Browse Source

Merge pull request #20832 from aaronlehmann/login-endpoint-refactor

Update login to use token handling code from distribution
Sebastiaan van Stijn 9 years ago
parent
commit
17156ba98f

+ 1 - 1
distribution/pull.go

@@ -88,7 +88,7 @@ func Pull(ctx context.Context, ref reference.Named, imagePullConfig *ImagePullCo
 		return err
 		return err
 	}
 	}
 
 
-	endpoints, err := imagePullConfig.RegistryService.LookupPullEndpoints(repoInfo)
+	endpoints, err := imagePullConfig.RegistryService.LookupPullEndpoints(repoInfo.Hostname())
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}

+ 1 - 1
distribution/push.go

@@ -100,7 +100,7 @@ func Push(ctx context.Context, ref reference.Named, imagePushConfig *ImagePushCo
 		return err
 		return err
 	}
 	}
 
 
-	endpoints, err := imagePushConfig.RegistryService.LookupPushEndpoints(repoInfo)
+	endpoints, err := imagePushConfig.RegistryService.LookupPushEndpoints(repoInfo.Hostname())
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}

+ 7 - 38
distribution/registry.go

@@ -5,7 +5,6 @@ import (
 	"net"
 	"net"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
-	"strings"
 	"time"
 	"time"
 
 
 	"github.com/docker/distribution"
 	"github.com/docker/distribution"
@@ -53,48 +52,18 @@ func NewV2Repository(ctx context.Context, repoInfo *registry.RepositoryInfo, end
 
 
 	modifiers := registry.DockerHeaders(dockerversion.DockerUserAgent(), metaHeaders)
 	modifiers := registry.DockerHeaders(dockerversion.DockerUserAgent(), metaHeaders)
 	authTransport := transport.NewTransport(base, modifiers...)
 	authTransport := transport.NewTransport(base, modifiers...)
-	pingClient := &http.Client{
-		Transport: authTransport,
-		Timeout:   15 * time.Second,
-	}
-	endpointStr := strings.TrimRight(endpoint.URL.String(), "/") + "/v2/"
-	req, err := http.NewRequest("GET", endpointStr, nil)
-	if err != nil {
-		return nil, false, fallbackError{err: err}
-	}
-	resp, err := pingClient.Do(req)
-	if err != nil {
-		return nil, false, fallbackError{err: err}
-	}
-	defer resp.Body.Close()
-
-	// We got a HTTP request through, so we're using the right TLS settings.
-	// From this point forward, set transportOK to true in any fallbackError
-	// we return.
 
 
-	v2Version := auth.APIVersion{
-		Type:    "registry",
-		Version: "2.0",
-	}
-
-	versions := auth.APIVersions(resp, registry.DefaultRegistryVersionHeader)
-	for _, pingVersion := range versions {
-		if pingVersion == v2Version {
-			// The version header indicates we're definitely
-			// talking to a v2 registry. So don't allow future
-			// fallbacks to the v1 protocol.
-
-			foundVersion = true
-			break
+	challengeManager, foundVersion, err := registry.PingV2Registry(endpoint, authTransport)
+	if err != nil {
+		transportOK := false
+		if responseErr, ok := err.(registry.PingResponseError); ok {
+			transportOK = true
+			err = responseErr.Err
 		}
 		}
-	}
-
-	challengeManager := auth.NewSimpleChallengeManager()
-	if err := challengeManager.AddResponse(resp); err != nil {
 		return nil, foundVersion, fallbackError{
 		return nil, foundVersion, fallbackError{
 			err:         err,
 			err:         err,
 			confirmedV2: foundVersion,
 			confirmedV2: foundVersion,
-			transportOK: true,
+			transportOK: transportOK,
 		}
 		}
 	}
 	}
 
 

+ 3 - 4
integration-cli/docker_cli_v2_only_test.go

@@ -106,20 +106,19 @@ func (s *DockerRegistrySuite) TestV1(c *check.C) {
 	defer cleanup()
 	defer cleanup()
 
 
 	s.d.Cmd("build", "--file", dockerfileName, ".")
 	s.d.Cmd("build", "--file", dockerfileName, ".")
-	c.Assert(v1Repo, check.Not(check.Equals), 0, check.Commentf("Expected v1 repository access after build"))
+	c.Assert(v1Repo, check.Equals, 1, check.Commentf("Expected v1 repository access after build"))
 
 
 	repoName := fmt.Sprintf("%s/busybox", reg.hostport)
 	repoName := fmt.Sprintf("%s/busybox", reg.hostport)
 	s.d.Cmd("run", repoName)
 	s.d.Cmd("run", repoName)
-	c.Assert(v1Repo, check.Not(check.Equals), 1, check.Commentf("Expected v1 repository access after run"))
+	c.Assert(v1Repo, check.Equals, 2, check.Commentf("Expected v1 repository access after run"))
 
 
 	s.d.Cmd("login", "-u", "richard", "-p", "testtest", reg.hostport)
 	s.d.Cmd("login", "-u", "richard", "-p", "testtest", reg.hostport)
-	c.Assert(v1Logins, check.Not(check.Equals), 0, check.Commentf("Expected v1 login attempt"))
+	c.Assert(v1Logins, check.Equals, 1, check.Commentf("Expected v1 login attempt"))
 
 
 	s.d.Cmd("tag", "busybox", repoName)
 	s.d.Cmd("tag", "busybox", repoName)
 	s.d.Cmd("push", repoName)
 	s.d.Cmd("push", repoName)
 
 
 	c.Assert(v1Repo, check.Equals, 2)
 	c.Assert(v1Repo, check.Equals, 2)
-	c.Assert(v1Pings, check.Equals, 1)
 
 
 	s.d.Cmd("pull", repoName)
 	s.d.Cmd("pull", repoName)
 	c.Assert(v1Repo, check.Equals, 3, check.Commentf("Expected v1 repository access after pull"))
 	c.Assert(v1Repo, check.Equals, 3, check.Commentf("Expected v1 repository access after pull"))

+ 128 - 80
registry/auth.go

@@ -4,28 +4,25 @@ import (
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
 	"net/http"
 	"net/http"
+	"net/url"
 	"strings"
 	"strings"
+	"time"
 
 
 	"github.com/Sirupsen/logrus"
 	"github.com/Sirupsen/logrus"
+	"github.com/docker/distribution/registry/client/auth"
+	"github.com/docker/distribution/registry/client/transport"
 	"github.com/docker/engine-api/types"
 	"github.com/docker/engine-api/types"
 	registrytypes "github.com/docker/engine-api/types/registry"
 	registrytypes "github.com/docker/engine-api/types/registry"
 )
 )
 
 
-// Login tries to register/login to the registry server.
-func Login(authConfig *types.AuthConfig, registryEndpoint *Endpoint) (string, error) {
-	// Separates the v2 registry login logic from the v1 logic.
-	if registryEndpoint.Version == APIVersion2 {
-		return loginV2(authConfig, registryEndpoint, "" /* scope */)
+// loginV1 tries to register/login to the v1 registry server.
+func loginV1(authConfig *types.AuthConfig, apiEndpoint APIEndpoint, userAgent string) (string, error) {
+	registryEndpoint, err := apiEndpoint.ToV1Endpoint(userAgent, nil)
+	if err != nil {
+		return "", err
 	}
 	}
-	return loginV1(authConfig, registryEndpoint)
-}
 
 
-// loginV1 tries to register/login to the v1 registry server.
-func loginV1(authConfig *types.AuthConfig, registryEndpoint *Endpoint) (string, error) {
-	var (
-		err           error
-		serverAddress = authConfig.ServerAddress
-	)
+	serverAddress := registryEndpoint.String()
 
 
 	logrus.Debugf("attempting v1 login to registry endpoint %s", registryEndpoint)
 	logrus.Debugf("attempting v1 login to registry endpoint %s", registryEndpoint)
 
 
@@ -36,10 +33,16 @@ func loginV1(authConfig *types.AuthConfig, registryEndpoint *Endpoint) (string,
 	loginAgainstOfficialIndex := serverAddress == IndexServer
 	loginAgainstOfficialIndex := serverAddress == IndexServer
 
 
 	req, err := http.NewRequest("GET", serverAddress+"users/", nil)
 	req, err := http.NewRequest("GET", serverAddress+"users/", nil)
+	if err != nil {
+		return "", err
+	}
 	req.SetBasicAuth(authConfig.Username, authConfig.Password)
 	req.SetBasicAuth(authConfig.Username, authConfig.Password)
 	resp, err := registryEndpoint.client.Do(req)
 	resp, err := registryEndpoint.client.Do(req)
 	if err != nil {
 	if err != nil {
-		return "", err
+		// fallback when request could not be completed
+		return "", fallbackError{
+			err: err,
+		}
 	}
 	}
 	defer resp.Body.Close()
 	defer resp.Body.Close()
 	body, err := ioutil.ReadAll(resp.Body)
 	body, err := ioutil.ReadAll(resp.Body)
@@ -68,97 +71,82 @@ func loginV1(authConfig *types.AuthConfig, registryEndpoint *Endpoint) (string,
 	}
 	}
 }
 }
 
 
-// loginV2 tries to login to the v2 registry server. The given registry endpoint has been
-// pinged or setup with a list of authorization challenges. Each of these challenges are
-// tried until one of them succeeds. Currently supported challenge schemes are:
-// 		HTTP Basic Authorization
-// 		Token Authorization with a separate token issuing server
-// NOTE: the v2 logic does not attempt to create a user account if one doesn't exist. For
-// now, users should create their account through other means like directly from a web page
-// served by the v2 registry service provider. Whether this will be supported in the future
-// is to be determined.
-func loginV2(authConfig *types.AuthConfig, registryEndpoint *Endpoint, scope string) (string, error) {
-	logrus.Debugf("attempting v2 login to registry endpoint %s", registryEndpoint)
-	var (
-		err       error
-		allErrors []error
-	)
-
-	for _, challenge := range registryEndpoint.AuthChallenges {
-		params := make(map[string]string, len(challenge.Parameters)+1)
-		for k, v := range challenge.Parameters {
-			params[k] = v
-		}
-		params["scope"] = scope
-		logrus.Debugf("trying %q auth challenge with params %v", challenge.Scheme, params)
-
-		switch strings.ToLower(challenge.Scheme) {
-		case "basic":
-			err = tryV2BasicAuthLogin(authConfig, params, registryEndpoint)
-		case "bearer":
-			err = tryV2TokenAuthLogin(authConfig, params, registryEndpoint)
-		default:
-			// Unsupported challenge types are explicitly skipped.
-			err = fmt.Errorf("unsupported auth scheme: %q", challenge.Scheme)
-		}
-
-		if err == nil {
-			return "Login Succeeded", nil
-		}
+type loginCredentialStore struct {
+	authConfig *types.AuthConfig
+}
 
 
-		logrus.Debugf("error trying auth challenge %q: %s", challenge.Scheme, err)
+func (lcs loginCredentialStore) Basic(*url.URL) (string, string) {
+	return lcs.authConfig.Username, lcs.authConfig.Password
+}
 
 
-		allErrors = append(allErrors, err)
-	}
+type fallbackError struct {
+	err error
+}
 
 
-	return "", fmt.Errorf("no successful auth challenge for %s - errors: %s", registryEndpoint, allErrors)
+func (err fallbackError) Error() string {
+	return err.err.Error()
 }
 }
 
 
-func tryV2BasicAuthLogin(authConfig *types.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error {
-	req, err := http.NewRequest("GET", registryEndpoint.Path(""), nil)
-	if err != nil {
-		return err
-	}
+// loginV2 tries to login to the v2 registry server. The given registry
+// endpoint will be pinged to get authorization challenges. These challenges
+// will be used to authenticate against the registry to validate credentials.
+func loginV2(authConfig *types.AuthConfig, endpoint APIEndpoint, userAgent string) (string, error) {
+	logrus.Debugf("attempting v2 login to registry endpoint %s", endpoint)
 
 
-	req.SetBasicAuth(authConfig.Username, authConfig.Password)
+	modifiers := DockerHeaders(userAgent, nil)
+	authTransport := transport.NewTransport(NewTransport(endpoint.TLSConfig), modifiers...)
 
 
-	resp, err := registryEndpoint.client.Do(req)
+	challengeManager, foundV2, err := PingV2Registry(endpoint, authTransport)
 	if err != nil {
 	if err != nil {
-		return err
+		if !foundV2 {
+			err = fallbackError{err: err}
+		}
+		return "", err
 	}
 	}
-	defer resp.Body.Close()
 
 
-	if resp.StatusCode != http.StatusOK {
-		return fmt.Errorf("basic auth attempt to %s realm %q failed with status: %d %s", registryEndpoint, params["realm"], resp.StatusCode, http.StatusText(resp.StatusCode))
+	creds := loginCredentialStore{
+		authConfig: authConfig,
 	}
 	}
 
 
-	return nil
-}
+	tokenHandler := auth.NewTokenHandler(authTransport, creds, "")
+	basicHandler := auth.NewBasicHandler(creds)
+	modifiers = append(modifiers, auth.NewAuthorizer(challengeManager, tokenHandler, basicHandler))
+	tr := transport.NewTransport(authTransport, modifiers...)
 
 
-func tryV2TokenAuthLogin(authConfig *types.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error {
-	token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint)
-	if err != nil {
-		return err
+	loginClient := &http.Client{
+		Transport: tr,
+		Timeout:   15 * time.Second,
 	}
 	}
 
 
-	req, err := http.NewRequest("GET", registryEndpoint.Path(""), nil)
+	endpointStr := strings.TrimRight(endpoint.URL.String(), "/") + "/v2/"
+	req, err := http.NewRequest("GET", endpointStr, nil)
 	if err != nil {
 	if err != nil {
-		return err
+		if !foundV2 {
+			err = fallbackError{err: err}
+		}
+		return "", err
 	}
 	}
 
 
-	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
-
-	resp, err := registryEndpoint.client.Do(req)
+	resp, err := loginClient.Do(req)
 	if err != nil {
 	if err != nil {
-		return err
+		if !foundV2 {
+			err = fallbackError{err: err}
+		}
+		return "", err
 	}
 	}
 	defer resp.Body.Close()
 	defer resp.Body.Close()
 
 
 	if resp.StatusCode != http.StatusOK {
 	if resp.StatusCode != http.StatusOK {
-		return fmt.Errorf("token auth attempt to %s realm %q failed with status: %d %s", registryEndpoint, params["realm"], resp.StatusCode, http.StatusText(resp.StatusCode))
+		// TODO(dmcgowan): Attempt to further interpret result, status code and error code string
+		err := fmt.Errorf("login attempt to %s failed with status: %d %s", endpointStr, resp.StatusCode, http.StatusText(resp.StatusCode))
+		if !foundV2 {
+			err = fallbackError{err: err}
+		}
+		return "", err
 	}
 	}
 
 
-	return nil
+	return "Login Succeeded", nil
+
 }
 }
 
 
 // ResolveAuthConfig matches an auth configuration to a server address or a URL
 // ResolveAuthConfig matches an auth configuration to a server address or a URL
@@ -193,3 +181,63 @@ func ResolveAuthConfig(authConfigs map[string]types.AuthConfig, index *registryt
 	// When all else fails, return an empty auth config
 	// When all else fails, return an empty auth config
 	return types.AuthConfig{}
 	return types.AuthConfig{}
 }
 }
+
+// PingResponseError is used when the response from a ping
+// was received but invalid.
+type PingResponseError struct {
+	Err error
+}
+
+func (err PingResponseError) Error() string {
+	return err.Error()
+}
+
+// PingV2Registry attempts to ping a v2 registry and on success return a
+// challenge manager for the supported authentication types and
+// whether v2 was confirmed by the response. If a response is received but
+// cannot be interpreted a PingResponseError will be returned.
+func PingV2Registry(endpoint APIEndpoint, transport http.RoundTripper) (auth.ChallengeManager, bool, error) {
+	var (
+		foundV2   = false
+		v2Version = auth.APIVersion{
+			Type:    "registry",
+			Version: "2.0",
+		}
+	)
+
+	pingClient := &http.Client{
+		Transport: transport,
+		Timeout:   15 * time.Second,
+	}
+	endpointStr := strings.TrimRight(endpoint.URL.String(), "/") + "/v2/"
+	req, err := http.NewRequest("GET", endpointStr, nil)
+	if err != nil {
+		return nil, false, err
+	}
+	resp, err := pingClient.Do(req)
+	if err != nil {
+		return nil, false, err
+	}
+	defer resp.Body.Close()
+
+	versions := auth.APIVersions(resp, DefaultRegistryVersionHeader)
+	for _, pingVersion := range versions {
+		if pingVersion == v2Version {
+			// The version header indicates we're definitely
+			// talking to a v2 registry. So don't allow future
+			// fallbacks to the v1 protocol.
+
+			foundV2 = true
+			break
+		}
+	}
+
+	challengeManager := auth.NewSimpleChallengeManager()
+	if err := challengeManager.AddResponse(resp); err != nil {
+		return nil, foundV2, PingResponseError{
+			Err: err,
+		}
+	}
+
+	return challengeManager, foundV2, nil
+}

+ 0 - 150
registry/authchallenge.go

@@ -1,150 +0,0 @@
-package registry
-
-import (
-	"net/http"
-	"strings"
-)
-
-// Octet types from RFC 2616.
-type octetType byte
-
-// AuthorizationChallenge carries information
-// from a WWW-Authenticate response header.
-type AuthorizationChallenge struct {
-	Scheme     string
-	Parameters map[string]string
-}
-
-var octetTypes [256]octetType
-
-const (
-	isToken octetType = 1 << iota
-	isSpace
-)
-
-func init() {
-	// OCTET      = <any 8-bit sequence of data>
-	// CHAR       = <any US-ASCII character (octets 0 - 127)>
-	// CTL        = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
-	// CR         = <US-ASCII CR, carriage return (13)>
-	// LF         = <US-ASCII LF, linefeed (10)>
-	// SP         = <US-ASCII SP, space (32)>
-	// HT         = <US-ASCII HT, horizontal-tab (9)>
-	// <">        = <US-ASCII double-quote mark (34)>
-	// CRLF       = CR LF
-	// LWS        = [CRLF] 1*( SP | HT )
-	// TEXT       = <any OCTET except CTLs, but including LWS>
-	// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
-	//              | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
-	// token      = 1*<any CHAR except CTLs or separators>
-	// qdtext     = <any TEXT except <">>
-
-	for c := 0; c < 256; c++ {
-		var t octetType
-		isCtl := c <= 31 || c == 127
-		isChar := 0 <= c && c <= 127
-		isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0
-		if strings.IndexRune(" \t\r\n", rune(c)) >= 0 {
-			t |= isSpace
-		}
-		if isChar && !isCtl && !isSeparator {
-			t |= isToken
-		}
-		octetTypes[c] = t
-	}
-}
-
-func parseAuthHeader(header http.Header) []*AuthorizationChallenge {
-	var challenges []*AuthorizationChallenge
-	for _, h := range header[http.CanonicalHeaderKey("WWW-Authenticate")] {
-		v, p := parseValueAndParams(h)
-		if v != "" {
-			challenges = append(challenges, &AuthorizationChallenge{Scheme: v, Parameters: p})
-		}
-	}
-	return challenges
-}
-
-func parseValueAndParams(header string) (value string, params map[string]string) {
-	params = make(map[string]string)
-	value, s := expectToken(header)
-	if value == "" {
-		return
-	}
-	value = strings.ToLower(value)
-	s = "," + skipSpace(s)
-	for strings.HasPrefix(s, ",") {
-		var pkey string
-		pkey, s = expectToken(skipSpace(s[1:]))
-		if pkey == "" {
-			return
-		}
-		if !strings.HasPrefix(s, "=") {
-			return
-		}
-		var pvalue string
-		pvalue, s = expectTokenOrQuoted(s[1:])
-		if pvalue == "" {
-			return
-		}
-		pkey = strings.ToLower(pkey)
-		params[pkey] = pvalue
-		s = skipSpace(s)
-	}
-	return
-}
-
-func skipSpace(s string) (rest string) {
-	i := 0
-	for ; i < len(s); i++ {
-		if octetTypes[s[i]]&isSpace == 0 {
-			break
-		}
-	}
-	return s[i:]
-}
-
-func expectToken(s string) (token, rest string) {
-	i := 0
-	for ; i < len(s); i++ {
-		if octetTypes[s[i]]&isToken == 0 {
-			break
-		}
-	}
-	return s[:i], s[i:]
-}
-
-func expectTokenOrQuoted(s string) (value string, rest string) {
-	if !strings.HasPrefix(s, "\"") {
-		return expectToken(s)
-	}
-	s = s[1:]
-	for i := 0; i < len(s); i++ {
-		switch s[i] {
-		case '"':
-			return s[:i], s[i+1:]
-		case '\\':
-			p := make([]byte, len(s)-1)
-			j := copy(p, s[:i])
-			escape := true
-			for i = i + i; i < len(s); i++ {
-				b := s[i]
-				switch {
-				case escape:
-					escape = false
-					p[j] = b
-					j++
-				case b == '\\':
-					escape = true
-				case b == '"':
-					return string(p[:j]), s[i+1:]
-				default:
-					p[j] = b
-					j++
-				}
-			}
-			return "", ""
-		}
-	}
-	return "", ""
-}

+ 3 - 0
registry/config.go

@@ -49,6 +49,9 @@ var (
 	V2Only = false
 	V2Only = false
 )
 )
 
 
+// for mocking in unit tests
+var lookupIP = net.LookupIP
+
 // InstallFlags adds command-line options to the top-level flag parser for
 // InstallFlags adds command-line options to the top-level flag parser for
 // the current process.
 // the current process.
 func (options *Options) InstallFlags(cmd *flag.FlagSet, usageFn func(string) string) {
 func (options *Options) InstallFlags(cmd *flag.FlagSet, usageFn func(string) string) {

+ 24 - 39
registry/endpoint_test.go

@@ -14,12 +14,13 @@ func TestEndpointParse(t *testing.T) {
 	}{
 	}{
 		{IndexServer, IndexServer},
 		{IndexServer, IndexServer},
 		{"http://0.0.0.0:5000/v1/", "http://0.0.0.0:5000/v1/"},
 		{"http://0.0.0.0:5000/v1/", "http://0.0.0.0:5000/v1/"},
-		{"http://0.0.0.0:5000/v2/", "http://0.0.0.0:5000/v2/"},
-		{"http://0.0.0.0:5000", "http://0.0.0.0:5000/v0/"},
-		{"0.0.0.0:5000", "https://0.0.0.0:5000/v0/"},
+		{"http://0.0.0.0:5000", "http://0.0.0.0:5000/v1/"},
+		{"0.0.0.0:5000", "https://0.0.0.0:5000/v1/"},
+		{"http://0.0.0.0:5000/nonversion/", "http://0.0.0.0:5000/nonversion/v1/"},
+		{"http://0.0.0.0:5000/v0/", "http://0.0.0.0:5000/v0/v1/"},
 	}
 	}
 	for _, td := range testData {
 	for _, td := range testData {
-		e, err := newEndpointFromStr(td.str, nil, "", nil)
+		e, err := newV1EndpointFromStr(td.str, nil, "", nil)
 		if err != nil {
 		if err != nil {
 			t.Errorf("%q: %s", td.str, err)
 			t.Errorf("%q: %s", td.str, err)
 		}
 		}
@@ -33,21 +34,26 @@ func TestEndpointParse(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestEndpointParseInvalid(t *testing.T) {
+	testData := []string{
+		"http://0.0.0.0:5000/v2/",
+	}
+	for _, td := range testData {
+		e, err := newV1EndpointFromStr(td, nil, "", nil)
+		if err == nil {
+			t.Errorf("expected error parsing %q: parsed as %q", td, e)
+		}
+	}
+}
+
 // Ensure that a registry endpoint that responds with a 401 only is determined
 // Ensure that a registry endpoint that responds with a 401 only is determined
-// to be a v1 registry unless it includes a valid v2 API header.
-func TestValidateEndpointAmbiguousAPIVersion(t *testing.T) {
+// to be a valid v1 registry endpoint
+func TestValidateEndpoint(t *testing.T) {
 	requireBasicAuthHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	requireBasicAuthHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		w.Header().Add("WWW-Authenticate", `Basic realm="localhost"`)
 		w.Header().Add("WWW-Authenticate", `Basic realm="localhost"`)
 		w.WriteHeader(http.StatusUnauthorized)
 		w.WriteHeader(http.StatusUnauthorized)
 	})
 	})
 
 
-	requireBasicAuthHandlerV2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		// This mock server supports v2.0, v2.1, v42.0, and v100.0
-		w.Header().Add("Docker-Distribution-API-Version", "registry/100.0 registry/42.0")
-		w.Header().Add("Docker-Distribution-API-Version", "registry/2.0 registry/2.1")
-		requireBasicAuthHandler.ServeHTTP(w, r)
-	})
-
 	// Make a test server which should validate as a v1 server.
 	// Make a test server which should validate as a v1 server.
 	testServer := httptest.NewServer(requireBasicAuthHandler)
 	testServer := httptest.NewServer(requireBasicAuthHandler)
 	defer testServer.Close()
 	defer testServer.Close()
@@ -57,37 +63,16 @@ func TestValidateEndpointAmbiguousAPIVersion(t *testing.T) {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
-	testEndpoint := Endpoint{
-		URL:     testServerURL,
-		Version: APIVersionUnknown,
-		client:  HTTPClient(NewTransport(nil)),
+	testEndpoint := V1Endpoint{
+		URL:    testServerURL,
+		client: HTTPClient(NewTransport(nil)),
 	}
 	}
 
 
 	if err = validateEndpoint(&testEndpoint); err != nil {
 	if err = validateEndpoint(&testEndpoint); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
-	if testEndpoint.Version != APIVersion1 {
-		t.Fatalf("expected endpoint to validate to %d, got %d", APIVersion1, testEndpoint.Version)
-	}
-
-	// Make a test server which should validate as a v2 server.
-	testServer = httptest.NewServer(requireBasicAuthHandlerV2)
-	defer testServer.Close()
-
-	testServerURL, err = url.Parse(testServer.URL)
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	testEndpoint.URL = testServerURL
-	testEndpoint.Version = APIVersionUnknown
-
-	if err = validateEndpoint(&testEndpoint); err != nil {
-		t.Fatal(err)
-	}
-
-	if testEndpoint.Version != APIVersion2 {
-		t.Fatalf("expected endpoint to validate to %d, got %d", APIVersion2, testEndpoint.Version)
+	if testEndpoint.URL.Scheme != "http" {
+		t.Fatalf("expecting to validate endpoint as http, got url %s", testEndpoint.String())
 	}
 	}
 }
 }

+ 51 - 142
registry/endpoint.go → registry/endpoint_v1.go

@@ -5,60 +5,35 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
-	"net"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
 	"strings"
 	"strings"
 
 
 	"github.com/Sirupsen/logrus"
 	"github.com/Sirupsen/logrus"
-	"github.com/docker/distribution/registry/api/v2"
 	"github.com/docker/distribution/registry/client/transport"
 	"github.com/docker/distribution/registry/client/transport"
 	registrytypes "github.com/docker/engine-api/types/registry"
 	registrytypes "github.com/docker/engine-api/types/registry"
 )
 )
 
 
-// for mocking in unit tests
-var lookupIP = net.LookupIP
-
-// scans string for api version in the URL path. returns the trimmed address, if version found, string and API version.
-func scanForAPIVersion(address string) (string, APIVersion) {
-	var (
-		chunks        []string
-		apiVersionStr string
-	)
-
-	if strings.HasSuffix(address, "/") {
-		address = address[:len(address)-1]
-	}
-
-	chunks = strings.Split(address, "/")
-	apiVersionStr = chunks[len(chunks)-1]
-
-	for k, v := range apiVersions {
-		if apiVersionStr == v {
-			address = strings.Join(chunks[:len(chunks)-1], "/")
-			return address, k
-		}
-	}
-
-	return address, APIVersionUnknown
+// V1Endpoint stores basic information about a V1 registry endpoint.
+type V1Endpoint struct {
+	client   *http.Client
+	URL      *url.URL
+	IsSecure bool
 }
 }
 
 
-// NewEndpoint parses the given address to return a registry endpoint.  v can be used to
+// NewV1Endpoint parses the given address to return a registry endpoint.  v can be used to
 // specify a specific endpoint version
 // specify a specific endpoint version
-func NewEndpoint(index *registrytypes.IndexInfo, userAgent string, metaHeaders http.Header, v APIVersion) (*Endpoint, error) {
+func NewV1Endpoint(index *registrytypes.IndexInfo, userAgent string, metaHeaders http.Header) (*V1Endpoint, error) {
 	tlsConfig, err := newTLSConfig(index.Name, index.Secure)
 	tlsConfig, err := newTLSConfig(index.Name, index.Secure)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	endpoint, err := newEndpointFromStr(GetAuthConfigKey(index), tlsConfig, userAgent, metaHeaders)
+	endpoint, err := newV1EndpointFromStr(GetAuthConfigKey(index), tlsConfig, userAgent, metaHeaders)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	if v != APIVersionUnknown {
-		endpoint.Version = v
-	}
 	if err := validateEndpoint(endpoint); err != nil {
 	if err := validateEndpoint(endpoint); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -66,7 +41,7 @@ func NewEndpoint(index *registrytypes.IndexInfo, userAgent string, metaHeaders h
 	return endpoint, nil
 	return endpoint, nil
 }
 }
 
 
-func validateEndpoint(endpoint *Endpoint) error {
+func validateEndpoint(endpoint *V1Endpoint) error {
 	logrus.Debugf("pinging registry endpoint %s", endpoint)
 	logrus.Debugf("pinging registry endpoint %s", endpoint)
 
 
 	// Try HTTPS ping to registry
 	// Try HTTPS ping to registry
@@ -93,11 +68,10 @@ func validateEndpoint(endpoint *Endpoint) error {
 	return nil
 	return nil
 }
 }
 
 
-func newEndpoint(address url.URL, tlsConfig *tls.Config, userAgent string, metaHeaders http.Header) (*Endpoint, error) {
-	endpoint := &Endpoint{
+func newV1Endpoint(address url.URL, tlsConfig *tls.Config, userAgent string, metaHeaders http.Header) (*V1Endpoint, error) {
+	endpoint := &V1Endpoint{
 		IsSecure: (tlsConfig == nil || !tlsConfig.InsecureSkipVerify),
 		IsSecure: (tlsConfig == nil || !tlsConfig.InsecureSkipVerify),
 		URL:      new(url.URL),
 		URL:      new(url.URL),
-		Version:  APIVersionUnknown,
 	}
 	}
 
 
 	*endpoint.URL = address
 	*endpoint.URL = address
@@ -108,86 +82,69 @@ func newEndpoint(address url.URL, tlsConfig *tls.Config, userAgent string, metaH
 	return endpoint, nil
 	return endpoint, nil
 }
 }
 
 
-func newEndpointFromStr(address string, tlsConfig *tls.Config, userAgent string, metaHeaders http.Header) (*Endpoint, error) {
+// trimV1Address trims the version off the address and returns the
+// trimmed address or an error if there is a non-V1 version.
+func trimV1Address(address string) (string, error) {
+	var (
+		chunks        []string
+		apiVersionStr string
+	)
+
+	if strings.HasSuffix(address, "/") {
+		address = address[:len(address)-1]
+	}
+
+	chunks = strings.Split(address, "/")
+	apiVersionStr = chunks[len(chunks)-1]
+	if apiVersionStr == "v1" {
+		return strings.Join(chunks[:len(chunks)-1], "/"), nil
+	}
+
+	for k, v := range apiVersions {
+		if k != APIVersion1 && apiVersionStr == v {
+			return "", fmt.Errorf("unsupported V1 version path %s", apiVersionStr)
+		}
+	}
+
+	return address, nil
+}
+
+func newV1EndpointFromStr(address string, tlsConfig *tls.Config, userAgent string, metaHeaders http.Header) (*V1Endpoint, error) {
 	if !strings.HasPrefix(address, "http://") && !strings.HasPrefix(address, "https://") {
 	if !strings.HasPrefix(address, "http://") && !strings.HasPrefix(address, "https://") {
 		address = "https://" + address
 		address = "https://" + address
 	}
 	}
 
 
-	trimmedAddress, detectedVersion := scanForAPIVersion(address)
+	address, err := trimV1Address(address)
+	if err != nil {
+		return nil, err
+	}
 
 
-	uri, err := url.Parse(trimmedAddress)
+	uri, err := url.Parse(address)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	endpoint, err := newEndpoint(*uri, tlsConfig, userAgent, metaHeaders)
+	endpoint, err := newV1Endpoint(*uri, tlsConfig, userAgent, metaHeaders)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	endpoint.Version = detectedVersion
 	return endpoint, nil
 	return endpoint, nil
 }
 }
 
 
-// Endpoint stores basic information about a registry endpoint.
-type Endpoint struct {
-	client         *http.Client
-	URL            *url.URL
-	Version        APIVersion
-	IsSecure       bool
-	AuthChallenges []*AuthorizationChallenge
-	URLBuilder     *v2.URLBuilder
-}
-
 // Get the formatted URL for the root of this registry Endpoint
 // Get the formatted URL for the root of this registry Endpoint
-func (e *Endpoint) String() string {
-	return fmt.Sprintf("%s/v%d/", e.URL, e.Version)
-}
-
-// VersionString returns a formatted string of this
-// endpoint address using the given API Version.
-func (e *Endpoint) VersionString(version APIVersion) string {
-	return fmt.Sprintf("%s/v%d/", e.URL, version)
+func (e *V1Endpoint) String() string {
+	return e.URL.String() + "/v1/"
 }
 }
 
 
 // Path returns a formatted string for the URL
 // Path returns a formatted string for the URL
 // of this endpoint with the given path appended.
 // of this endpoint with the given path appended.
-func (e *Endpoint) Path(path string) string {
-	return fmt.Sprintf("%s/v%d/%s", e.URL, e.Version, path)
-}
-
-// Ping pings the remote endpoint with v2 and v1 pings to determine the API
-// version. It returns a PingResult containing the discovered version. The
-// PingResult also indicates whether the registry is standalone or not.
-func (e *Endpoint) Ping() (PingResult, error) {
-	// The ping logic to use is determined by the registry endpoint version.
-	switch e.Version {
-	case APIVersion1:
-		return e.pingV1()
-	case APIVersion2:
-		return e.pingV2()
-	}
-
-	// APIVersionUnknown
-	// We should try v2 first...
-	e.Version = APIVersion2
-	regInfo, errV2 := e.pingV2()
-	if errV2 == nil {
-		return regInfo, nil
-	}
-
-	// ... then fallback to v1.
-	e.Version = APIVersion1
-	regInfo, errV1 := e.pingV1()
-	if errV1 == nil {
-		return regInfo, nil
-	}
-
-	e.Version = APIVersionUnknown
-	return PingResult{}, fmt.Errorf("unable to ping registry endpoint %s\nv2 ping attempt failed with error: %s\n v1 ping attempt failed with error: %s", e, errV2, errV1)
+func (e *V1Endpoint) Path(path string) string {
+	return e.URL.String() + "/v1/" + path
 }
 }
 
 
-func (e *Endpoint) pingV1() (PingResult, error) {
+// Ping returns a PingResult which indicates whether the registry is standalone or not.
+func (e *V1Endpoint) Ping() (PingResult, error) {
 	logrus.Debugf("attempting v1 ping for registry endpoint %s", e)
 	logrus.Debugf("attempting v1 ping for registry endpoint %s", e)
 
 
 	if e.String() == IndexServer {
 	if e.String() == IndexServer {
@@ -240,51 +197,3 @@ func (e *Endpoint) pingV1() (PingResult, error) {
 	logrus.Debugf("PingResult.Standalone: %t", info.Standalone)
 	logrus.Debugf("PingResult.Standalone: %t", info.Standalone)
 	return info, nil
 	return info, nil
 }
 }
-
-func (e *Endpoint) pingV2() (PingResult, error) {
-	logrus.Debugf("attempting v2 ping for registry endpoint %s", e)
-
-	req, err := http.NewRequest("GET", e.Path(""), nil)
-	if err != nil {
-		return PingResult{}, err
-	}
-
-	resp, err := e.client.Do(req)
-	if err != nil {
-		return PingResult{}, err
-	}
-	defer resp.Body.Close()
-
-	// The endpoint may have multiple supported versions.
-	// Ensure it supports the v2 Registry API.
-	var supportsV2 bool
-
-HeaderLoop:
-	for _, supportedVersions := range resp.Header[http.CanonicalHeaderKey("Docker-Distribution-API-Version")] {
-		for _, versionName := range strings.Fields(supportedVersions) {
-			if versionName == "registry/2.0" {
-				supportsV2 = true
-				break HeaderLoop
-			}
-		}
-	}
-
-	if !supportsV2 {
-		return PingResult{}, fmt.Errorf("%s does not appear to be a v2 registry endpoint", e)
-	}
-
-	if resp.StatusCode == http.StatusOK {
-		// It would seem that no authentication/authorization is required.
-		// So we don't need to parse/add any authorization schemes.
-		return PingResult{Standalone: true}, nil
-	}
-
-	if resp.StatusCode == http.StatusUnauthorized {
-		// Parse the WWW-Authenticate Header and store the challenges
-		// on this endpoint object.
-		e.AuthChallenges = parseAuthHeader(resp.Header)
-		return PingResult{}, nil
-	}
-
-	return PingResult{}, fmt.Errorf("v2 registry endpoint returned status %d: %q", resp.StatusCode, http.StatusText(resp.StatusCode))
-}

+ 9 - 27
registry/registry_test.go

@@ -25,7 +25,7 @@ const (
 
 
 func spawnTestRegistrySession(t *testing.T) *Session {
 func spawnTestRegistrySession(t *testing.T) *Session {
 	authConfig := &types.AuthConfig{}
 	authConfig := &types.AuthConfig{}
-	endpoint, err := NewEndpoint(makeIndex("/v1/"), "", nil, APIVersionUnknown)
+	endpoint, err := NewV1Endpoint(makeIndex("/v1/"), "", nil)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -53,7 +53,7 @@ func spawnTestRegistrySession(t *testing.T) *Session {
 
 
 func TestPingRegistryEndpoint(t *testing.T) {
 func TestPingRegistryEndpoint(t *testing.T) {
 	testPing := func(index *registrytypes.IndexInfo, expectedStandalone bool, assertMessage string) {
 	testPing := func(index *registrytypes.IndexInfo, expectedStandalone bool, assertMessage string) {
-		ep, err := NewEndpoint(index, "", nil, APIVersionUnknown)
+		ep, err := NewV1Endpoint(index, "", nil)
 		if err != nil {
 		if err != nil {
 			t.Fatal(err)
 			t.Fatal(err)
 		}
 		}
@@ -72,8 +72,8 @@ func TestPingRegistryEndpoint(t *testing.T) {
 
 
 func TestEndpoint(t *testing.T) {
 func TestEndpoint(t *testing.T) {
 	// Simple wrapper to fail test if err != nil
 	// Simple wrapper to fail test if err != nil
-	expandEndpoint := func(index *registrytypes.IndexInfo) *Endpoint {
-		endpoint, err := NewEndpoint(index, "", nil, APIVersionUnknown)
+	expandEndpoint := func(index *registrytypes.IndexInfo) *V1Endpoint {
+		endpoint, err := NewV1Endpoint(index, "", nil)
 		if err != nil {
 		if err != nil {
 			t.Fatal(err)
 			t.Fatal(err)
 		}
 		}
@@ -82,7 +82,7 @@ func TestEndpoint(t *testing.T) {
 
 
 	assertInsecureIndex := func(index *registrytypes.IndexInfo) {
 	assertInsecureIndex := func(index *registrytypes.IndexInfo) {
 		index.Secure = true
 		index.Secure = true
-		_, err := NewEndpoint(index, "", nil, APIVersionUnknown)
+		_, err := NewV1Endpoint(index, "", nil)
 		assertNotEqual(t, err, nil, index.Name+": Expected error for insecure index")
 		assertNotEqual(t, err, nil, index.Name+": Expected error for insecure index")
 		assertEqual(t, strings.Contains(err.Error(), "insecure-registry"), true, index.Name+": Expected insecure-registry  error for insecure index")
 		assertEqual(t, strings.Contains(err.Error(), "insecure-registry"), true, index.Name+": Expected insecure-registry  error for insecure index")
 		index.Secure = false
 		index.Secure = false
@@ -90,7 +90,7 @@ func TestEndpoint(t *testing.T) {
 
 
 	assertSecureIndex := func(index *registrytypes.IndexInfo) {
 	assertSecureIndex := func(index *registrytypes.IndexInfo) {
 		index.Secure = true
 		index.Secure = true
-		_, err := NewEndpoint(index, "", nil, APIVersionUnknown)
+		_, err := NewV1Endpoint(index, "", nil)
 		assertNotEqual(t, err, nil, index.Name+": Expected cert error for secure index")
 		assertNotEqual(t, err, nil, index.Name+": Expected cert error for secure index")
 		assertEqual(t, strings.Contains(err.Error(), "certificate signed by unknown authority"), true, index.Name+": Expected cert error for secure index")
 		assertEqual(t, strings.Contains(err.Error(), "certificate signed by unknown authority"), true, index.Name+": Expected cert error for secure index")
 		index.Secure = false
 		index.Secure = false
@@ -100,51 +100,33 @@ func TestEndpoint(t *testing.T) {
 	index.Name = makeURL("/v1/")
 	index.Name = makeURL("/v1/")
 	endpoint := expandEndpoint(index)
 	endpoint := expandEndpoint(index)
 	assertEqual(t, endpoint.String(), index.Name, "Expected endpoint to be "+index.Name)
 	assertEqual(t, endpoint.String(), index.Name, "Expected endpoint to be "+index.Name)
-	if endpoint.Version != APIVersion1 {
-		t.Fatal("Expected endpoint to be v1")
-	}
 	assertInsecureIndex(index)
 	assertInsecureIndex(index)
 
 
 	index.Name = makeURL("")
 	index.Name = makeURL("")
 	endpoint = expandEndpoint(index)
 	endpoint = expandEndpoint(index)
 	assertEqual(t, endpoint.String(), index.Name+"/v1/", index.Name+": Expected endpoint to be "+index.Name+"/v1/")
 	assertEqual(t, endpoint.String(), index.Name+"/v1/", index.Name+": Expected endpoint to be "+index.Name+"/v1/")
-	if endpoint.Version != APIVersion1 {
-		t.Fatal("Expected endpoint to be v1")
-	}
 	assertInsecureIndex(index)
 	assertInsecureIndex(index)
 
 
 	httpURL := makeURL("")
 	httpURL := makeURL("")
 	index.Name = strings.SplitN(httpURL, "://", 2)[1]
 	index.Name = strings.SplitN(httpURL, "://", 2)[1]
 	endpoint = expandEndpoint(index)
 	endpoint = expandEndpoint(index)
 	assertEqual(t, endpoint.String(), httpURL+"/v1/", index.Name+": Expected endpoint to be "+httpURL+"/v1/")
 	assertEqual(t, endpoint.String(), httpURL+"/v1/", index.Name+": Expected endpoint to be "+httpURL+"/v1/")
-	if endpoint.Version != APIVersion1 {
-		t.Fatal("Expected endpoint to be v1")
-	}
 	assertInsecureIndex(index)
 	assertInsecureIndex(index)
 
 
 	index.Name = makeHTTPSURL("/v1/")
 	index.Name = makeHTTPSURL("/v1/")
 	endpoint = expandEndpoint(index)
 	endpoint = expandEndpoint(index)
 	assertEqual(t, endpoint.String(), index.Name, "Expected endpoint to be "+index.Name)
 	assertEqual(t, endpoint.String(), index.Name, "Expected endpoint to be "+index.Name)
-	if endpoint.Version != APIVersion1 {
-		t.Fatal("Expected endpoint to be v1")
-	}
 	assertSecureIndex(index)
 	assertSecureIndex(index)
 
 
 	index.Name = makeHTTPSURL("")
 	index.Name = makeHTTPSURL("")
 	endpoint = expandEndpoint(index)
 	endpoint = expandEndpoint(index)
 	assertEqual(t, endpoint.String(), index.Name+"/v1/", index.Name+": Expected endpoint to be "+index.Name+"/v1/")
 	assertEqual(t, endpoint.String(), index.Name+"/v1/", index.Name+": Expected endpoint to be "+index.Name+"/v1/")
-	if endpoint.Version != APIVersion1 {
-		t.Fatal("Expected endpoint to be v1")
-	}
 	assertSecureIndex(index)
 	assertSecureIndex(index)
 
 
 	httpsURL := makeHTTPSURL("")
 	httpsURL := makeHTTPSURL("")
 	index.Name = strings.SplitN(httpsURL, "://", 2)[1]
 	index.Name = strings.SplitN(httpsURL, "://", 2)[1]
 	endpoint = expandEndpoint(index)
 	endpoint = expandEndpoint(index)
 	assertEqual(t, endpoint.String(), httpsURL+"/v1/", index.Name+": Expected endpoint to be "+httpsURL+"/v1/")
 	assertEqual(t, endpoint.String(), httpsURL+"/v1/", index.Name+": Expected endpoint to be "+httpsURL+"/v1/")
-	if endpoint.Version != APIVersion1 {
-		t.Fatal("Expected endpoint to be v1")
-	}
 	assertSecureIndex(index)
 	assertSecureIndex(index)
 
 
 	badEndpoints := []string{
 	badEndpoints := []string{
@@ -156,7 +138,7 @@ func TestEndpoint(t *testing.T) {
 	}
 	}
 	for _, address := range badEndpoints {
 	for _, address := range badEndpoints {
 		index.Name = address
 		index.Name = address
-		_, err := NewEndpoint(index, "", nil, APIVersionUnknown)
+		_, err := NewV1Endpoint(index, "", nil)
 		checkNotEqual(t, err, nil, "Expected error while expanding bad endpoint")
 		checkNotEqual(t, err, nil, "Expected error while expanding bad endpoint")
 	}
 	}
 }
 }
@@ -685,7 +667,7 @@ func TestMirrorEndpointLookup(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Error(err)
 		t.Error(err)
 	}
 	}
-	pushAPIEndpoints, err := s.LookupPushEndpoints(imageName)
+	pushAPIEndpoints, err := s.LookupPushEndpoints(imageName.Hostname())
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -693,7 +675,7 @@ func TestMirrorEndpointLookup(t *testing.T) {
 		t.Fatal("Push endpoint should not contain mirror")
 		t.Fatal("Push endpoint should not contain mirror")
 	}
 	}
 
 
-	pullAPIEndpoints, err := s.LookupPullEndpoints(imageName)
+	pullAPIEndpoints, err := s.LookupPullEndpoints(imageName.Hostname())
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}

+ 29 - 26
registry/service.go

@@ -6,6 +6,7 @@ import (
 	"net/url"
 	"net/url"
 	"strings"
 	"strings"
 
 
+	"github.com/Sirupsen/logrus"
 	"github.com/docker/docker/reference"
 	"github.com/docker/docker/reference"
 	"github.com/docker/engine-api/types"
 	"github.com/docker/engine-api/types"
 	registrytypes "github.com/docker/engine-api/types/registry"
 	registrytypes "github.com/docker/engine-api/types/registry"
@@ -28,29 +29,31 @@ func NewService(options *Options) *Service {
 // Auth contacts the public registry with the provided credentials,
 // Auth contacts the public registry with the provided credentials,
 // and returns OK if authentication was successful.
 // and returns OK if authentication was successful.
 // It can be used to verify the validity of a client's credentials.
 // It can be used to verify the validity of a client's credentials.
-func (s *Service) Auth(authConfig *types.AuthConfig, userAgent string) (string, error) {
-	addr := authConfig.ServerAddress
-	if addr == "" {
-		// Use the official registry address if not specified.
-		addr = IndexServer
-	}
-	index, err := s.ResolveIndex(addr)
+func (s *Service) Auth(authConfig *types.AuthConfig, userAgent string) (status string, err error) {
+	endpoints, err := s.LookupPushEndpoints(authConfig.ServerAddress)
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
 
 
-	endpointVersion := APIVersion(APIVersionUnknown)
-	if V2Only {
-		// Override the endpoint to only attempt a v2 ping
-		endpointVersion = APIVersion2
-	}
+	for _, endpoint := range endpoints {
+		login := loginV2
+		if endpoint.Version == APIVersion1 {
+			login = loginV1
+		}
 
 
-	endpoint, err := NewEndpoint(index, userAgent, nil, endpointVersion)
-	if err != nil {
+		status, err = login(authConfig, endpoint, userAgent)
+		if err == nil {
+			return
+		}
+		if fErr, ok := err.(fallbackError); ok {
+			err = fErr.err
+			logrus.Infof("Error logging in to %s endpoint, trying next endpoint: %v", endpoint.Version, err)
+			continue
+		}
 		return "", err
 		return "", err
 	}
 	}
-	authConfig.ServerAddress = endpoint.String()
-	return Login(authConfig, endpoint)
+
+	return "", err
 }
 }
 
 
 // splitReposSearchTerm breaks a search term into an index name and remote name
 // splitReposSearchTerm breaks a search term into an index name and remote name
@@ -85,7 +88,7 @@ func (s *Service) Search(term string, authConfig *types.AuthConfig, userAgent st
 	}
 	}
 
 
 	// *TODO: Search multiple indexes.
 	// *TODO: Search multiple indexes.
-	endpoint, err := NewEndpoint(index, userAgent, http.Header(headers), APIVersionUnknown)
+	endpoint, err := NewV1Endpoint(index, userAgent, http.Header(headers))
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -129,8 +132,8 @@ type APIEndpoint struct {
 }
 }
 
 
 // ToV1Endpoint returns a V1 API endpoint based on the APIEndpoint
 // ToV1Endpoint returns a V1 API endpoint based on the APIEndpoint
-func (e APIEndpoint) ToV1Endpoint(userAgent string, metaHeaders http.Header) (*Endpoint, error) {
-	return newEndpoint(*e.URL, e.TLSConfig, userAgent, metaHeaders)
+func (e APIEndpoint) ToV1Endpoint(userAgent string, metaHeaders http.Header) (*V1Endpoint, error) {
+	return newV1Endpoint(*e.URL, e.TLSConfig, userAgent, metaHeaders)
 }
 }
 
 
 // TLSConfig constructs a client TLS configuration based on server defaults
 // TLSConfig constructs a client TLS configuration based on server defaults
@@ -145,15 +148,15 @@ func (s *Service) tlsConfigForMirror(mirrorURL *url.URL) (*tls.Config, error) {
 // LookupPullEndpoints creates an list of endpoints to try to pull from, in order of preference.
 // LookupPullEndpoints creates an list of endpoints to try to pull from, in order of preference.
 // It gives preference to v2 endpoints over v1, mirrors over the actual
 // It gives preference to v2 endpoints over v1, mirrors over the actual
 // registry, and HTTPS over plain HTTP.
 // registry, and HTTPS over plain HTTP.
-func (s *Service) LookupPullEndpoints(repoName reference.Named) (endpoints []APIEndpoint, err error) {
-	return s.lookupEndpoints(repoName)
+func (s *Service) LookupPullEndpoints(hostname string) (endpoints []APIEndpoint, err error) {
+	return s.lookupEndpoints(hostname)
 }
 }
 
 
 // LookupPushEndpoints creates an list of endpoints to try to push to, in order of preference.
 // LookupPushEndpoints creates an list of endpoints to try to push to, in order of preference.
 // It gives preference to v2 endpoints over v1, and HTTPS over plain HTTP.
 // It gives preference to v2 endpoints over v1, and HTTPS over plain HTTP.
 // Mirrors are not included.
 // Mirrors are not included.
-func (s *Service) LookupPushEndpoints(repoName reference.Named) (endpoints []APIEndpoint, err error) {
-	allEndpoints, err := s.lookupEndpoints(repoName)
+func (s *Service) LookupPushEndpoints(hostname string) (endpoints []APIEndpoint, err error) {
+	allEndpoints, err := s.lookupEndpoints(hostname)
 	if err == nil {
 	if err == nil {
 		for _, endpoint := range allEndpoints {
 		for _, endpoint := range allEndpoints {
 			if !endpoint.Mirror {
 			if !endpoint.Mirror {
@@ -164,8 +167,8 @@ func (s *Service) LookupPushEndpoints(repoName reference.Named) (endpoints []API
 	return endpoints, err
 	return endpoints, err
 }
 }
 
 
-func (s *Service) lookupEndpoints(repoName reference.Named) (endpoints []APIEndpoint, err error) {
-	endpoints, err = s.lookupV2Endpoints(repoName)
+func (s *Service) lookupEndpoints(hostname string) (endpoints []APIEndpoint, err error) {
+	endpoints, err = s.lookupV2Endpoints(hostname)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -174,7 +177,7 @@ func (s *Service) lookupEndpoints(repoName reference.Named) (endpoints []APIEndp
 		return endpoints, nil
 		return endpoints, nil
 	}
 	}
 
 
-	legacyEndpoints, err := s.lookupV1Endpoints(repoName)
+	legacyEndpoints, err := s.lookupV1Endpoints(hostname)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 2 - 12
registry/service_v1.go

@@ -1,19 +1,15 @@
 package registry
 package registry
 
 
 import (
 import (
-	"fmt"
 	"net/url"
 	"net/url"
-	"strings"
 
 
-	"github.com/docker/docker/reference"
 	"github.com/docker/go-connections/tlsconfig"
 	"github.com/docker/go-connections/tlsconfig"
 )
 )
 
 
-func (s *Service) lookupV1Endpoints(repoName reference.Named) (endpoints []APIEndpoint, err error) {
+func (s *Service) lookupV1Endpoints(hostname string) (endpoints []APIEndpoint, err error) {
 	var cfg = tlsconfig.ServerDefault
 	var cfg = tlsconfig.ServerDefault
 	tlsConfig := &cfg
 	tlsConfig := &cfg
-	nameString := repoName.FullName()
-	if strings.HasPrefix(nameString, DefaultNamespace+"/") {
+	if hostname == DefaultNamespace {
 		endpoints = append(endpoints, APIEndpoint{
 		endpoints = append(endpoints, APIEndpoint{
 			URL:          DefaultV1Registry,
 			URL:          DefaultV1Registry,
 			Version:      APIVersion1,
 			Version:      APIVersion1,
@@ -24,12 +20,6 @@ func (s *Service) lookupV1Endpoints(repoName reference.Named) (endpoints []APIEn
 		return endpoints, nil
 		return endpoints, nil
 	}
 	}
 
 
-	slashIndex := strings.IndexRune(nameString, '/')
-	if slashIndex <= 0 {
-		return nil, fmt.Errorf("invalid repo name: missing '/':  %s", nameString)
-	}
-	hostname := nameString[:slashIndex]
-
 	tlsConfig, err = s.TLSConfig(hostname)
 	tlsConfig, err = s.TLSConfig(hostname)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err

+ 2 - 11
registry/service_v2.go

@@ -1,19 +1,16 @@
 package registry
 package registry
 
 
 import (
 import (
-	"fmt"
 	"net/url"
 	"net/url"
 	"strings"
 	"strings"
 
 
-	"github.com/docker/docker/reference"
 	"github.com/docker/go-connections/tlsconfig"
 	"github.com/docker/go-connections/tlsconfig"
 )
 )
 
 
-func (s *Service) lookupV2Endpoints(repoName reference.Named) (endpoints []APIEndpoint, err error) {
+func (s *Service) lookupV2Endpoints(hostname string) (endpoints []APIEndpoint, err error) {
 	var cfg = tlsconfig.ServerDefault
 	var cfg = tlsconfig.ServerDefault
 	tlsConfig := &cfg
 	tlsConfig := &cfg
-	nameString := repoName.FullName()
-	if strings.HasPrefix(nameString, DefaultNamespace+"/") {
+	if hostname == DefaultNamespace {
 		// v2 mirrors
 		// v2 mirrors
 		for _, mirror := range s.Config.Mirrors {
 		for _, mirror := range s.Config.Mirrors {
 			if !strings.HasPrefix(mirror, "http://") && !strings.HasPrefix(mirror, "https://") {
 			if !strings.HasPrefix(mirror, "http://") && !strings.HasPrefix(mirror, "https://") {
@@ -48,12 +45,6 @@ func (s *Service) lookupV2Endpoints(repoName reference.Named) (endpoints []APIEn
 		return endpoints, nil
 		return endpoints, nil
 	}
 	}
 
 
-	slashIndex := strings.IndexRune(nameString, '/')
-	if slashIndex <= 0 {
-		return nil, fmt.Errorf("invalid repo name: missing '/':  %s", nameString)
-	}
-	hostname := nameString[:slashIndex]
-
 	tlsConfig, err = s.TLSConfig(hostname)
 	tlsConfig, err = s.TLSConfig(hostname)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err

+ 8 - 8
registry/session.go

@@ -37,7 +37,7 @@ var (
 
 
 // A Session is used to communicate with a V1 registry
 // A Session is used to communicate with a V1 registry
 type Session struct {
 type Session struct {
-	indexEndpoint *Endpoint
+	indexEndpoint *V1Endpoint
 	client        *http.Client
 	client        *http.Client
 	// TODO(tiborvass): remove authConfig
 	// TODO(tiborvass): remove authConfig
 	authConfig *types.AuthConfig
 	authConfig *types.AuthConfig
@@ -163,7 +163,7 @@ func (tr *authTransport) CancelRequest(req *http.Request) {
 
 
 // NewSession creates a new session
 // NewSession creates a new session
 // TODO(tiborvass): remove authConfig param once registry client v2 is vendored
 // TODO(tiborvass): remove authConfig param once registry client v2 is vendored
-func NewSession(client *http.Client, authConfig *types.AuthConfig, endpoint *Endpoint) (r *Session, err error) {
+func NewSession(client *http.Client, authConfig *types.AuthConfig, endpoint *V1Endpoint) (r *Session, err error) {
 	r = &Session{
 	r = &Session{
 		authConfig:    authConfig,
 		authConfig:    authConfig,
 		client:        client,
 		client:        client,
@@ -175,7 +175,7 @@ func NewSession(client *http.Client, authConfig *types.AuthConfig, endpoint *End
 
 
 	// If we're working with a standalone private registry over HTTPS, send Basic Auth headers
 	// If we're working with a standalone private registry over HTTPS, send Basic Auth headers
 	// alongside all our requests.
 	// alongside all our requests.
-	if endpoint.VersionString(1) != IndexServer && endpoint.URL.Scheme == "https" {
+	if endpoint.String() != IndexServer && endpoint.URL.Scheme == "https" {
 		info, err := endpoint.Ping()
 		info, err := endpoint.Ping()
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
@@ -405,7 +405,7 @@ func buildEndpointsList(headers []string, indexEp string) ([]string, error) {
 
 
 // GetRepositoryData returns lists of images and endpoints for the repository
 // GetRepositoryData returns lists of images and endpoints for the repository
 func (r *Session) GetRepositoryData(name reference.Named) (*RepositoryData, error) {
 func (r *Session) GetRepositoryData(name reference.Named) (*RepositoryData, error) {
-	repositoryTarget := fmt.Sprintf("%srepositories/%s/images", r.indexEndpoint.VersionString(1), name.RemoteName())
+	repositoryTarget := fmt.Sprintf("%srepositories/%s/images", r.indexEndpoint.String(), name.RemoteName())
 
 
 	logrus.Debugf("[registry] Calling GET %s", repositoryTarget)
 	logrus.Debugf("[registry] Calling GET %s", repositoryTarget)
 
 
@@ -444,7 +444,7 @@ func (r *Session) GetRepositoryData(name reference.Named) (*RepositoryData, erro
 
 
 	var endpoints []string
 	var endpoints []string
 	if res.Header.Get("X-Docker-Endpoints") != "" {
 	if res.Header.Get("X-Docker-Endpoints") != "" {
-		endpoints, err = buildEndpointsList(res.Header["X-Docker-Endpoints"], r.indexEndpoint.VersionString(1))
+		endpoints, err = buildEndpointsList(res.Header["X-Docker-Endpoints"], r.indexEndpoint.String())
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -634,7 +634,7 @@ func (r *Session) PushImageJSONIndex(remote reference.Named, imgList []*ImgData,
 	if validate {
 	if validate {
 		suffix = "images"
 		suffix = "images"
 	}
 	}
-	u := fmt.Sprintf("%srepositories/%s/%s", r.indexEndpoint.VersionString(1), remote.RemoteName(), suffix)
+	u := fmt.Sprintf("%srepositories/%s/%s", r.indexEndpoint.String(), remote.RemoteName(), suffix)
 	logrus.Debugf("[registry] PUT %s", u)
 	logrus.Debugf("[registry] PUT %s", u)
 	logrus.Debugf("Image list pushed to index:\n%s", imgListJSON)
 	logrus.Debugf("Image list pushed to index:\n%s", imgListJSON)
 	headers := map[string][]string{
 	headers := map[string][]string{
@@ -680,7 +680,7 @@ func (r *Session) PushImageJSONIndex(remote reference.Named, imgList []*ImgData,
 		if res.Header.Get("X-Docker-Endpoints") == "" {
 		if res.Header.Get("X-Docker-Endpoints") == "" {
 			return nil, fmt.Errorf("Index response didn't contain any endpoints")
 			return nil, fmt.Errorf("Index response didn't contain any endpoints")
 		}
 		}
-		endpoints, err = buildEndpointsList(res.Header["X-Docker-Endpoints"], r.indexEndpoint.VersionString(1))
+		endpoints, err = buildEndpointsList(res.Header["X-Docker-Endpoints"], r.indexEndpoint.String())
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -722,7 +722,7 @@ func shouldRedirect(response *http.Response) bool {
 // SearchRepositories performs a search against the remote repository
 // SearchRepositories performs a search against the remote repository
 func (r *Session) SearchRepositories(term string) (*registrytypes.SearchResults, error) {
 func (r *Session) SearchRepositories(term string) (*registrytypes.SearchResults, error) {
 	logrus.Debugf("Index server: %s", r.indexEndpoint)
 	logrus.Debugf("Index server: %s", r.indexEndpoint)
-	u := r.indexEndpoint.VersionString(1) + "search?q=" + url.QueryEscape(term)
+	u := r.indexEndpoint.String() + "search?q=" + url.QueryEscape(term)
 
 
 	req, err := http.NewRequest("GET", u, nil)
 	req, err := http.NewRequest("GET", u, nil)
 	if err != nil {
 	if err != nil {

+ 0 - 81
registry/token.go

@@ -1,81 +0,0 @@
-package registry
-
-import (
-	"encoding/json"
-	"errors"
-	"fmt"
-	"net/http"
-	"net/url"
-	"strings"
-)
-
-type tokenResponse struct {
-	Token string `json:"token"`
-}
-
-func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint) (string, error) {
-	realm, ok := params["realm"]
-	if !ok {
-		return "", errors.New("no realm specified for token auth challenge")
-	}
-
-	realmURL, err := url.Parse(realm)
-	if err != nil {
-		return "", fmt.Errorf("invalid token auth challenge realm: %s", err)
-	}
-
-	if realmURL.Scheme == "" {
-		if registryEndpoint.IsSecure {
-			realmURL.Scheme = "https"
-		} else {
-			realmURL.Scheme = "http"
-		}
-	}
-
-	req, err := http.NewRequest("GET", realmURL.String(), nil)
-	if err != nil {
-		return "", err
-	}
-
-	reqParams := req.URL.Query()
-	service := params["service"]
-	scope := params["scope"]
-
-	if service != "" {
-		reqParams.Add("service", service)
-	}
-
-	for _, scopeField := range strings.Fields(scope) {
-		reqParams.Add("scope", scopeField)
-	}
-
-	if username != "" {
-		reqParams.Add("account", username)
-		req.SetBasicAuth(username, password)
-	}
-
-	req.URL.RawQuery = reqParams.Encode()
-
-	resp, err := registryEndpoint.client.Do(req)
-	if err != nil {
-		return "", err
-	}
-	defer resp.Body.Close()
-
-	if resp.StatusCode != http.StatusOK {
-		return "", fmt.Errorf("token auth attempt for registry %s: %s request failed with status: %d %s", registryEndpoint, req.URL, resp.StatusCode, http.StatusText(resp.StatusCode))
-	}
-
-	decoder := json.NewDecoder(resp.Body)
-
-	tr := new(tokenResponse)
-	if err = decoder.Decode(tr); err != nil {
-		return "", fmt.Errorf("unable to decode token response: %s", err)
-	}
-
-	if tr.Token == "" {
-		return "", errors.New("authorization server did not include a token in the response")
-	}
-
-	return tr.Token, nil
-}

+ 7 - 7
registry/types.go

@@ -46,18 +46,18 @@ func (av APIVersion) String() string {
 	return apiVersions[av]
 	return apiVersions[av]
 }
 }
 
 
-var apiVersions = map[APIVersion]string{
-	1: "v1",
-	2: "v2",
-}
-
 // API Version identifiers.
 // API Version identifiers.
 const (
 const (
-	APIVersionUnknown = iota
-	APIVersion1
+	_                      = iota
+	APIVersion1 APIVersion = iota
 	APIVersion2
 	APIVersion2
 )
 )
 
 
+var apiVersions = map[APIVersion]string{
+	APIVersion1: "v1",
+	APIVersion2: "v2",
+}
+
 // RepositoryInfo describes a repository
 // RepositoryInfo describes a repository
 type RepositoryInfo struct {
 type RepositoryInfo struct {
 	reference.Named
 	reference.Named