Sfoglia il codice sorgente

client: add WithDialContext() and client.Dialer()

WithDialContext() allows specifying custom dialer for hijacking and supposed to
replace WithDialer().
WithDialer() is also updated to use WithDialContext().

client.Dialer() returns the dialer configured with WithDialContext().

Signed-off-by: Akihiro Suda <suda.akihiro@lab.ntt.co.jp>
Akihiro Suda 7 anni fa
parent
commit
edac92409a
4 ha cambiato i file con 30 aggiunte e 6 eliminazioni
  1. 21 1
      client/client.go
  2. 7 4
      client/hijack.go
  3. 1 0
      client/interface.go
  4. 1 1
      client/session.go

+ 21 - 1
client/client.go

@@ -173,10 +173,17 @@ func WithTLSClientConfig(cacertPath, certPath, keyPath string) func(*Client) err
 
 // WithDialer applies the dialer.DialContext to the client transport. This can be
 // used to set the Timeout and KeepAlive settings of the client.
+// Deprecated: use WithDialContext
 func WithDialer(dialer *net.Dialer) func(*Client) error {
+	return WithDialContext(dialer.DialContext)
+}
+
+// WithDialContext applies the dialer to the client transport. This can be
+// used to set the Timeout and KeepAlive settings of the client.
+func WithDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) func(*Client) error {
 	return func(c *Client) error {
 		if transport, ok := c.client.Transport.(*http.Transport); ok {
-			transport.DialContext = dialer.DialContext
+			transport.DialContext = dialContext
 			return nil
 		}
 		return errors.Errorf("cannot apply dialer to transport: %T", c.client.Transport)
@@ -400,3 +407,16 @@ func (cli *Client) CustomHTTPHeaders() map[string]string {
 func (cli *Client) SetCustomHTTPHeaders(headers map[string]string) {
 	cli.customHTTPHeaders = headers
 }
+
+// Dialer returns a dialer for a raw stream connection, with HTTP/1.1 header, that can be used for proxying the daemon connection.
+// Used by `docker dial-stdio` (docker/cli#889).
+func (cli *Client) Dialer() func(context.Context) (net.Conn, error) {
+	return func(ctx context.Context) (net.Conn, error) {
+		if transport, ok := cli.client.Transport.(*http.Transport); ok {
+			if transport.DialContext != nil {
+				return transport.DialContext(ctx, cli.proto, cli.addr)
+			}
+		}
+		return fallbackDial(cli.proto, cli.addr, resolveTLSConfig(cli.client.Transport))
+	}
+}

+ 7 - 4
client/hijack.go

@@ -30,7 +30,7 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu
 	}
 	req = cli.addHeaders(req, headers)
 
-	conn, err := cli.setupHijackConn(req, "tcp")
+	conn, err := cli.setupHijackConn(ctx, req, "tcp")
 	if err != nil {
 		return types.HijackedResponse{}, err
 	}
@@ -38,7 +38,9 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu
 	return types.HijackedResponse{Conn: conn, Reader: bufio.NewReader(conn)}, err
 }
 
-func dial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) {
+// fallbackDial is used when WithDialer() was not called.
+// See cli.Dialer().
+func fallbackDial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) {
 	if tlsConfig != nil && proto != "unix" && proto != "npipe" {
 		return tls.Dial(proto, addr, tlsConfig)
 	}
@@ -48,12 +50,13 @@ func dial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) {
 	return net.Dial(proto, addr)
 }
 
-func (cli *Client) setupHijackConn(req *http.Request, proto string) (net.Conn, error) {
+func (cli *Client) setupHijackConn(ctx context.Context, req *http.Request, proto string) (net.Conn, error) {
 	req.Host = cli.addr
 	req.Header.Set("Connection", "Upgrade")
 	req.Header.Set("Upgrade", proto)
 
-	conn, err := dial(cli.proto, cli.addr, resolveTLSConfig(cli.client.Transport))
+	dialer := cli.Dialer()
+	conn, err := dialer(ctx)
 	if err != nil {
 		return nil, errors.Wrap(err, "cannot connect to the Docker daemon. Is 'docker daemon' running on this host?")
 	}

+ 1 - 0
client/interface.go

@@ -39,6 +39,7 @@ type CommonAPIClient interface {
 	NegotiateAPIVersion(ctx context.Context)
 	NegotiateAPIVersionPing(types.Ping)
 	DialSession(ctx context.Context, proto string, meta map[string][]string) (net.Conn, error)
+	Dialer() func(context.Context) (net.Conn, error)
 	Close() error
 }
 

+ 1 - 1
client/session.go

@@ -14,5 +14,5 @@ func (cli *Client) DialSession(ctx context.Context, proto string, meta map[strin
 	}
 	req = cli.addHeaders(req, meta)
 
-	return cli.setupHijackConn(req, proto)
+	return cli.setupHijackConn(ctx, req, proto)
 }