Kaynağa Gözat

Add transport package to support CancelRequest

Signed-off-by: Tibor Vass <tibor@docker.com>
Tibor Vass 10 yıl önce
ebeveyn
işleme
73823e5e56

+ 7 - 5
graph/pull.go

@@ -17,6 +17,7 @@ import (
 	"github.com/docker/docker/pkg/progressreader"
 	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/pkg/stringid"
+	"github.com/docker/docker/pkg/transport"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/utils"
 )
@@ -55,16 +56,17 @@ func (s *TagStore) Pull(image string, tag string, imagePullConfig *ImagePullConf
 	defer s.poolRemove("pull", utils.ImageReference(repoInfo.LocalName, tag))
 
 	logrus.Debugf("pulling image from host %q with remote name %q", repoInfo.Index.Name, repoInfo.RemoteName)
-	endpoint, err := repoInfo.GetEndpoint()
+
+	endpoint, err := repoInfo.GetEndpoint(imagePullConfig.MetaHeaders)
 	if err != nil {
 		return err
 	}
-
+	// TODO(tiborvass): reuse client from endpoint?
 	// Adds Docker-specific headers as well as user-specified headers (metaHeaders)
-	tr := &registry.DockerHeaders{
+	tr := transport.NewTransport(
 		registry.NewTransport(registry.ReceiveTimeout, endpoint.IsSecure),
-		imagePullConfig.MetaHeaders,
-	}
+		registry.DockerHeaders(imagePullConfig.MetaHeaders)...,
+	)
 	client := registry.HTTPClient(tr)
 	r, err := registry.NewSession(client, imagePullConfig.AuthConfig, endpoint)
 	if err != nil {

+ 7 - 5
graph/push.go

@@ -18,6 +18,7 @@ import (
 	"github.com/docker/docker/pkg/progressreader"
 	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/pkg/stringid"
+	"github.com/docker/docker/pkg/transport"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/runconfig"
 	"github.com/docker/docker/utils"
@@ -509,16 +510,17 @@ func (s *TagStore) Push(localName string, imagePushConfig *ImagePushConfig) erro
 	}
 	defer s.poolRemove("push", repoInfo.LocalName)
 
-	endpoint, err := repoInfo.GetEndpoint()
+	endpoint, err := repoInfo.GetEndpoint(imagePushConfig.MetaHeaders)
 	if err != nil {
 		return err
 	}
-
+	// TODO(tiborvass): reuse client from endpoint?
 	// Adds Docker-specific headers as well as user-specified headers (metaHeaders)
-	tr := &registry.DockerHeaders{
+	tr := transport.NewTransport(
 		registry.NewTransport(registry.NoTimeout, endpoint.IsSecure),
-		imagePushConfig.MetaHeaders,
-	}
+		registry.DockerHeaders(imagePushConfig.MetaHeaders)...,
+	)
+	client := registry.HTTPClient(tr)
 	r, err := registry.NewSession(client, imagePushConfig.AuthConfig, endpoint)
 	if err != nil {
 		return err

+ 27 - 0
pkg/transport/LICENSE

@@ -0,0 +1,27 @@
+Copyright (c) 2009 The oauth2 Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+   * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+   * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+   * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

+ 148 - 0
pkg/transport/transport.go

@@ -0,0 +1,148 @@
+package transport
+
+import (
+	"io"
+	"net/http"
+	"sync"
+)
+
+type RequestModifier interface {
+	ModifyRequest(*http.Request) error
+}
+
+type headerModifier http.Header
+
+// NewHeaderRequestModifier returns a RequestModifier that merges the HTTP headers
+// passed as an argument, with the HTTP headers of a request.
+//
+// If the same key is present in both, the modifying header values for that key,
+// are appended to the values for that same key in the request header.
+func NewHeaderRequestModifier(header http.Header) RequestModifier {
+	return headerModifier(header)
+}
+
+func (h headerModifier) ModifyRequest(req *http.Request) error {
+	for k, s := range http.Header(h) {
+		req.Header[k] = append(req.Header[k], s...)
+	}
+
+	return nil
+}
+
+// NewTransport returns an http.RoundTripper that modifies requests according to
+// the RequestModifiers passed in the arguments, before sending the requests to
+// the base http.RoundTripper (which, if nil, defaults to http.DefaultTransport).
+func NewTransport(base http.RoundTripper, modifiers ...RequestModifier) http.RoundTripper {
+	return &transport{
+		Modifiers: modifiers,
+		Base:      base,
+	}
+}
+
+// transport is an http.RoundTripper that makes HTTP requests after
+// copying and modifying the request
+type transport struct {
+	Modifiers []RequestModifier
+	Base      http.RoundTripper
+
+	mu     sync.Mutex                      // guards modReq
+	modReq map[*http.Request]*http.Request // original -> modified
+}
+
+func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
+	req2 := CloneRequest(req)
+	for _, modifier := range t.Modifiers {
+		if err := modifier.ModifyRequest(req2); err != nil {
+			return nil, err
+		}
+	}
+
+	t.setModReq(req, req2)
+	res, err := t.base().RoundTrip(req2)
+	if err != nil {
+		t.setModReq(req, nil)
+		return nil, err
+	}
+	res.Body = &OnEOFReader{
+		Rc: res.Body,
+		Fn: func() { t.setModReq(req, nil) },
+	}
+	return res, nil
+}
+
+// CancelRequest cancels an in-flight request by closing its connection.
+func (t *transport) CancelRequest(req *http.Request) {
+	type canceler interface {
+		CancelRequest(*http.Request)
+	}
+	if cr, ok := t.base().(canceler); ok {
+		t.mu.Lock()
+		modReq := t.modReq[req]
+		delete(t.modReq, req)
+		t.mu.Unlock()
+		cr.CancelRequest(modReq)
+	}
+}
+
+func (t *transport) base() http.RoundTripper {
+	if t.Base != nil {
+		return t.Base
+	}
+	return http.DefaultTransport
+}
+
+func (t *transport) setModReq(orig, mod *http.Request) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	if t.modReq == nil {
+		t.modReq = make(map[*http.Request]*http.Request)
+	}
+	if mod == nil {
+		delete(t.modReq, orig)
+	} else {
+		t.modReq[orig] = mod
+	}
+}
+
+// CloneRequest returns a clone of the provided *http.Request.
+// The clone is a shallow copy of the struct and its Header map.
+func CloneRequest(r *http.Request) *http.Request {
+	// shallow copy of the struct
+	r2 := new(http.Request)
+	*r2 = *r
+	// deep copy of the Header
+	r2.Header = make(http.Header, len(r.Header))
+	for k, s := range r.Header {
+		r2.Header[k] = append([]string(nil), s...)
+	}
+
+	return r2
+}
+
+// OnEOFReader ensures a callback function is called
+// on Close() and when the underlying Reader returns an io.EOF error
+type OnEOFReader struct {
+	Rc io.ReadCloser
+	Fn func()
+}
+
+func (r *OnEOFReader) Read(p []byte) (n int, err error) {
+	n, err = r.Rc.Read(p)
+	if err == io.EOF {
+		r.runFunc()
+	}
+	return
+}
+
+func (r *OnEOFReader) Close() error {
+	err := r.Rc.Close()
+	r.runFunc()
+	return err
+}
+
+func (r *OnEOFReader) runFunc() {
+	if fn := r.Fn; fn != nil {
+		fn()
+		r.Fn = nil
+	}
+}

+ 11 - 15
registry/auth.go

@@ -44,8 +44,6 @@ func (auth *RequestAuthorization) getToken() (string, error) {
 		return auth.tokenCache, nil
 	}
 
-	client := auth.registryEndpoint.HTTPClient()
-
 	for _, challenge := range auth.registryEndpoint.AuthChallenges {
 		switch strings.ToLower(challenge.Scheme) {
 		case "basic":
@@ -57,7 +55,7 @@ func (auth *RequestAuthorization) getToken() (string, error) {
 				params[k] = v
 			}
 			params["scope"] = fmt.Sprintf("%s:%s:%s", auth.resource, auth.scope, strings.Join(auth.actions, ","))
-			token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint, client)
+			token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint)
 			if err != nil {
 				return "", err
 			}
@@ -104,7 +102,6 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
 		status        string
 		reqBody       []byte
 		err           error
-		client        = registryEndpoint.HTTPClient()
 		reqStatusCode = 0
 		serverAddress = authConfig.ServerAddress
 	)
@@ -128,7 +125,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
 
 	// using `bytes.NewReader(jsonBody)` here causes the server to respond with a 411 status.
 	b := strings.NewReader(string(jsonBody))
-	req1, err := client.Post(serverAddress+"users/", "application/json; charset=utf-8", b)
+	req1, err := registryEndpoint.client.Post(serverAddress+"users/", "application/json; charset=utf-8", b)
 	if err != nil {
 		return "", fmt.Errorf("Server Error: %s", err)
 	}
@@ -151,7 +148,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
 		if string(reqBody) == "\"Username or email already exists\"" {
 			req, err := http.NewRequest("GET", serverAddress+"users/", nil)
 			req.SetBasicAuth(authConfig.Username, authConfig.Password)
-			resp, err := client.Do(req)
+			resp, err := registryEndpoint.client.Do(req)
 			if err != nil {
 				return "", err
 			}
@@ -180,7 +177,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
 		// protected, so people can use `docker login` as an auth check.
 		req, err := http.NewRequest("GET", serverAddress+"users/", nil)
 		req.SetBasicAuth(authConfig.Username, authConfig.Password)
-		resp, err := client.Do(req)
+		resp, err := registryEndpoint.client.Do(req)
 		if err != nil {
 			return "", err
 		}
@@ -217,7 +214,6 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
 	var (
 		err       error
 		allErrors []error
-		client    = registryEndpoint.HTTPClient()
 	)
 
 	for _, challenge := range registryEndpoint.AuthChallenges {
@@ -225,9 +221,9 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
 
 		switch strings.ToLower(challenge.Scheme) {
 		case "basic":
-			err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client)
+			err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint)
 		case "bearer":
-			err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client)
+			err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint)
 		default:
 			// Unsupported challenge types are explicitly skipped.
 			err = fmt.Errorf("unsupported auth scheme: %q", challenge.Scheme)
@@ -245,7 +241,7 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
 	return "", fmt.Errorf("no successful auth challenge for %s - errors: %s", registryEndpoint, allErrors)
 }
 
-func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client) error {
+func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error {
 	req, err := http.NewRequest("GET", registryEndpoint.Path(""), nil)
 	if err != nil {
 		return err
@@ -253,7 +249,7 @@ func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str
 
 	req.SetBasicAuth(authConfig.Username, authConfig.Password)
 
-	resp, err := client.Do(req)
+	resp, err := registryEndpoint.client.Do(req)
 	if err != nil {
 		return err
 	}
@@ -266,8 +262,8 @@ func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str
 	return nil
 }
 
-func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client) error {
-	token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint, client)
+func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error {
+	token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint)
 	if err != nil {
 		return err
 	}
@@ -279,7 +275,7 @@ func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str
 
 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
 
-	resp, err := client.Do(req)
+	resp, err := registryEndpoint.client.Do(req)
 	if err != nil {
 		return err
 	}

+ 10 - 15
registry/endpoint.go

@@ -11,6 +11,7 @@ import (
 
 	"github.com/Sirupsen/logrus"
 	"github.com/docker/distribution/registry/api/v2"
+	"github.com/docker/docker/pkg/transport"
 )
 
 // for mocking in unit tests
@@ -41,9 +42,9 @@ func scanForAPIVersion(address string) (string, APIVersion) {
 }
 
 // NewEndpoint parses the given address to return a registry endpoint.
-func NewEndpoint(index *IndexInfo) (*Endpoint, error) {
+func NewEndpoint(index *IndexInfo, metaHeaders http.Header) (*Endpoint, error) {
 	// *TODO: Allow per-registry configuration of endpoints.
-	endpoint, err := newEndpoint(index.GetAuthConfigKey(), index.Secure)
+	endpoint, err := newEndpoint(index.GetAuthConfigKey(), index.Secure, metaHeaders)
 	if err != nil {
 		return nil, err
 	}
@@ -81,7 +82,7 @@ func validateEndpoint(endpoint *Endpoint) error {
 	return nil
 }
 
-func newEndpoint(address string, secure bool) (*Endpoint, error) {
+func newEndpoint(address string, secure bool, metaHeaders http.Header) (*Endpoint, error) {
 	var (
 		endpoint       = new(Endpoint)
 		trimmedAddress string
@@ -98,11 +99,13 @@ func newEndpoint(address string, secure bool) (*Endpoint, error) {
 		return nil, err
 	}
 	endpoint.IsSecure = secure
+	tr := NewTransport(ConnectTimeout, endpoint.IsSecure)
+	endpoint.client = HTTPClient(transport.NewTransport(tr, DockerHeaders(metaHeaders)...))
 	return endpoint, nil
 }
 
-func (repoInfo *RepositoryInfo) GetEndpoint() (*Endpoint, error) {
-	return NewEndpoint(repoInfo.Index)
+func (repoInfo *RepositoryInfo) GetEndpoint(metaHeaders http.Header) (*Endpoint, error) {
+	return NewEndpoint(repoInfo.Index, metaHeaders)
 }
 
 // Endpoint stores basic information about a registry endpoint.
@@ -174,7 +177,7 @@ func (e *Endpoint) pingV1() (RegistryInfo, error) {
 		return RegistryInfo{Standalone: false}, err
 	}
 
-	resp, err := e.HTTPClient().Do(req)
+	resp, err := e.client.Do(req)
 	if err != nil {
 		return RegistryInfo{Standalone: false}, err
 	}
@@ -222,7 +225,7 @@ func (e *Endpoint) pingV2() (RegistryInfo, error) {
 		return RegistryInfo{}, err
 	}
 
-	resp, err := e.HTTPClient().Do(req)
+	resp, err := e.client.Do(req)
 	if err != nil {
 		return RegistryInfo{}, err
 	}
@@ -261,11 +264,3 @@ HeaderLoop:
 
 	return RegistryInfo{}, fmt.Errorf("v2 registry endpoint returned status %d: %q", resp.StatusCode, http.StatusText(resp.StatusCode))
 }
-
-func (e *Endpoint) HTTPClient() *http.Client {
-	if e.client == nil {
-		tr := NewTransport(ConnectTimeout, e.IsSecure)
-		e.client = HTTPClient(tr)
-	}
-	return e.client
-}

+ 2 - 1
registry/endpoint_test.go

@@ -19,7 +19,7 @@ func TestEndpointParse(t *testing.T) {
 		{"0.0.0.0:5000", "https://0.0.0.0:5000/v0/"},
 	}
 	for _, td := range testData {
-		e, err := newEndpoint(td.str, false)
+		e, err := newEndpoint(td.str, false, nil)
 		if err != nil {
 			t.Errorf("%q: %s", td.str, err)
 		}
@@ -60,6 +60,7 @@ func TestValidateEndpointAmbiguousAPIVersion(t *testing.T) {
 	testEndpoint := Endpoint{
 		URL:     testServerURL,
 		Version: APIVersionUnknown,
+		client:  HTTPClient(NewTransport(ConnectTimeout, false)),
 	}
 
 	if err = validateEndpoint(&testEndpoint); err != nil {

+ 44 - 62
registry/registry.go

@@ -19,6 +19,7 @@ import (
 	"github.com/docker/docker/autogen/dockerversion"
 	"github.com/docker/docker/pkg/parsers/kernel"
 	"github.com/docker/docker/pkg/timeoutconn"
+	"github.com/docker/docker/pkg/transport"
 	"github.com/docker/docker/pkg/useragent"
 )
 
@@ -36,17 +37,32 @@ const (
 	ConnectTimeout
 )
 
-type httpsTransport struct {
-	*http.Transport
+// dockerUserAgent is the User-Agent the Docker client uses to identify itself.
+// It is populated on init(), comprising version information of different components.
+var dockerUserAgent string
+
+func init() {
+	httpVersion := make([]useragent.VersionInfo, 0, 6)
+	httpVersion = append(httpVersion, useragent.VersionInfo{"docker", dockerversion.VERSION})
+	httpVersion = append(httpVersion, useragent.VersionInfo{"go", runtime.Version()})
+	httpVersion = append(httpVersion, useragent.VersionInfo{"git-commit", dockerversion.GITCOMMIT})
+	if kernelVersion, err := kernel.GetKernelVersion(); err == nil {
+		httpVersion = append(httpVersion, useragent.VersionInfo{"kernel", kernelVersion.String()})
+	}
+	httpVersion = append(httpVersion, useragent.VersionInfo{"os", runtime.GOOS})
+	httpVersion = append(httpVersion, useragent.VersionInfo{"arch", runtime.GOARCH})
+
+	dockerUserAgent = useragent.AppendVersions("", httpVersion...)
 }
 
+type httpsRequestModifier struct{ tlsConfig *tls.Config }
+
 // DRAGONS(tiborvass): If someone wonders why do we set tlsconfig in a roundtrip,
 // it's because it's so as to match the current behavior in master: we generate the
 // certpool on every-goddam-request. It's not great, but it allows people to just put
 // the certs in /etc/docker/certs.d/.../ and let docker "pick it up" immediately. Would
 // prefer an fsnotify implementation, but that was out of scope of my refactoring.
-// TODO: improve things
-func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+func (m *httpsRequestModifier) ModifyRequest(req *http.Request) error {
 	var (
 		roots *x509.CertPool
 		certs []tls.Certificate
@@ -66,7 +82,7 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 		logrus.Debugf("hostDir: %s", hostDir)
 		fs, err := ioutil.ReadDir(hostDir)
 		if err != nil && !os.IsNotExist(err) {
-			return nil, err
+			return nil
 		}
 
 		for _, f := range fs {
@@ -77,7 +93,7 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 				logrus.Debugf("crt: %s", hostDir+"/"+f.Name())
 				data, err := ioutil.ReadFile(path.Join(hostDir, f.Name()))
 				if err != nil {
-					return nil, err
+					return err
 				}
 				roots.AppendCertsFromPEM(data)
 			}
@@ -86,11 +102,11 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 				keyName := certName[:len(certName)-5] + ".key"
 				logrus.Debugf("cert: %s", hostDir+"/"+f.Name())
 				if !hasFile(fs, keyName) {
-					return nil, fmt.Errorf("Missing key %s for certificate %s", keyName, certName)
+					return fmt.Errorf("Missing key %s for certificate %s", keyName, certName)
 				}
 				cert, err := tls.LoadX509KeyPair(path.Join(hostDir, certName), path.Join(hostDir, keyName))
 				if err != nil {
-					return nil, err
+					return err
 				}
 				certs = append(certs, cert)
 			}
@@ -99,38 +115,32 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 				certName := keyName[:len(keyName)-4] + ".cert"
 				logrus.Debugf("key: %s", hostDir+"/"+f.Name())
 				if !hasFile(fs, certName) {
-					return nil, fmt.Errorf("Missing certificate %s for key %s", certName, keyName)
+					return fmt.Errorf("Missing certificate %s for key %s", certName, keyName)
 				}
 			}
 		}
-		if tr.Transport.TLSClientConfig == nil {
-			tr.Transport.TLSClientConfig = &tls.Config{
-				// Avoid fallback to SSL protocols < TLS1.0
-				MinVersion: tls.VersionTLS10,
-			}
-		}
-		tr.Transport.TLSClientConfig.RootCAs = roots
-		tr.Transport.TLSClientConfig.Certificates = certs
+		m.tlsConfig.RootCAs = roots
+		m.tlsConfig.Certificates = certs
 	}
-	return tr.Transport.RoundTrip(req)
+	return nil
 }
 
 func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper {
-	tlsConfig := tls.Config{
+	tlsConfig := &tls.Config{
 		// Avoid fallback to SSL protocols < TLS1.0
 		MinVersion:         tls.VersionTLS10,
 		InsecureSkipVerify: !secure,
 	}
 
-	transport := &http.Transport{
+	tr := &http.Transport{
 		DisableKeepAlives: true,
 		Proxy:             http.ProxyFromEnvironment,
-		TLSClientConfig:   &tlsConfig,
+		TLSClientConfig:   tlsConfig,
 	}
 
 	switch timeout {
 	case ConnectTimeout:
-		transport.Dial = func(proto string, addr string) (net.Conn, error) {
+		tr.Dial = func(proto string, addr string) (net.Conn, error) {
 			// Set the connect timeout to 30 seconds to allow for slower connection
 			// times...
 			d := net.Dialer{Timeout: 30 * time.Second, DualStack: true}
@@ -144,7 +154,7 @@ func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper {
 			return conn, nil
 		}
 	case ReceiveTimeout:
-		transport.Dial = func(proto string, addr string) (net.Conn, error) {
+		tr.Dial = func(proto string, addr string) (net.Conn, error) {
 			d := net.Dialer{DualStack: true}
 
 			conn, err := d.Dial(proto, addr)
@@ -159,51 +169,23 @@ func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper {
 	if secure {
 		// note: httpsTransport also handles http transport
 		// but for HTTPS, it sets up the certs
-		return &httpsTransport{transport}
+		return transport.NewTransport(tr, &httpsRequestModifier{tlsConfig})
 	}
 
-	return transport
+	return tr
 }
 
-type DockerHeaders struct {
-	http.RoundTripper
-	Headers http.Header
-}
-
-// cloneRequest returns a clone of the provided *http.Request.
-// The clone is a shallow copy of the struct and its Header map
-func cloneRequest(r *http.Request) *http.Request {
-	// shallow copy of the struct
-	r2 := new(http.Request)
-	*r2 = *r
-	// deep copy of the Header
-	r2.Header = make(http.Header, len(r.Header))
-	for k, s := range r.Header {
-		r2.Header[k] = append([]string(nil), s...)
+// DockerHeaders returns request modifiers that ensure requests have
+// the User-Agent header set to dockerUserAgent and that metaHeaders
+// are added.
+func DockerHeaders(metaHeaders http.Header) []transport.RequestModifier {
+	modifiers := []transport.RequestModifier{
+		transport.NewHeaderRequestModifier(http.Header{"User-Agent": []string{dockerUserAgent}}),
 	}
-	return r2
-}
-
-func (tr *DockerHeaders) RoundTrip(req *http.Request) (*http.Response, error) {
-	req = cloneRequest(req)
-	httpVersion := make([]useragent.VersionInfo, 0, 4)
-	httpVersion = append(httpVersion, useragent.VersionInfo{"docker", dockerversion.VERSION})
-	httpVersion = append(httpVersion, useragent.VersionInfo{"go", runtime.Version()})
-	httpVersion = append(httpVersion, useragent.VersionInfo{"git-commit", dockerversion.GITCOMMIT})
-	if kernelVersion, err := kernel.GetKernelVersion(); err == nil {
-		httpVersion = append(httpVersion, useragent.VersionInfo{"kernel", kernelVersion.String()})
-	}
-	httpVersion = append(httpVersion, useragent.VersionInfo{"os", runtime.GOOS})
-	httpVersion = append(httpVersion, useragent.VersionInfo{"arch", runtime.GOARCH})
-
-	userAgent := useragent.AppendVersions(req.UserAgent(), httpVersion...)
-
-	req.Header.Set("User-Agent", userAgent)
-
-	for k, v := range tr.Headers {
-		req.Header[k] = v
+	if metaHeaders != nil {
+		modifiers = append(modifiers, transport.NewHeaderRequestModifier(metaHeaders))
 	}
-	return tr.RoundTripper.RoundTrip(req)
+	return modifiers
 }
 
 type debugTransport struct{ http.RoundTripper }

+ 8 - 7
registry/registry_test.go

@@ -8,6 +8,7 @@ import (
 	"testing"
 
 	"github.com/docker/docker/cliconfig"
+	"github.com/docker/docker/pkg/transport"
 )
 
 var (
@@ -21,12 +22,12 @@ const (
 
 func spawnTestRegistrySession(t *testing.T) *Session {
 	authConfig := &cliconfig.AuthConfig{}
-	endpoint, err := NewEndpoint(makeIndex("/v1/"))
+	endpoint, err := NewEndpoint(makeIndex("/v1/"), nil)
 	if err != nil {
 		t.Fatal(err)
 	}
 	var tr http.RoundTripper = debugTransport{NewTransport(ReceiveTimeout, endpoint.IsSecure)}
-	tr = &DockerHeaders{&authTransport{RoundTripper: tr, AuthConfig: authConfig}, nil}
+	tr = transport.NewTransport(AuthTransport(tr, authConfig, false), DockerHeaders(nil)...)
 	client := HTTPClient(tr)
 	r, err := NewSession(client, authConfig, endpoint)
 	if err != nil {
@@ -48,7 +49,7 @@ func spawnTestRegistrySession(t *testing.T) *Session {
 
 func TestPingRegistryEndpoint(t *testing.T) {
 	testPing := func(index *IndexInfo, expectedStandalone bool, assertMessage string) {
-		ep, err := NewEndpoint(index)
+		ep, err := NewEndpoint(index, nil)
 		if err != nil {
 			t.Fatal(err)
 		}
@@ -68,7 +69,7 @@ func TestPingRegistryEndpoint(t *testing.T) {
 func TestEndpoint(t *testing.T) {
 	// Simple wrapper to fail test if err != nil
 	expandEndpoint := func(index *IndexInfo) *Endpoint {
-		endpoint, err := NewEndpoint(index)
+		endpoint, err := NewEndpoint(index, nil)
 		if err != nil {
 			t.Fatal(err)
 		}
@@ -77,7 +78,7 @@ func TestEndpoint(t *testing.T) {
 
 	assertInsecureIndex := func(index *IndexInfo) {
 		index.Secure = true
-		_, err := NewEndpoint(index)
+		_, err := NewEndpoint(index, nil)
 		assertNotEqual(t, err, nil, index.Name+": Expected error for insecure index")
 		assertEqual(t, strings.Contains(err.Error(), "insecure-registry"), true, index.Name+": Expected insecure-registry  error for insecure index")
 		index.Secure = false
@@ -85,7 +86,7 @@ func TestEndpoint(t *testing.T) {
 
 	assertSecureIndex := func(index *IndexInfo) {
 		index.Secure = true
-		_, err := NewEndpoint(index)
+		_, err := NewEndpoint(index, nil)
 		assertNotEqual(t, err, nil, index.Name+": Expected cert error for secure index")
 		assertEqual(t, strings.Contains(err.Error(), "certificate signed by unknown authority"), true, index.Name+": Expected cert error for secure index")
 		index.Secure = false
@@ -151,7 +152,7 @@ func TestEndpoint(t *testing.T) {
 	}
 	for _, address := range badEndpoints {
 		index.Name = address
-		_, err := NewEndpoint(index)
+		_, err := NewEndpoint(index, nil)
 		checkNotEqual(t, err, nil, "Expected error while expanding bad endpoint")
 	}
 }

+ 8 - 4
registry/service.go

@@ -1,6 +1,10 @@
 package registry
 
-import "github.com/docker/docker/cliconfig"
+import (
+	"net/http"
+
+	"github.com/docker/docker/cliconfig"
+)
 
 type Service struct {
 	Config *ServiceConfig
@@ -27,7 +31,7 @@ func (s *Service) Auth(authConfig *cliconfig.AuthConfig) (string, error) {
 	if err != nil {
 		return "", err
 	}
-	endpoint, err := NewEndpoint(index)
+	endpoint, err := NewEndpoint(index, nil)
 	if err != nil {
 		return "", err
 	}
@@ -44,11 +48,11 @@ func (s *Service) Search(term string, authConfig *cliconfig.AuthConfig, headers
 	}
 
 	// *TODO: Search multiple indexes.
-	endpoint, err := repoInfo.GetEndpoint()
+	endpoint, err := repoInfo.GetEndpoint(http.Header(headers))
 	if err != nil {
 		return nil, err
 	}
-	r, err := NewSession(endpoint.HTTPClient(), authConfig, endpoint)
+	r, err := NewSession(endpoint.client, authConfig, endpoint)
 	if err != nil {
 		return nil, err
 	}

+ 49 - 10
registry/session.go

@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"crypto/sha256"
 	"errors"
+	"sync"
 	// this is required for some certificates
 	_ "crypto/sha512"
 	"encoding/hex"
@@ -22,6 +23,7 @@ import (
 	"github.com/docker/docker/cliconfig"
 	"github.com/docker/docker/pkg/httputils"
 	"github.com/docker/docker/pkg/tarsum"
+	"github.com/docker/docker/pkg/transport"
 )
 
 type Session struct {
@@ -31,7 +33,18 @@ type Session struct {
 	authConfig *cliconfig.AuthConfig
 }
 
-// authTransport handles the auth layer when communicating with a v1 registry (private or official)
+type authTransport struct {
+	http.RoundTripper
+	*cliconfig.AuthConfig
+
+	alwaysSetBasicAuth bool
+	token              []string
+
+	mu     sync.Mutex                      // guards modReq
+	modReq map[*http.Request]*http.Request // original -> modified
+}
+
+// AuthTransport handles the auth layer when communicating with a v1 registry (private or official)
 //
 // For private v1 registries, set alwaysSetBasicAuth to true.
 //
@@ -44,16 +57,23 @@ type Session struct {
 // If the server sends a token without the client having requested it, it is ignored.
 //
 // This RoundTripper also has a CancelRequest method important for correct timeout handling.
-type authTransport struct {
-	http.RoundTripper
-	*cliconfig.AuthConfig
-
-	alwaysSetBasicAuth bool
-	token              []string
+func AuthTransport(base http.RoundTripper, authConfig *cliconfig.AuthConfig, alwaysSetBasicAuth bool) http.RoundTripper {
+	if base == nil {
+		base = http.DefaultTransport
+	}
+	return &authTransport{
+		RoundTripper:       base,
+		AuthConfig:         authConfig,
+		alwaysSetBasicAuth: alwaysSetBasicAuth,
+		modReq:             make(map[*http.Request]*http.Request),
+	}
 }
 
-func (tr *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
-	req = cloneRequest(req)
+func (tr *authTransport) RoundTrip(orig *http.Request) (*http.Response, error) {
+	req := transport.CloneRequest(orig)
+	tr.mu.Lock()
+	tr.modReq[orig] = req
+	tr.mu.Unlock()
 
 	if tr.alwaysSetBasicAuth {
 		req.SetBasicAuth(tr.Username, tr.Password)
@@ -73,14 +93,33 @@ func (tr *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 	}
 	resp, err := tr.RoundTripper.RoundTrip(req)
 	if err != nil {
+		delete(tr.modReq, orig)
 		return nil, err
 	}
 	if askedForToken && len(resp.Header["X-Docker-Token"]) > 0 {
 		tr.token = resp.Header["X-Docker-Token"]
 	}
+	resp.Body = &transport.OnEOFReader{
+		Rc: resp.Body,
+		Fn: func() { delete(tr.modReq, orig) },
+	}
 	return resp, nil
 }
 
+// CancelRequest cancels an in-flight request by closing its connection.
+func (tr *authTransport) CancelRequest(req *http.Request) {
+	type canceler interface {
+		CancelRequest(*http.Request)
+	}
+	if cr, ok := tr.RoundTripper.(canceler); ok {
+		tr.mu.Lock()
+		modReq := tr.modReq[req]
+		delete(tr.modReq, req)
+		tr.mu.Unlock()
+		cr.CancelRequest(modReq)
+	}
+}
+
 // TODO(tiborvass): remove authConfig param once registry client v2 is vendored
 func NewSession(client *http.Client, authConfig *cliconfig.AuthConfig, endpoint *Endpoint) (r *Session, err error) {
 	r = &Session{
@@ -105,7 +144,7 @@ func NewSession(client *http.Client, authConfig *cliconfig.AuthConfig, endpoint
 		}
 	}
 
-	client.Transport = &authTransport{RoundTripper: client.Transport, AuthConfig: authConfig, alwaysSetBasicAuth: alwaysSetBasicAuth}
+	client.Transport = AuthTransport(client.Transport, authConfig, alwaysSetBasicAuth)
 
 	jar, err := cookiejar.New(nil)
 	if err != nil {

+ 2 - 2
registry/session_v2.go

@@ -27,7 +27,7 @@ func getV2Builder(e *Endpoint) *v2.URLBuilder {
 func (r *Session) V2RegistryEndpoint(index *IndexInfo) (ep *Endpoint, err error) {
 	// TODO check if should use Mirror
 	if index.Official {
-		ep, err = newEndpoint(REGISTRYSERVER, true)
+		ep, err = newEndpoint(REGISTRYSERVER, true, nil)
 		if err != nil {
 			return
 		}
@@ -38,7 +38,7 @@ func (r *Session) V2RegistryEndpoint(index *IndexInfo) (ep *Endpoint, err error)
 	} else if r.indexEndpoint.String() == index.GetAuthConfigKey() {
 		ep = r.indexEndpoint
 	} else {
-		ep, err = NewEndpoint(index)
+		ep, err = NewEndpoint(index, nil)
 		if err != nil {
 			return
 		}

+ 2 - 2
registry/token.go

@@ -13,7 +13,7 @@ type tokenResponse struct {
 	Token string `json:"token"`
 }
 
-func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint, client *http.Client) (token string, err error) {
+func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint) (token string, err error) {
 	realm, ok := params["realm"]
 	if !ok {
 		return "", errors.New("no realm specified for token auth challenge")
@@ -56,7 +56,7 @@ func getToken(username, password string, params map[string]string, registryEndpo
 
 	req.URL.RawQuery = reqParams.Encode()
 
-	resp, err := client.Do(req)
+	resp, err := registryEndpoint.client.Do(req)
 	if err != nil {
 		return "", err
 	}