Bläddra i källkod

Use basic auth for private registries when over HTTPS.
RequestFactory is no longer a singleton (can be different for different instances of Registry)
Registry now has an indexEndpoint member
Registry methods that needed the indexEndpoint parameter no longer do so
Registry methods will only use token auth where applicable if basic auth is not enabled.

shin- 11 år sedan
förälder
incheckning
045989e3d8
3 ändrade filer med 70 tillägg och 44 borttagningar
  1. 47 15
      registry/registry.go
  2. 4 5
      registry/registry_test.go
  3. 19 24
      server.go

+ 47 - 15
registry/registry.go

@@ -160,7 +160,9 @@ func (r *Registry) GetRemoteHistory(imgID, registry string, token []string) ([]s
 	if err != nil {
 		return nil, err
 	}
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ", "))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 	res, err := doWithCookies(r.client, req)
 	if err != nil {
 		return nil, err
@@ -193,7 +195,9 @@ func (r *Registry) LookupRemoteImage(imgID, registry string, token []string) boo
 	if err != nil {
 		return false
 	}
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ", "))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 	res, err := doWithCookies(r.client, req)
 	if err != nil {
 		return false
@@ -209,7 +213,9 @@ func (r *Registry) GetRemoteImageJSON(imgID, registry string, token []string) ([
 	if err != nil {
 		return nil, -1, fmt.Errorf("Failed to download json: %s", err)
 	}
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ", "))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 	res, err := doWithCookies(r.client, req)
 	if err != nil {
 		return nil, -1, fmt.Errorf("Failed to download json: %s", err)
@@ -236,7 +242,9 @@ func (r *Registry) GetRemoteImageLayer(imgID, registry string, token []string) (
 	if err != nil {
 		return nil, fmt.Errorf("Error while getting from the server: %s\n", err)
 	}
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ", "))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 	res, err := doWithCookies(r.client, req)
 	if err != nil {
 		return nil, err
@@ -262,7 +270,9 @@ func (r *Registry) GetRemoteTags(registries []string, repository string, token [
 		if err != nil {
 			return nil, err
 		}
-		req.Header.Set("Authorization", "Token "+strings.Join(token, ", "))
+		if req.Header.Get("Authorization") == "" { // Don't override
+			req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+		}
 		res, err := doWithCookies(r.client, req)
 		if err != nil {
 			return nil, err
@@ -290,7 +300,8 @@ func (r *Registry) GetRemoteTags(registries []string, repository string, token [
 	return nil, fmt.Errorf("Could not reach any registry endpoint")
 }
 
-func (r *Registry) GetRepositoryData(indexEp, remote string) (*RepositoryData, error) {
+func (r *Registry) GetRepositoryData(remote string) (*RepositoryData, error) {
+	indexEp := r.indexEndpoint
 	repositoryTarget := fmt.Sprintf("%srepositories/%s/images", indexEp, remote)
 
 	utils.Debugf("[registry] Calling GET %s", repositoryTarget)
@@ -364,7 +375,9 @@ func (r *Registry) PushImageChecksumRegistry(imgData *ImgData, registry string,
 	if err != nil {
 		return err
 	}
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 	req.Header.Set("X-Docker-Checksum", imgData.Checksum)
 
 	res, err := doWithCookies(r.client, req)
@@ -401,7 +414,9 @@ func (r *Registry) PushImageJSONRegistry(imgData *ImgData, jsonRaw []byte, regis
 		return err
 	}
 	req.Header.Add("Content-type", "application/json")
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 
 	res, err := doWithCookies(r.client, req)
 	if err != nil {
@@ -436,7 +451,9 @@ func (r *Registry) PushImageLayerRegistry(imgID string, layer io.Reader, registr
 	}
 	req.ContentLength = -1
 	req.TransferEncoding = []string{"chunked"}
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 	res, err := doWithCookies(r.client, req)
 	if err != nil {
 		return "", fmt.Errorf("Failed to upload layer: %s", err)
@@ -465,7 +482,9 @@ func (r *Registry) PushRegistryTag(remote, revision, tag, registry string, token
 		return err
 	}
 	req.Header.Add("Content-type", "application/json")
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 	req.ContentLength = int64(len(revision))
 	res, err := doWithCookies(r.client, req)
 	if err != nil {
@@ -478,8 +497,9 @@ func (r *Registry) PushRegistryTag(remote, revision, tag, registry string, token
 	return nil
 }
 
-func (r *Registry) PushImageJSONIndex(indexEp, remote string, imgList []*ImgData, validate bool, regs []string) (*RepositoryData, error) {
+func (r *Registry) PushImageJSONIndex(remote string, imgList []*ImgData, validate bool, regs []string) (*RepositoryData, error) {
 	cleanImgList := []*ImgData{}
+	indexEp := r.indexEndpoint
 
 	if validate {
 		for _, elem := range imgList {
@@ -583,6 +603,7 @@ func (r *Registry) PushImageJSONIndex(indexEp, remote string, imgList []*ImgData
 }
 
 func (r *Registry) SearchRepositories(term string) (*SearchResults, error) {
+	utils.Debugf("Index server: %s", r.indexEndpoint)
 	u := auth.IndexServerAddress() + "search?q=" + url.QueryEscape(term)
 	req, err := r.reqFactory.NewRequest("GET", u, nil)
 	if err != nil {
@@ -644,12 +665,13 @@ type ImgData struct {
 }
 
 type Registry struct {
-	client     *http.Client
-	authConfig *auth.AuthConfig
-	reqFactory *utils.HTTPRequestFactory
+	client		  *http.Client
+	authConfig	  *auth.AuthConfig
+	reqFactory	  *utils.HTTPRequestFactory
+	indexEndpoint string
 }
 
-func NewRegistry(root string, authConfig *auth.AuthConfig, factory *utils.HTTPRequestFactory) (r *Registry, err error) {
+func NewRegistry(authConfig *auth.AuthConfig, factory *utils.HTTPRequestFactory, indexEndpoint string) (r *Registry, err error) {
 	httpTransport := &http.Transport{
 		DisableKeepAlives: true,
 		Proxy:             http.ProxyFromEnvironment,
@@ -660,12 +682,22 @@ func NewRegistry(root string, authConfig *auth.AuthConfig, factory *utils.HTTPRe
 		client: &http.Client{
 			Transport: httpTransport,
 		},
+		indexEndpoint: indexEndpoint,
 	}
 	r.client.Jar, err = cookiejar.New(nil)
 	if err != nil {
 		return nil, err
 	}
 
+	// If we're working with a private registry over HTTPS, send Basic Auth headers
+	// alongside our requests.
+	if indexEndpoint != auth.IndexServerAddress() && strings.HasPrefix(indexEndpoint, "https://") {
+		utils.Debugf("Endpoint %s is eligible for private registry auth. Enabling decorator.", indexEndpoint)
+		dec := utils.NewHTTPAuthDecorator(authConfig.Username, authConfig.Password)
+		factory.AddDecorator(dec)
+	}
+
 	r.reqFactory = factory
 	return r, nil
 }
+

+ 4 - 5
registry/registry_test.go

@@ -15,7 +15,7 @@ var (
 
 func spawnTestRegistry(t *testing.T) *Registry {
 	authConfig := &auth.AuthConfig{}
-	r, err := NewRegistry("", authConfig, utils.NewHTTPRequestFactory())
+	r, err := NewRegistry(authConfig, utils.NewHTTPRequestFactory(), makeURL("/v1/"))
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -99,7 +99,7 @@ func TestGetRemoteTags(t *testing.T) {
 
 func TestGetRepositoryData(t *testing.T) {
 	r := spawnTestRegistry(t)
-	data, err := r.GetRepositoryData(makeURL("/v1/"), "foo42/bar")
+	data, err := r.GetRepositoryData("foo42/bar")
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -168,15 +168,14 @@ func TestPushImageJSONIndex(t *testing.T) {
 			Checksum: "sha256:bea7bf2e4bacd479344b737328db47b18880d09096e6674165533aa994f5e9f2",
 		},
 	}
-	ep := makeURL("/v1/")
-	repoData, err := r.PushImageJSONIndex(ep, "foo42/bar", imgData, false, nil)
+	repoData, err := r.PushImageJSONIndex("foo42/bar", imgData, false, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
 	if repoData == nil {
 		t.Fatal("Expected RepositoryData object")
 	}
-	repoData, err = r.PushImageJSONIndex(ep, "foo42/bar", imgData, true, []string{ep})
+	repoData, err = r.PushImageJSONIndex("foo42/bar", imgData, true, []string{r.indexEndpoint})
 	if err != nil {
 		t.Fatal(err)
 	}

+ 19 - 24
server.go

@@ -425,7 +425,7 @@ func (srv *Server) recursiveLoad(address, tmpImageDir string) error {
 }
 
 func (srv *Server) ImagesSearch(term string) ([]registry.SearchResult, error) {
-	r, err := registry.NewRegistry(srv.runtime.config.Root, nil, srv.HTTPRequestFactory(nil))
+	r, err := registry.NewRegistry(nil, srv.HTTPRequestFactory(nil), auth.IndexServerAddress())
 	if err != nil {
 		return nil, err
 	}
@@ -816,10 +816,10 @@ func (srv *Server) pullImage(r *registry.Registry, out io.Writer, imgID, endpoin
 	return nil
 }
 
-func (srv *Server) pullRepository(r *registry.Registry, out io.Writer, localName, remoteName, askedTag, indexEp string, sf *utils.StreamFormatter, parallel bool) error {
+func (srv *Server) pullRepository(r *registry.Registry, out io.Writer, localName, remoteName, askedTag string, sf *utils.StreamFormatter, parallel bool) error {
 	out.Write(sf.FormatStatus("", "Pulling repository %s", localName))
 
-	repoData, err := r.GetRepositoryData(indexEp, remoteName)
+	repoData, err := r.GetRepositoryData(remoteName)
 	if err != nil {
 		return err
 	}
@@ -989,11 +989,6 @@ func (srv *Server) poolRemove(kind, key string) error {
 }
 
 func (srv *Server) ImagePull(localName string, tag string, out io.Writer, sf *utils.StreamFormatter, authConfig *auth.AuthConfig, metaHeaders map[string][]string, parallel bool) error {
-	r, err := registry.NewRegistry(srv.runtime.config.Root, authConfig, srv.HTTPRequestFactory(metaHeaders))
-	if err != nil {
-		return err
-	}
-
 	out = utils.NewWriteFlusher(out)
 
 	c, err := srv.poolAdd("pull", localName+":"+tag)
@@ -1014,12 +1009,17 @@ func (srv *Server) ImagePull(localName string, tag string, out io.Writer, sf *ut
 		return err
 	}
 
+	r, err := registry.NewRegistry(authConfig, srv.HTTPRequestFactory(metaHeaders), endpoint)
+	if err != nil {
+		return err
+	}
+
 	if endpoint == auth.IndexServerAddress() {
 		// If pull "index.docker.io/foo/bar", it's stored locally under "foo/bar"
 		localName = remoteName
 	}
 
-	if err = srv.pullRepository(r, out, localName, remoteName, tag, endpoint, sf, parallel); err != nil {
+	if err = srv.pullRepository(r, out, localName, remoteName, tag, sf, parallel); err != nil {
 		return err
 	}
 
@@ -1081,7 +1081,7 @@ func flatten(slc [][]*registry.ImgData) []*registry.ImgData {
 	return result
 }
 
-func (srv *Server) pushRepository(r *registry.Registry, out io.Writer, localName, remoteName string, localRepo map[string]string, indexEp string, sf *utils.StreamFormatter) error {
+func (srv *Server) pushRepository(r *registry.Registry, out io.Writer, localName, remoteName string, localRepo map[string]string, sf *utils.StreamFormatter) error {
 	out = utils.NewWriteFlusher(out)
 	imgList, err := srv.getImageList(localRepo)
 	if err != nil {
@@ -1091,7 +1091,7 @@ func (srv *Server) pushRepository(r *registry.Registry, out io.Writer, localName
 	out.Write(sf.FormatStatus("", "Sending image list"))
 
 	var repoData *registry.RepositoryData
-	repoData, err = r.PushImageJSONIndex(indexEp, remoteName, flattenedImgList, false, nil)
+	repoData, err = r.PushImageJSONIndex(remoteName, flattenedImgList, false, nil)
 	if err != nil {
 		return err
 	}
@@ -1137,7 +1137,7 @@ func (srv *Server) pushRepository(r *registry.Registry, out io.Writer, localName
 		}
 	}
 
-	if _, err := r.PushImageJSONIndex(indexEp, remoteName, flattenedImgList, true, repoData.Endpoints); err != nil {
+	if _, err := r.PushImageJSONIndex(remoteName, flattenedImgList, true, repoData.Endpoints); err != nil {
 		return err
 	}
 
@@ -1203,7 +1203,7 @@ func (srv *Server) ImagePush(localName string, out io.Writer, sf *utils.StreamFo
 
 	out = utils.NewWriteFlusher(out)
 	img, err := srv.runtime.graph.Get(localName)
-	r, err2 := registry.NewRegistry(srv.runtime.config.Root, authConfig, srv.HTTPRequestFactory(metaHeaders))
+	r, err2 := registry.NewRegistry(authConfig, srv.HTTPRequestFactory(metaHeaders), endpoint)
 	if err2 != nil {
 		return err2
 	}
@@ -1213,7 +1213,7 @@ func (srv *Server) ImagePush(localName string, out io.Writer, sf *utils.StreamFo
 		out.Write(sf.FormatStatus("", "The push refers to a repository [%s] (len: %d)", localName, reposLen))
 		// If it fails, try to get the repository
 		if localRepo, exists := srv.runtime.repositories.Repositories[localName]; exists {
-			if err := srv.pushRepository(r, out, localName, remoteName, localRepo, endpoint, sf); err != nil {
+			if err := srv.pushRepository(r, out, localName, remoteName, localRepo, sf); err != nil {
 				return err
 			}
 			return nil
@@ -1852,7 +1852,6 @@ func NewServer(eng *engine.Engine, config *DaemonConfig) (*Server, error) {
 		pushingPool: make(map[string]chan struct{}),
 		events:      make([]utils.JSONMessage, 0, 64), //only keeps the 64 last events
 		listeners:   make(map[string]chan utils.JSONMessage),
-		reqFactory:  nil,
 	}
 	runtime.srv = srv
 	return srv, nil
@@ -1861,15 +1860,12 @@ func NewServer(eng *engine.Engine, config *DaemonConfig) (*Server, error) {
 func (srv *Server) HTTPRequestFactory(metaHeaders map[string][]string) *utils.HTTPRequestFactory {
 	srv.Lock()
 	defer srv.Unlock()
-	if srv.reqFactory == nil {
-		ud := utils.NewHTTPUserAgentDecorator(srv.versionInfos()...)
-		md := &utils.HTTPMetaHeadersDecorator{
-			Headers: metaHeaders,
-		}
-		factory := utils.NewHTTPRequestFactory(ud, md)
-		srv.reqFactory = factory
+	ud := utils.NewHTTPUserAgentDecorator(srv.versionInfos()...)
+	md := &utils.HTTPMetaHeadersDecorator{
+		Headers: metaHeaders,
 	}
-	return srv.reqFactory
+	factory := utils.NewHTTPRequestFactory(ud, md)
+	return factory
 }
 
 func (srv *Server) LogEvent(action, id, from string) *utils.JSONMessage {
@@ -1904,6 +1900,5 @@ type Server struct {
 	pushingPool map[string]chan struct{}
 	events      []utils.JSONMessage
 	listeners   map[string]chan utils.JSONMessage
-	reqFactory  *utils.HTTPRequestFactory
 	Eng         *engine.Engine
 }