Просмотр исходного кода

Get token on each request

Signed-off-by: Derek McGowan <derek@mcgstyle.net>
Derek McGowan 10 лет назад
Родитель
Сommit
d094eb6f7f
2 измененных файлов с 62 добавлено и 32 удалено
  1. 37 23
      registry/auth.go
  2. 25 9
      registry/session_v2.go

+ 37 - 23
registry/auth.go

@@ -38,56 +38,70 @@ type ConfigFile struct {
 }
 }
 
 
 type RequestAuthorization struct {
 type RequestAuthorization struct {
-	Token    string
-	Username string
-	Password string
+	authConfig       *AuthConfig
+	registryEndpoint *Endpoint
+	resource         string
+	scope            string
+	actions          []string
 }
 }
 
 
-func NewRequestAuthorization(authConfig *AuthConfig, registryEndpoint *Endpoint, resource, scope string, actions []string) (*RequestAuthorization, error) {
-	var auth RequestAuthorization
+func NewRequestAuthorization(authConfig *AuthConfig, registryEndpoint *Endpoint, resource, scope string, actions []string) *RequestAuthorization {
+	return &RequestAuthorization{
+		authConfig:       authConfig,
+		registryEndpoint: registryEndpoint,
+		resource:         resource,
+		scope:            scope,
+		actions:          actions,
+	}
+}
 
 
+func (auth *RequestAuthorization) getToken() (string, error) {
+	// TODO check if already has token and before expiration
 	client := &http.Client{
 	client := &http.Client{
 		Transport: &http.Transport{
 		Transport: &http.Transport{
 			DisableKeepAlives: true,
 			DisableKeepAlives: true,
-			Proxy:             http.ProxyFromEnvironment,
-		},
+			Proxy:             http.ProxyFromEnvironment},
 		CheckRedirect: AddRequiredHeadersToRedirectedRequests,
 		CheckRedirect: AddRequiredHeadersToRedirectedRequests,
 	}
 	}
 	factory := HTTPRequestFactory(nil)
 	factory := HTTPRequestFactory(nil)
 
 
-	for _, challenge := range registryEndpoint.AuthChallenges {
-		log.Debugf("Using %q auth challenge with params %s for %s", challenge.Scheme, challenge.Parameters, authConfig.Username)
-
+	for _, challenge := range auth.registryEndpoint.AuthChallenges {
 		switch strings.ToLower(challenge.Scheme) {
 		switch strings.ToLower(challenge.Scheme) {
 		case "basic":
 		case "basic":
-			auth.Username = authConfig.Username
-			auth.Password = authConfig.Password
+			// no token necessary
 		case "bearer":
 		case "bearer":
+			log.Debugf("Getting bearer token with %s for %s", challenge.Parameters, auth.authConfig.Username)
 			params := map[string]string{}
 			params := map[string]string{}
 			for k, v := range challenge.Parameters {
 			for k, v := range challenge.Parameters {
 				params[k] = v
 				params[k] = v
 			}
 			}
-			params["scope"] = fmt.Sprintf("%s:%s:%s", resource, scope, strings.Join(actions, ","))
-			token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint, client, factory)
+			params["scope"] = fmt.Sprintf("%s:%s:%s", auth.resource, auth.scope, strings.Join(auth.actions, ","))
+			token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint, client, factory)
 			if err != nil {
 			if err != nil {
-				return nil, err
+				return "", err
 			}
 			}
+			// TODO cache token and set expiration to one minute from now
 
 
-			auth.Token = token
+			return token, nil
 		default:
 		default:
 			log.Infof("Unsupported auth scheme: %q", challenge.Scheme)
 			log.Infof("Unsupported auth scheme: %q", challenge.Scheme)
 		}
 		}
 	}
 	}
-
-	return &auth, nil
+	// TODO no expiration, do not reattempt to get a token
+	return "", nil
 }
 }
 
 
-func (auth *RequestAuthorization) Authorize(req *http.Request) {
-	if auth.Token != "" {
-		req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", auth.Token))
-	} else if auth.Username != "" && auth.Password != "" {
-		req.SetBasicAuth(auth.Username, auth.Password)
+func (auth *RequestAuthorization) Authorize(req *http.Request) error {
+	token, err := auth.getToken()
+	if err != nil {
+		return err
+	}
+	if token != "" {
+		req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
+	} else if auth.authConfig.Username != "" && auth.authConfig.Password != "" {
+		req.SetBasicAuth(auth.authConfig.Username, auth.authConfig.Password)
 	}
 	}
+	return nil
 }
 }
 
 
 // create a base64 encoded auth string to store in config
 // create a base64 encoded auth string to store in config

+ 25 - 9
registry/session_v2.go

@@ -42,7 +42,7 @@ func (r *Session) GetV2Authorization(imageName string, readOnly bool) (auth *Req
 	r.indexEndpoint = registry
 	r.indexEndpoint = registry
 
 
 	log.Debugf("Getting authorization for %s %s", imageName, scopes)
 	log.Debugf("Getting authorization for %s %s", imageName, scopes)
-	return NewRequestAuthorization(r.GetAuthConfig(true), registry, "repository", imageName, scopes)
+	return NewRequestAuthorization(r.GetAuthConfig(true), registry, "repository", imageName, scopes), nil
 }
 }
 
 
 //
 //
@@ -65,7 +65,9 @@ func (r *Session) GetV2ImageManifest(imageName, tagName string, auth *RequestAut
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	auth.Authorize(req)
+	if err := auth.Authorize(req) {
+		return nil, err
+	}
 	res, _, err := r.doRequest(req)
 	res, _, err := r.doRequest(req)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -103,7 +105,9 @@ func (r *Session) PostV2ImageMountBlob(imageName, sumType, sum string, auth *Req
 	if err != nil {
 	if err != nil {
 		return false, err
 		return false, err
 	}
 	}
-	auth.Authorize(req)
+	if err := auth.Authorize(req) {
+		return nil, err
+	}
 	res, _, err := r.doRequest(req)
 	res, _, err := r.doRequest(req)
 	if err != nil {
 	if err != nil {
 		return false, err
 		return false, err
@@ -132,7 +136,9 @@ func (r *Session) GetV2ImageBlob(imageName, sumType, sum string, blobWrtr io.Wri
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	auth.Authorize(req)
+	if err := auth.Authorize(req) {
+		return nil, err
+	}
 	res, _, err := r.doRequest(req)
 	res, _, err := r.doRequest(req)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -161,7 +167,9 @@ func (r *Session) GetV2ImageBlobReader(imageName, sumType, sum string, auth *Req
 	if err != nil {
 	if err != nil {
 		return nil, 0, err
 		return nil, 0, err
 	}
 	}
-	auth.Authorize(req)
+	if err := auth.Authorize(req) {
+		return nil, err
+	}
 	res, _, err := r.doRequest(req)
 	res, _, err := r.doRequest(req)
 	if err != nil {
 	if err != nil {
 		return nil, 0, err
 		return nil, 0, err
@@ -196,7 +204,9 @@ func (r *Session) PutV2ImageBlob(imageName, sumType, sumStr string, blobRdr io.R
 		return err
 		return err
 	}
 	}
 
 
-	auth.Authorize(req)
+	if err := auth.Authorize(req) {
+		return nil, err
+	}
 	res, _, err := r.doRequest(req)
 	res, _, err := r.doRequest(req)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -212,7 +222,9 @@ func (r *Session) PutV2ImageBlob(imageName, sumType, sumStr string, blobRdr io.R
 	queryParams := url.Values{}
 	queryParams := url.Values{}
 	queryParams.Add("digest", sumType+":"+sumStr)
 	queryParams.Add("digest", sumType+":"+sumStr)
 	req.URL.RawQuery = queryParams.Encode()
 	req.URL.RawQuery = queryParams.Encode()
-	auth.Authorize(req)
+	if err := auth.Authorize(req) {
+		return nil, err
+	}
 	res, _, err = r.doRequest(req)
 	res, _, err = r.doRequest(req)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -242,7 +254,9 @@ func (r *Session) PutV2ImageManifest(imageName, tagName string, manifestRdr io.R
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	auth.Authorize(req)
+	if err := auth.Authorize(req) {
+		return nil, err
+	}
 	res, _, err := r.doRequest(req)
 	res, _, err := r.doRequest(req)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -274,7 +288,9 @@ func (r *Session) GetV2RemoteTags(imageName string, auth *RequestAuthorization)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	auth.Authorize(req)
+	if err := auth.Authorize(req) {
+		return nil, err
+	}
 	res, _, err := r.doRequest(req)
 	res, _, err := r.doRequest(req)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err