Browse Source

Use stdlib TLS dialer

Since go1.8, the stdlib TLS net.Conn implementation implements the
`CloseWrite()` interface.

Signed-off-by: Brian Goff <cpuguy83@gmail.com>
Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
Brian Goff 7 years ago
parent
commit
2ac277a56f
4 changed files with 104 additions and 148 deletions
  1. 1 104
      client/hijack.go
  2. 103 0
      client/hijack_test.go
  3. 0 11
      client/tlsconfig_clone.go
  4. 0 33
      client/tlsconfig_clone_go17.go

+ 1 - 104
client/hijack.go

@@ -9,7 +9,6 @@ import (
 	"net/http"
 	"net/http/httputil"
 	"net/url"
-	"strings"
 	"time"
 
 	"github.com/docker/docker/api/types"
@@ -17,21 +16,6 @@ import (
 	"github.com/pkg/errors"
 )
 
-// tlsClientCon holds tls information and a dialed connection.
-type tlsClientCon struct {
-	*tls.Conn
-	rawConn net.Conn
-}
-
-func (c *tlsClientCon) CloseWrite() error {
-	// Go standard tls.Conn doesn't provide the CloseWrite() method so we do it
-	// on its underlying connection.
-	if conn, ok := c.rawConn.(types.CloseWriter); ok {
-		return conn.CloseWrite()
-	}
-	return nil
-}
-
 // postHijacked sends a POST request and hijacks the connection.
 func (cli *Client) postHijacked(ctx context.Context, path string, query url.Values, body interface{}, headers map[string][]string) (types.HijackedResponse, error) {
 	bodyEncoded, err := encodeData(body)
@@ -54,96 +38,9 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu
 	return types.HijackedResponse{Conn: conn, Reader: bufio.NewReader(conn)}, err
 }
 
-func tlsDial(network, addr string, config *tls.Config) (net.Conn, error) {
-	return tlsDialWithDialer(new(net.Dialer), network, addr, config)
-}
-
-// We need to copy Go's implementation of tls.Dial (pkg/cryptor/tls/tls.go) in
-// order to return our custom tlsClientCon struct which holds both the tls.Conn
-// object _and_ its underlying raw connection. The rationale for this is that
-// we need to be able to close the write end of the connection when attaching,
-// which tls.Conn does not provide.
-func tlsDialWithDialer(dialer *net.Dialer, network, addr string, config *tls.Config) (net.Conn, error) {
-	// We want the Timeout and Deadline values from dialer to cover the
-	// whole process: TCP connection and TLS handshake. This means that we
-	// also need to start our own timers now.
-	timeout := dialer.Timeout
-
-	if !dialer.Deadline.IsZero() {
-		deadlineTimeout := time.Until(dialer.Deadline)
-		if timeout == 0 || deadlineTimeout < timeout {
-			timeout = deadlineTimeout
-		}
-	}
-
-	var errChannel chan error
-
-	if timeout != 0 {
-		errChannel = make(chan error, 2)
-		time.AfterFunc(timeout, func() {
-			errChannel <- errors.New("")
-		})
-	}
-
-	proxyDialer, err := sockets.DialerFromEnvironment(dialer)
-	if err != nil {
-		return nil, err
-	}
-
-	rawConn, err := proxyDialer.Dial(network, addr)
-	if err != nil {
-		return nil, err
-	}
-	// When we set up a TCP connection for hijack, there could be long periods
-	// of inactivity (a long running command with no output) that in certain
-	// network setups may cause ECONNTIMEOUT, leaving the client in an unknown
-	// state. Setting TCP KeepAlive on the socket connection will prohibit
-	// ECONNTIMEOUT unless the socket connection truly is broken
-	if tcpConn, ok := rawConn.(*net.TCPConn); ok {
-		tcpConn.SetKeepAlive(true)
-		tcpConn.SetKeepAlivePeriod(30 * time.Second)
-	}
-
-	colonPos := strings.LastIndex(addr, ":")
-	if colonPos == -1 {
-		colonPos = len(addr)
-	}
-	hostname := addr[:colonPos]
-
-	// If no ServerName is set, infer the ServerName
-	// from the hostname we're connecting to.
-	if config.ServerName == "" {
-		// Make a copy to avoid polluting argument or default.
-		config = tlsConfigClone(config)
-		config.ServerName = hostname
-	}
-
-	conn := tls.Client(rawConn, config)
-
-	if timeout == 0 {
-		err = conn.Handshake()
-	} else {
-		go func() {
-			errChannel <- conn.Handshake()
-		}()
-
-		err = <-errChannel
-	}
-
-	if err != nil {
-		rawConn.Close()
-		return nil, err
-	}
-
-	// This is Docker difference with standard's crypto/tls package: returned a
-	// wrapper which holds both the TLS and raw connections.
-	return &tlsClientCon{conn, rawConn}, nil
-}
-
 func dial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) {
 	if tlsConfig != nil && proto != "unix" && proto != "npipe" {
-		// Notice this isn't Go standard's tls.Dial function
-		return tlsDial(proto, addr, tlsConfig)
+		return tls.Dial(proto, addr, tlsConfig)
 	}
 	if proto == "npipe" {
 		return sockets.DialPipe(addr, 32*time.Second)

+ 103 - 0
client/hijack_test.go

@@ -0,0 +1,103 @@
+package client
+
+import (
+	"fmt"
+	"io/ioutil"
+	"net"
+	"net/http"
+	"net/http/httptest"
+	"net/url"
+	"testing"
+
+	"github.com/docker/docker/api/server/httputils"
+	"github.com/docker/docker/api/types"
+	"github.com/gotestyourself/gotestyourself/assert"
+	"github.com/pkg/errors"
+	"golang.org/x/net/context"
+)
+
+func TestTLSCloseWriter(t *testing.T) {
+	t.Parallel()
+
+	var chErr chan error
+	ts := &httptest.Server{Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+		chErr = make(chan error, 1)
+		defer close(chErr)
+		if err := httputils.ParseForm(req); err != nil {
+			chErr <- errors.Wrap(err, "error parsing form")
+			http.Error(w, err.Error(), 500)
+			return
+		}
+		r, rw, err := httputils.HijackConnection(w)
+		if err != nil {
+			chErr <- errors.Wrap(err, "error hijacking connection")
+			http.Error(w, err.Error(), 500)
+			return
+		}
+		defer r.Close()
+
+		fmt.Fprint(rw, "HTTP/1.1 101 UPGRADED\r\nContent-Type: application/vnd.docker.raw-stream\r\nConnection: Upgrade\r\nUpgrade: tcp\r\n\n")
+
+		buf := make([]byte, 5)
+		_, err = r.Read(buf)
+		if err != nil {
+			chErr <- errors.Wrap(err, "error reading from client")
+			return
+		}
+		_, err = rw.Write(buf)
+		if err != nil {
+			chErr <- errors.Wrap(err, "error writing to client")
+			return
+		}
+	})}}
+
+	var (
+		l   net.Listener
+		err error
+	)
+	for i := 1024; i < 10000; i++ {
+		l, err = net.Listen("tcp4", fmt.Sprintf("127.0.0.1:%d", i))
+		if err == nil {
+			break
+		}
+	}
+	assert.Assert(t, err)
+
+	ts.Listener = l
+	defer l.Close()
+
+	defer func() {
+		if chErr != nil {
+			assert.Assert(t, <-chErr)
+		}
+	}()
+
+	ts.StartTLS()
+	defer ts.Close()
+
+	serverURL, err := url.Parse(ts.URL)
+	assert.Assert(t, err)
+
+	client, err := NewClient("tcp://"+serverURL.Host, "", ts.Client(), nil)
+	assert.Assert(t, err)
+
+	resp, err := client.postHijacked(context.Background(), "/asdf", url.Values{}, nil, map[string][]string{"Content-Type": {"text/plain"}})
+	assert.Assert(t, err)
+	defer resp.Close()
+
+	if _, ok := resp.Conn.(types.CloseWriter); !ok {
+		t.Fatal("tls conn did not implement the CloseWrite interface")
+	}
+
+	_, err = resp.Conn.Write([]byte("hello"))
+	assert.Assert(t, err)
+
+	b, err := ioutil.ReadAll(resp.Reader)
+	assert.Assert(t, err)
+	assert.Assert(t, string(b) == "hello")
+	assert.Assert(t, resp.CloseWrite())
+
+	// This should error since writes are closed
+	_, err = resp.Conn.Write([]byte("no"))
+	assert.Assert(t, err != nil)
+}

+ 0 - 11
client/tlsconfig_clone.go

@@ -1,11 +0,0 @@
-// +build go1.8
-
-package client // import "github.com/docker/docker/client"
-
-import "crypto/tls"
-
-// tlsConfigClone returns a clone of tls.Config. This function is provided for
-// compatibility for go1.7 that doesn't include this method in stdlib.
-func tlsConfigClone(c *tls.Config) *tls.Config {
-	return c.Clone()
-}

+ 0 - 33
client/tlsconfig_clone_go17.go

@@ -1,33 +0,0 @@
-// +build go1.7,!go1.8
-
-package client // import "github.com/docker/docker/client"
-
-import "crypto/tls"
-
-// tlsConfigClone returns a clone of tls.Config. This function is provided for
-// compatibility for go1.7 that doesn't include this method in stdlib.
-func tlsConfigClone(c *tls.Config) *tls.Config {
-	return &tls.Config{
-		Rand:                        c.Rand,
-		Time:                        c.Time,
-		Certificates:                c.Certificates,
-		NameToCertificate:           c.NameToCertificate,
-		GetCertificate:              c.GetCertificate,
-		RootCAs:                     c.RootCAs,
-		NextProtos:                  c.NextProtos,
-		ServerName:                  c.ServerName,
-		ClientAuth:                  c.ClientAuth,
-		ClientCAs:                   c.ClientCAs,
-		InsecureSkipVerify:          c.InsecureSkipVerify,
-		CipherSuites:                c.CipherSuites,
-		PreferServerCipherSuites:    c.PreferServerCipherSuites,
-		SessionTicketsDisabled:      c.SessionTicketsDisabled,
-		SessionTicketKey:            c.SessionTicketKey,
-		ClientSessionCache:          c.ClientSessionCache,
-		MinVersion:                  c.MinVersion,
-		MaxVersion:                  c.MaxVersion,
-		CurvePreferences:            c.CurvePreferences,
-		DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
-		Renegotiation:               c.Renegotiation,
-	}
-}