From d094eb6f7ffe6b608ecde54297e107e5caa0954d Mon Sep 17 00:00:00 2001 From: Derek McGowan Date: Fri, 19 Dec 2014 16:14:04 -0800 Subject: [PATCH] Get token on each request Signed-off-by: Derek McGowan --- registry/auth.go | 60 ++++++++++++++++++++++++++---------------- registry/session_v2.go | 34 +++++++++++++++++------- 2 files changed, 62 insertions(+), 32 deletions(-) diff --git a/registry/auth.go b/registry/auth.go index b138fb530d..1e1c7ddb82 100644 --- a/registry/auth.go +++ b/registry/auth.go @@ -38,56 +38,70 @@ type ConfigFile struct { } type RequestAuthorization struct { - Token string - Username string - Password string + authConfig *AuthConfig + registryEndpoint *Endpoint + resource string + scope string + actions []string } -func NewRequestAuthorization(authConfig *AuthConfig, registryEndpoint *Endpoint, resource, scope string, actions []string) (*RequestAuthorization, error) { - var auth RequestAuthorization +func NewRequestAuthorization(authConfig *AuthConfig, registryEndpoint *Endpoint, resource, scope string, actions []string) *RequestAuthorization { + return &RequestAuthorization{ + authConfig: authConfig, + registryEndpoint: registryEndpoint, + resource: resource, + scope: scope, + actions: actions, + } +} +func (auth *RequestAuthorization) getToken() (string, error) { + // TODO check if already has token and before expiration client := &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, - Proxy: http.ProxyFromEnvironment, - }, + Proxy: http.ProxyFromEnvironment}, CheckRedirect: AddRequiredHeadersToRedirectedRequests, } factory := HTTPRequestFactory(nil) - for _, challenge := range registryEndpoint.AuthChallenges { - log.Debugf("Using %q auth challenge with params %s for %s", challenge.Scheme, challenge.Parameters, authConfig.Username) - + for _, challenge := range auth.registryEndpoint.AuthChallenges { switch strings.ToLower(challenge.Scheme) { case "basic": - auth.Username = authConfig.Username - auth.Password = authConfig.Password + // no token necessary case "bearer": + log.Debugf("Getting bearer token with %s for %s", challenge.Parameters, auth.authConfig.Username) params := map[string]string{} for k, v := range challenge.Parameters { params[k] = v } - params["scope"] = fmt.Sprintf("%s:%s:%s", resource, scope, strings.Join(actions, ",")) - token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint, client, factory) + params["scope"] = fmt.Sprintf("%s:%s:%s", auth.resource, auth.scope, strings.Join(auth.actions, ",")) + token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint, client, factory) if err != nil { - return nil, err + return "", err } + // TODO cache token and set expiration to one minute from now - auth.Token = token + return token, nil default: log.Infof("Unsupported auth scheme: %q", challenge.Scheme) } } - - return &auth, nil + // TODO no expiration, do not reattempt to get a token + return "", nil } -func (auth *RequestAuthorization) Authorize(req *http.Request) { - if auth.Token != "" { - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", auth.Token)) - } else if auth.Username != "" && auth.Password != "" { - req.SetBasicAuth(auth.Username, auth.Password) +func (auth *RequestAuthorization) Authorize(req *http.Request) error { + token, err := auth.getToken() + if err != nil { + return err } + if token != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + } else if auth.authConfig.Username != "" && auth.authConfig.Password != "" { + req.SetBasicAuth(auth.authConfig.Username, auth.authConfig.Password) + } + return nil } // create a base64 encoded auth string to store in config diff --git a/registry/session_v2.go b/registry/session_v2.go index 2304a61344..491cd2c6e0 100644 --- a/registry/session_v2.go +++ b/registry/session_v2.go @@ -42,7 +42,7 @@ func (r *Session) GetV2Authorization(imageName string, readOnly bool) (auth *Req r.indexEndpoint = registry log.Debugf("Getting authorization for %s %s", imageName, scopes) - return NewRequestAuthorization(r.GetAuthConfig(true), registry, "repository", imageName, scopes) + return NewRequestAuthorization(r.GetAuthConfig(true), registry, "repository", imageName, scopes), nil } // @@ -65,7 +65,9 @@ func (r *Session) GetV2ImageManifest(imageName, tagName string, auth *RequestAut if err != nil { return nil, err } - auth.Authorize(req) + if err := auth.Authorize(req) { + return nil, err + } res, _, err := r.doRequest(req) if err != nil { return nil, err @@ -103,7 +105,9 @@ func (r *Session) PostV2ImageMountBlob(imageName, sumType, sum string, auth *Req if err != nil { return false, err } - auth.Authorize(req) + if err := auth.Authorize(req) { + return nil, err + } res, _, err := r.doRequest(req) if err != nil { return false, err @@ -132,7 +136,9 @@ func (r *Session) GetV2ImageBlob(imageName, sumType, sum string, blobWrtr io.Wri if err != nil { return err } - auth.Authorize(req) + if err := auth.Authorize(req) { + return nil, err + } res, _, err := r.doRequest(req) if err != nil { return err @@ -161,7 +167,9 @@ func (r *Session) GetV2ImageBlobReader(imageName, sumType, sum string, auth *Req if err != nil { return nil, 0, err } - auth.Authorize(req) + if err := auth.Authorize(req) { + return nil, err + } res, _, err := r.doRequest(req) if err != nil { return nil, 0, err @@ -196,7 +204,9 @@ func (r *Session) PutV2ImageBlob(imageName, sumType, sumStr string, blobRdr io.R return err } - auth.Authorize(req) + if err := auth.Authorize(req) { + return nil, err + } res, _, err := r.doRequest(req) if err != nil { return err @@ -212,7 +222,9 @@ func (r *Session) PutV2ImageBlob(imageName, sumType, sumStr string, blobRdr io.R queryParams := url.Values{} queryParams.Add("digest", sumType+":"+sumStr) req.URL.RawQuery = queryParams.Encode() - auth.Authorize(req) + if err := auth.Authorize(req) { + return nil, err + } res, _, err = r.doRequest(req) if err != nil { return err @@ -242,7 +254,9 @@ func (r *Session) PutV2ImageManifest(imageName, tagName string, manifestRdr io.R if err != nil { return err } - auth.Authorize(req) + if err := auth.Authorize(req) { + return nil, err + } res, _, err := r.doRequest(req) if err != nil { return err @@ -274,7 +288,9 @@ func (r *Session) GetV2RemoteTags(imageName string, auth *RequestAuthorization) if err != nil { return nil, err } - auth.Authorize(req) + if err := auth.Authorize(req) { + return nil, err + } res, _, err := r.doRequest(req) if err != nil { return nil, err