فهرست منبع

Use basic auth for private registries when over HTTPS.
RequestFactory is no longer a singleton (can be different for different instances of Registry)
Registry now has an indexEndpoint member
Registry methods that needed the indexEndpoint parameter no longer do so
Registry methods will only use token auth where applicable if basic auth is not enabled.

shin- 11 سال پیش
والد
کامیت
045989e3d8
3فایلهای تغییر یافته به همراه70 افزوده شده و 44 حذف شده
  1. 47 15
      registry/registry.go
  2. 4 5
      registry/registry_test.go
  3. 19 24
      server.go

+ 47 - 15
registry/registry.go

@@ -160,7 +160,9 @@ func (r *Registry) GetRemoteHistory(imgID, registry string, token []string) ([]s
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ", "))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 	res, err := doWithCookies(r.client, req)
 	res, err := doWithCookies(r.client, req)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -193,7 +195,9 @@ func (r *Registry) LookupRemoteImage(imgID, registry string, token []string) boo
 	if err != nil {
 	if err != nil {
 		return false
 		return false
 	}
 	}
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ", "))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 	res, err := doWithCookies(r.client, req)
 	res, err := doWithCookies(r.client, req)
 	if err != nil {
 	if err != nil {
 		return false
 		return false
@@ -209,7 +213,9 @@ func (r *Registry) GetRemoteImageJSON(imgID, registry string, token []string) ([
 	if err != nil {
 	if err != nil {
 		return nil, -1, fmt.Errorf("Failed to download json: %s", err)
 		return nil, -1, fmt.Errorf("Failed to download json: %s", err)
 	}
 	}
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ", "))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 	res, err := doWithCookies(r.client, req)
 	res, err := doWithCookies(r.client, req)
 	if err != nil {
 	if err != nil {
 		return nil, -1, fmt.Errorf("Failed to download json: %s", err)
 		return nil, -1, fmt.Errorf("Failed to download json: %s", err)
@@ -236,7 +242,9 @@ func (r *Registry) GetRemoteImageLayer(imgID, registry string, token []string) (
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("Error while getting from the server: %s\n", err)
 		return nil, fmt.Errorf("Error while getting from the server: %s\n", err)
 	}
 	}
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ", "))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 	res, err := doWithCookies(r.client, req)
 	res, err := doWithCookies(r.client, req)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -262,7 +270,9 @@ func (r *Registry) GetRemoteTags(registries []string, repository string, token [
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
-		req.Header.Set("Authorization", "Token "+strings.Join(token, ", "))
+		if req.Header.Get("Authorization") == "" { // Don't override
+			req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+		}
 		res, err := doWithCookies(r.client, req)
 		res, err := doWithCookies(r.client, req)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
@@ -290,7 +300,8 @@ func (r *Registry) GetRemoteTags(registries []string, repository string, token [
 	return nil, fmt.Errorf("Could not reach any registry endpoint")
 	return nil, fmt.Errorf("Could not reach any registry endpoint")
 }
 }
 
 
-func (r *Registry) GetRepositoryData(indexEp, remote string) (*RepositoryData, error) {
+func (r *Registry) GetRepositoryData(remote string) (*RepositoryData, error) {
+	indexEp := r.indexEndpoint
 	repositoryTarget := fmt.Sprintf("%srepositories/%s/images", indexEp, remote)
 	repositoryTarget := fmt.Sprintf("%srepositories/%s/images", indexEp, remote)
 
 
 	utils.Debugf("[registry] Calling GET %s", repositoryTarget)
 	utils.Debugf("[registry] Calling GET %s", repositoryTarget)
@@ -364,7 +375,9 @@ func (r *Registry) PushImageChecksumRegistry(imgData *ImgData, registry string,
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 	req.Header.Set("X-Docker-Checksum", imgData.Checksum)
 	req.Header.Set("X-Docker-Checksum", imgData.Checksum)
 
 
 	res, err := doWithCookies(r.client, req)
 	res, err := doWithCookies(r.client, req)
@@ -401,7 +414,9 @@ func (r *Registry) PushImageJSONRegistry(imgData *ImgData, jsonRaw []byte, regis
 		return err
 		return err
 	}
 	}
 	req.Header.Add("Content-type", "application/json")
 	req.Header.Add("Content-type", "application/json")
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 
 
 	res, err := doWithCookies(r.client, req)
 	res, err := doWithCookies(r.client, req)
 	if err != nil {
 	if err != nil {
@@ -436,7 +451,9 @@ func (r *Registry) PushImageLayerRegistry(imgID string, layer io.Reader, registr
 	}
 	}
 	req.ContentLength = -1
 	req.ContentLength = -1
 	req.TransferEncoding = []string{"chunked"}
 	req.TransferEncoding = []string{"chunked"}
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 	res, err := doWithCookies(r.client, req)
 	res, err := doWithCookies(r.client, req)
 	if err != nil {
 	if err != nil {
 		return "", fmt.Errorf("Failed to upload layer: %s", err)
 		return "", fmt.Errorf("Failed to upload layer: %s", err)
@@ -465,7 +482,9 @@ func (r *Registry) PushRegistryTag(remote, revision, tag, registry string, token
 		return err
 		return err
 	}
 	}
 	req.Header.Add("Content-type", "application/json")
 	req.Header.Add("Content-type", "application/json")
-	req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	if req.Header.Get("Authorization") == "" { // Don't override
+		req.Header.Set("Authorization", "Token "+strings.Join(token, ","))
+	}
 	req.ContentLength = int64(len(revision))
 	req.ContentLength = int64(len(revision))
 	res, err := doWithCookies(r.client, req)
 	res, err := doWithCookies(r.client, req)
 	if err != nil {
 	if err != nil {
@@ -478,8 +497,9 @@ func (r *Registry) PushRegistryTag(remote, revision, tag, registry string, token
 	return nil
 	return nil
 }
 }
 
 
-func (r *Registry) PushImageJSONIndex(indexEp, remote string, imgList []*ImgData, validate bool, regs []string) (*RepositoryData, error) {
+func (r *Registry) PushImageJSONIndex(remote string, imgList []*ImgData, validate bool, regs []string) (*RepositoryData, error) {
 	cleanImgList := []*ImgData{}
 	cleanImgList := []*ImgData{}
+	indexEp := r.indexEndpoint
 
 
 	if validate {
 	if validate {
 		for _, elem := range imgList {
 		for _, elem := range imgList {
@@ -583,6 +603,7 @@ func (r *Registry) PushImageJSONIndex(indexEp, remote string, imgList []*ImgData
 }
 }
 
 
 func (r *Registry) SearchRepositories(term string) (*SearchResults, error) {
 func (r *Registry) SearchRepositories(term string) (*SearchResults, error) {
+	utils.Debugf("Index server: %s", r.indexEndpoint)
 	u := auth.IndexServerAddress() + "search?q=" + url.QueryEscape(term)
 	u := auth.IndexServerAddress() + "search?q=" + url.QueryEscape(term)
 	req, err := r.reqFactory.NewRequest("GET", u, nil)
 	req, err := r.reqFactory.NewRequest("GET", u, nil)
 	if err != nil {
 	if err != nil {
@@ -644,12 +665,13 @@ type ImgData struct {
 }
 }
 
 
 type Registry struct {
 type Registry struct {
-	client     *http.Client
-	authConfig *auth.AuthConfig
-	reqFactory *utils.HTTPRequestFactory
+	client		  *http.Client
+	authConfig	  *auth.AuthConfig
+	reqFactory	  *utils.HTTPRequestFactory
+	indexEndpoint string
 }
 }
 
 
-func NewRegistry(root string, authConfig *auth.AuthConfig, factory *utils.HTTPRequestFactory) (r *Registry, err error) {
+func NewRegistry(authConfig *auth.AuthConfig, factory *utils.HTTPRequestFactory, indexEndpoint string) (r *Registry, err error) {
 	httpTransport := &http.Transport{
 	httpTransport := &http.Transport{
 		DisableKeepAlives: true,
 		DisableKeepAlives: true,
 		Proxy:             http.ProxyFromEnvironment,
 		Proxy:             http.ProxyFromEnvironment,
@@ -660,12 +682,22 @@ func NewRegistry(root string, authConfig *auth.AuthConfig, factory *utils.HTTPRe
 		client: &http.Client{
 		client: &http.Client{
 			Transport: httpTransport,
 			Transport: httpTransport,
 		},
 		},
+		indexEndpoint: indexEndpoint,
 	}
 	}
 	r.client.Jar, err = cookiejar.New(nil)
 	r.client.Jar, err = cookiejar.New(nil)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
+	// If we're working with a private registry over HTTPS, send Basic Auth headers
+	// alongside our requests.
+	if indexEndpoint != auth.IndexServerAddress() && strings.HasPrefix(indexEndpoint, "https://") {
+		utils.Debugf("Endpoint %s is eligible for private registry auth. Enabling decorator.", indexEndpoint)
+		dec := utils.NewHTTPAuthDecorator(authConfig.Username, authConfig.Password)
+		factory.AddDecorator(dec)
+	}
+
 	r.reqFactory = factory
 	r.reqFactory = factory
 	return r, nil
 	return r, nil
 }
 }
+

+ 4 - 5
registry/registry_test.go

@@ -15,7 +15,7 @@ var (
 
 
 func spawnTestRegistry(t *testing.T) *Registry {
 func spawnTestRegistry(t *testing.T) *Registry {
 	authConfig := &auth.AuthConfig{}
 	authConfig := &auth.AuthConfig{}
-	r, err := NewRegistry("", authConfig, utils.NewHTTPRequestFactory())
+	r, err := NewRegistry(authConfig, utils.NewHTTPRequestFactory(), makeURL("/v1/"))
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -99,7 +99,7 @@ func TestGetRemoteTags(t *testing.T) {
 
 
 func TestGetRepositoryData(t *testing.T) {
 func TestGetRepositoryData(t *testing.T) {
 	r := spawnTestRegistry(t)
 	r := spawnTestRegistry(t)
-	data, err := r.GetRepositoryData(makeURL("/v1/"), "foo42/bar")
+	data, err := r.GetRepositoryData("foo42/bar")
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -168,15 +168,14 @@ func TestPushImageJSONIndex(t *testing.T) {
 			Checksum: "sha256:bea7bf2e4bacd479344b737328db47b18880d09096e6674165533aa994f5e9f2",
 			Checksum: "sha256:bea7bf2e4bacd479344b737328db47b18880d09096e6674165533aa994f5e9f2",
 		},
 		},
 	}
 	}
-	ep := makeURL("/v1/")
-	repoData, err := r.PushImageJSONIndex(ep, "foo42/bar", imgData, false, nil)
+	repoData, err := r.PushImageJSONIndex("foo42/bar", imgData, false, nil)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 	if repoData == nil {
 	if repoData == nil {
 		t.Fatal("Expected RepositoryData object")
 		t.Fatal("Expected RepositoryData object")
 	}
 	}
-	repoData, err = r.PushImageJSONIndex(ep, "foo42/bar", imgData, true, []string{ep})
+	repoData, err = r.PushImageJSONIndex("foo42/bar", imgData, true, []string{r.indexEndpoint})
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}

+ 19 - 24
server.go

@@ -425,7 +425,7 @@ func (srv *Server) recursiveLoad(address, tmpImageDir string) error {
 }
 }
 
 
 func (srv *Server) ImagesSearch(term string) ([]registry.SearchResult, error) {
 func (srv *Server) ImagesSearch(term string) ([]registry.SearchResult, error) {
-	r, err := registry.NewRegistry(srv.runtime.config.Root, nil, srv.HTTPRequestFactory(nil))
+	r, err := registry.NewRegistry(nil, srv.HTTPRequestFactory(nil), auth.IndexServerAddress())
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -816,10 +816,10 @@ func (srv *Server) pullImage(r *registry.Registry, out io.Writer, imgID, endpoin
 	return nil
 	return nil
 }
 }
 
 
-func (srv *Server) pullRepository(r *registry.Registry, out io.Writer, localName, remoteName, askedTag, indexEp string, sf *utils.StreamFormatter, parallel bool) error {
+func (srv *Server) pullRepository(r *registry.Registry, out io.Writer, localName, remoteName, askedTag string, sf *utils.StreamFormatter, parallel bool) error {
 	out.Write(sf.FormatStatus("", "Pulling repository %s", localName))
 	out.Write(sf.FormatStatus("", "Pulling repository %s", localName))
 
 
-	repoData, err := r.GetRepositoryData(indexEp, remoteName)
+	repoData, err := r.GetRepositoryData(remoteName)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -989,11 +989,6 @@ func (srv *Server) poolRemove(kind, key string) error {
 }
 }
 
 
 func (srv *Server) ImagePull(localName string, tag string, out io.Writer, sf *utils.StreamFormatter, authConfig *auth.AuthConfig, metaHeaders map[string][]string, parallel bool) error {
 func (srv *Server) ImagePull(localName string, tag string, out io.Writer, sf *utils.StreamFormatter, authConfig *auth.AuthConfig, metaHeaders map[string][]string, parallel bool) error {
-	r, err := registry.NewRegistry(srv.runtime.config.Root, authConfig, srv.HTTPRequestFactory(metaHeaders))
-	if err != nil {
-		return err
-	}
-
 	out = utils.NewWriteFlusher(out)
 	out = utils.NewWriteFlusher(out)
 
 
 	c, err := srv.poolAdd("pull", localName+":"+tag)
 	c, err := srv.poolAdd("pull", localName+":"+tag)
@@ -1014,12 +1009,17 @@ func (srv *Server) ImagePull(localName string, tag string, out io.Writer, sf *ut
 		return err
 		return err
 	}
 	}
 
 
+	r, err := registry.NewRegistry(authConfig, srv.HTTPRequestFactory(metaHeaders), endpoint)
+	if err != nil {
+		return err
+	}
+
 	if endpoint == auth.IndexServerAddress() {
 	if endpoint == auth.IndexServerAddress() {
 		// If pull "index.docker.io/foo/bar", it's stored locally under "foo/bar"
 		// If pull "index.docker.io/foo/bar", it's stored locally under "foo/bar"
 		localName = remoteName
 		localName = remoteName
 	}
 	}
 
 
-	if err = srv.pullRepository(r, out, localName, remoteName, tag, endpoint, sf, parallel); err != nil {
+	if err = srv.pullRepository(r, out, localName, remoteName, tag, sf, parallel); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -1081,7 +1081,7 @@ func flatten(slc [][]*registry.ImgData) []*registry.ImgData {
 	return result
 	return result
 }
 }
 
 
-func (srv *Server) pushRepository(r *registry.Registry, out io.Writer, localName, remoteName string, localRepo map[string]string, indexEp string, sf *utils.StreamFormatter) error {
+func (srv *Server) pushRepository(r *registry.Registry, out io.Writer, localName, remoteName string, localRepo map[string]string, sf *utils.StreamFormatter) error {
 	out = utils.NewWriteFlusher(out)
 	out = utils.NewWriteFlusher(out)
 	imgList, err := srv.getImageList(localRepo)
 	imgList, err := srv.getImageList(localRepo)
 	if err != nil {
 	if err != nil {
@@ -1091,7 +1091,7 @@ func (srv *Server) pushRepository(r *registry.Registry, out io.Writer, localName
 	out.Write(sf.FormatStatus("", "Sending image list"))
 	out.Write(sf.FormatStatus("", "Sending image list"))
 
 
 	var repoData *registry.RepositoryData
 	var repoData *registry.RepositoryData
-	repoData, err = r.PushImageJSONIndex(indexEp, remoteName, flattenedImgList, false, nil)
+	repoData, err = r.PushImageJSONIndex(remoteName, flattenedImgList, false, nil)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -1137,7 +1137,7 @@ func (srv *Server) pushRepository(r *registry.Registry, out io.Writer, localName
 		}
 		}
 	}
 	}
 
 
-	if _, err := r.PushImageJSONIndex(indexEp, remoteName, flattenedImgList, true, repoData.Endpoints); err != nil {
+	if _, err := r.PushImageJSONIndex(remoteName, flattenedImgList, true, repoData.Endpoints); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -1203,7 +1203,7 @@ func (srv *Server) ImagePush(localName string, out io.Writer, sf *utils.StreamFo
 
 
 	out = utils.NewWriteFlusher(out)
 	out = utils.NewWriteFlusher(out)
 	img, err := srv.runtime.graph.Get(localName)
 	img, err := srv.runtime.graph.Get(localName)
-	r, err2 := registry.NewRegistry(srv.runtime.config.Root, authConfig, srv.HTTPRequestFactory(metaHeaders))
+	r, err2 := registry.NewRegistry(authConfig, srv.HTTPRequestFactory(metaHeaders), endpoint)
 	if err2 != nil {
 	if err2 != nil {
 		return err2
 		return err2
 	}
 	}
@@ -1213,7 +1213,7 @@ func (srv *Server) ImagePush(localName string, out io.Writer, sf *utils.StreamFo
 		out.Write(sf.FormatStatus("", "The push refers to a repository [%s] (len: %d)", localName, reposLen))
 		out.Write(sf.FormatStatus("", "The push refers to a repository [%s] (len: %d)", localName, reposLen))
 		// If it fails, try to get the repository
 		// If it fails, try to get the repository
 		if localRepo, exists := srv.runtime.repositories.Repositories[localName]; exists {
 		if localRepo, exists := srv.runtime.repositories.Repositories[localName]; exists {
-			if err := srv.pushRepository(r, out, localName, remoteName, localRepo, endpoint, sf); err != nil {
+			if err := srv.pushRepository(r, out, localName, remoteName, localRepo, sf); err != nil {
 				return err
 				return err
 			}
 			}
 			return nil
 			return nil
@@ -1852,7 +1852,6 @@ func NewServer(eng *engine.Engine, config *DaemonConfig) (*Server, error) {
 		pushingPool: make(map[string]chan struct{}),
 		pushingPool: make(map[string]chan struct{}),
 		events:      make([]utils.JSONMessage, 0, 64), //only keeps the 64 last events
 		events:      make([]utils.JSONMessage, 0, 64), //only keeps the 64 last events
 		listeners:   make(map[string]chan utils.JSONMessage),
 		listeners:   make(map[string]chan utils.JSONMessage),
-		reqFactory:  nil,
 	}
 	}
 	runtime.srv = srv
 	runtime.srv = srv
 	return srv, nil
 	return srv, nil
@@ -1861,15 +1860,12 @@ func NewServer(eng *engine.Engine, config *DaemonConfig) (*Server, error) {
 func (srv *Server) HTTPRequestFactory(metaHeaders map[string][]string) *utils.HTTPRequestFactory {
 func (srv *Server) HTTPRequestFactory(metaHeaders map[string][]string) *utils.HTTPRequestFactory {
 	srv.Lock()
 	srv.Lock()
 	defer srv.Unlock()
 	defer srv.Unlock()
-	if srv.reqFactory == nil {
-		ud := utils.NewHTTPUserAgentDecorator(srv.versionInfos()...)
-		md := &utils.HTTPMetaHeadersDecorator{
-			Headers: metaHeaders,
-		}
-		factory := utils.NewHTTPRequestFactory(ud, md)
-		srv.reqFactory = factory
+	ud := utils.NewHTTPUserAgentDecorator(srv.versionInfos()...)
+	md := &utils.HTTPMetaHeadersDecorator{
+		Headers: metaHeaders,
 	}
 	}
-	return srv.reqFactory
+	factory := utils.NewHTTPRequestFactory(ud, md)
+	return factory
 }
 }
 
 
 func (srv *Server) LogEvent(action, id, from string) *utils.JSONMessage {
 func (srv *Server) LogEvent(action, id, from string) *utils.JSONMessage {
@@ -1904,6 +1900,5 @@ type Server struct {
 	pushingPool map[string]chan struct{}
 	pushingPool map[string]chan struct{}
 	events      []utils.JSONMessage
 	events      []utils.JSONMessage
 	listeners   map[string]chan utils.JSONMessage
 	listeners   map[string]chan utils.JSONMessage
-	reqFactory  *utils.HTTPRequestFactory
 	Eng         *engine.Engine
 	Eng         *engine.Engine
 }
 }