diff --git a/client/client.go b/client/client.go index b874b3b522..7475d3f968 100644 --- a/client/client.go +++ b/client/client.go @@ -173,10 +173,17 @@ func WithTLSClientConfig(cacertPath, certPath, keyPath string) func(*Client) err // WithDialer applies the dialer.DialContext to the client transport. This can be // used to set the Timeout and KeepAlive settings of the client. +// Deprecated: use WithDialContext func WithDialer(dialer *net.Dialer) func(*Client) error { + return WithDialContext(dialer.DialContext) +} + +// WithDialContext applies the dialer to the client transport. This can be +// used to set the Timeout and KeepAlive settings of the client. +func WithDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) func(*Client) error { return func(c *Client) error { if transport, ok := c.client.Transport.(*http.Transport); ok { - transport.DialContext = dialer.DialContext + transport.DialContext = dialContext return nil } return errors.Errorf("cannot apply dialer to transport: %T", c.client.Transport) @@ -400,3 +407,16 @@ func (cli *Client) CustomHTTPHeaders() map[string]string { func (cli *Client) SetCustomHTTPHeaders(headers map[string]string) { cli.customHTTPHeaders = headers } + +// Dialer returns a dialer for a raw stream connection, with HTTP/1.1 header, that can be used for proxying the daemon connection. +// Used by `docker dial-stdio` (docker/cli#889). +func (cli *Client) Dialer() func(context.Context) (net.Conn, error) { + return func(ctx context.Context) (net.Conn, error) { + if transport, ok := cli.client.Transport.(*http.Transport); ok { + if transport.DialContext != nil { + return transport.DialContext(ctx, cli.proto, cli.addr) + } + } + return fallbackDial(cli.proto, cli.addr, resolveTLSConfig(cli.client.Transport)) + } +} diff --git a/client/hijack.go b/client/hijack.go index 35f5dd86dc..0ac8248f2c 100644 --- a/client/hijack.go +++ b/client/hijack.go @@ -30,7 +30,7 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu } req = cli.addHeaders(req, headers) - conn, err := cli.setupHijackConn(req, "tcp") + conn, err := cli.setupHijackConn(ctx, req, "tcp") if err != nil { return types.HijackedResponse{}, err } @@ -38,7 +38,9 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu return types.HijackedResponse{Conn: conn, Reader: bufio.NewReader(conn)}, err } -func dial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) { +// fallbackDial is used when WithDialer() was not called. +// See cli.Dialer(). +func fallbackDial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) { if tlsConfig != nil && proto != "unix" && proto != "npipe" { return tls.Dial(proto, addr, tlsConfig) } @@ -48,12 +50,13 @@ func dial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) { return net.Dial(proto, addr) } -func (cli *Client) setupHijackConn(req *http.Request, proto string) (net.Conn, error) { +func (cli *Client) setupHijackConn(ctx context.Context, req *http.Request, proto string) (net.Conn, error) { req.Host = cli.addr req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", proto) - conn, err := dial(cli.proto, cli.addr, resolveTLSConfig(cli.client.Transport)) + dialer := cli.Dialer() + conn, err := dialer(ctx) if err != nil { return nil, errors.Wrap(err, "cannot connect to the Docker daemon. Is 'docker daemon' running on this host?") } diff --git a/client/interface.go b/client/interface.go index 9250c468a6..663749f008 100644 --- a/client/interface.go +++ b/client/interface.go @@ -39,6 +39,7 @@ type CommonAPIClient interface { NegotiateAPIVersion(ctx context.Context) NegotiateAPIVersionPing(types.Ping) DialSession(ctx context.Context, proto string, meta map[string][]string) (net.Conn, error) + Dialer() func(context.Context) (net.Conn, error) Close() error } diff --git a/client/session.go b/client/session.go index c247123b45..df199f3d03 100644 --- a/client/session.go +++ b/client/session.go @@ -14,5 +14,5 @@ func (cli *Client) DialSession(ctx context.Context, proto string, meta map[strin } req = cli.addHeaders(req, meta) - return cli.setupHijackConn(req, proto) + return cli.setupHijackConn(ctx, req, proto) }