Browse Source

add test for filesync path filtering and testutil helper

Signed-off-by: Tonis Tiigi <tonistiigi@gmail.com>
Tonis Tiigi 8 years ago
parent
commit
4141d8fe5d

+ 71 - 0
client/session/filesync/filesync_test.go

@@ -0,0 +1,71 @@
+package filesync
+
+import (
+	"context"
+	"io/ioutil"
+	"path/filepath"
+	"testing"
+
+	"github.com/docker/docker/client/session"
+	"github.com/docker/docker/client/session/testutil"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"golang.org/x/sync/errgroup"
+)
+
+func TestFileSyncIncludePatterns(t *testing.T) {
+	tmpDir, err := ioutil.TempDir("", "fsynctest")
+	require.NoError(t, err)
+
+	destDir, err := ioutil.TempDir("", "fsynctest")
+	require.NoError(t, err)
+
+	err = ioutil.WriteFile(filepath.Join(tmpDir, "foo"), []byte("content1"), 0600)
+	require.NoError(t, err)
+
+	err = ioutil.WriteFile(filepath.Join(tmpDir, "bar"), []byte("content2"), 0600)
+	require.NoError(t, err)
+
+	s, err := session.NewSession("foo", "bar")
+	require.NoError(t, err)
+
+	m, err := session.NewManager()
+	require.NoError(t, err)
+
+	fs := NewFSSyncProvider(tmpDir, nil)
+	s.Allow(fs)
+
+	dialer := session.Dialer(testutil.TestStream(testutil.Handler(m.HandleConn)))
+
+	g, ctx := errgroup.WithContext(context.Background())
+
+	g.Go(func() error {
+		return s.Run(ctx, dialer)
+	})
+
+	g.Go(func() (reterr error) {
+		c, err := m.Get(ctx, s.UUID())
+		if err != nil {
+			return err
+		}
+		if err := FSSync(ctx, c, FSSendRequestOpt{
+			DestDir:         destDir,
+			IncludePatterns: []string{"ba*"},
+		}); err != nil {
+			return err
+		}
+
+		_, err = ioutil.ReadFile(filepath.Join(destDir, "foo"))
+		assert.Error(t, err)
+
+		dt, err := ioutil.ReadFile(filepath.Join(destDir, "bar"))
+		if err != nil {
+			return err
+		}
+		assert.Equal(t, "content2", string(dt))
+		return s.Close()
+	})
+
+	err = g.Wait()
+	require.NoError(t, err)
+}

+ 18 - 3
client/session/manager.go

@@ -1,6 +1,7 @@
 package session
 
 import (
+	"net"
 	"net/http"
 	"strings"
 	"sync"
@@ -49,8 +50,6 @@ func (sm *Manager) HandleHTTPRequest(ctx context.Context, w http.ResponseWriter,
 	}
 
 	uuid := r.Header.Get(headerSessionUUID)
-	name := r.Header.Get(headerSessionName)
-	sharedKey := r.Header.Get(headerSessionSharedKey)
 
 	proto := r.Header.Get("Upgrade")
 
@@ -89,9 +88,25 @@ func (sm *Manager) HandleHTTPRequest(ctx context.Context, w http.ResponseWriter,
 	conn.Write([]byte{})
 	resp.Write(conn)
 
+	return sm.handleConn(ctx, conn, r.Header)
+}
+
+// HandleConn handles an incoming raw connection
+func (sm *Manager) HandleConn(ctx context.Context, conn net.Conn, opts map[string][]string) error {
+	sm.mu.Lock()
+	return sm.handleConn(ctx, conn, opts)
+}
+
+// caller needs to take lock, this function will release it
+func (sm *Manager) handleConn(ctx context.Context, conn net.Conn, opts map[string][]string) error {
 	ctx, cancel := context.WithCancel(ctx)
 	defer cancel()
 
+	h := http.Header(opts)
+	uuid := h.Get(headerSessionUUID)
+	name := h.Get(headerSessionName)
+	sharedKey := h.Get(headerSessionSharedKey)
+
 	ctx, cc, err := grpcClientConn(ctx, conn)
 	if err != nil {
 		sm.mu.Unlock()
@@ -111,7 +126,7 @@ func (sm *Manager) HandleHTTPRequest(ctx context.Context, w http.ResponseWriter,
 		supported: make(map[string]struct{}),
 	}
 
-	for _, m := range r.Header[headerSessionMethod] {
+	for _, m := range opts[headerSessionMethod] {
 		c.supported[strings.ToLower(m)] = struct{}{}
 	}
 	sm.sessions[uuid] = c

+ 70 - 0
client/session/testutil/testutil.go

@@ -0,0 +1,70 @@
+package testutil
+
+import (
+	"io"
+	"net"
+	"time"
+
+	"github.com/Sirupsen/logrus"
+	"golang.org/x/net/context"
+)
+
+// Handler is function called to handle incoming connection
+type Handler func(ctx context.Context, conn net.Conn, meta map[string][]string) error
+
+// Dialer is a function for dialing an outgoing connection
+type Dialer func(ctx context.Context, proto string, meta map[string][]string) (net.Conn, error)
+
+// TestStream creates an in memory session dialer for a handler function
+func TestStream(handler Handler) Dialer {
+	s1, s2 := sockPair()
+	return func(ctx context.Context, proto string, meta map[string][]string) (net.Conn, error) {
+		go func() {
+			err := handler(context.TODO(), s1, meta)
+			if err != nil {
+				logrus.Error(err)
+			}
+			s1.Close()
+		}()
+		return s2, nil
+	}
+}
+
+func sockPair() (*sock, *sock) {
+	pr1, pw1 := io.Pipe()
+	pr2, pw2 := io.Pipe()
+	return &sock{pw1, pr2, pw1}, &sock{pw2, pr1, pw2}
+}
+
+type sock struct {
+	io.Writer
+	io.Reader
+	io.Closer
+}
+
+func (s *sock) LocalAddr() net.Addr {
+	return dummyAddr{}
+}
+func (s *sock) RemoteAddr() net.Addr {
+	return dummyAddr{}
+}
+func (s *sock) SetDeadline(t time.Time) error {
+	return nil
+}
+func (s *sock) SetReadDeadline(t time.Time) error {
+	return nil
+}
+func (s *sock) SetWriteDeadline(t time.Time) error {
+	return nil
+}
+
+type dummyAddr struct {
+}
+
+func (d dummyAddr) Network() string {
+	return "tcp"
+}
+
+func (d dummyAddr) String() string {
+	return "localhost"
+}