Browse Source

apiclient: fix http roundtrip (clone body also) (#1758)

* apiclient: fix http roundtrip (clone body also)
he2ss 2 years ago
parent
commit
579cecde04
3 changed files with 40 additions and 13 deletions
  1. 38 12
      pkg/apiclient/auth.go
  2. 1 1
      pkg/apiclient/auth_service_test.go
  3. 1 0
      pkg/apiclient/client.go

+ 38 - 12
pkg/apiclient/auth.go

@@ -78,7 +78,7 @@ func (t *APIKeyTransport) transport() http.RoundTripper {
 type JWTTransport struct {
 	MachineID     *string
 	Password      *strfmt.Password
-	token         string
+	Token         string
 	Expiration    time.Time
 	Scenarios     []string
 	URL           *url.URL
@@ -88,6 +88,7 @@ type JWTTransport struct {
 	// It will default to http.DefaultTransport if nil.
 	Transport      http.RoundTripper
 	UpdateScenario func() ([]string, error)
+	NbRetry        int
 }
 
 func (t *JWTTransport) refreshJwtToken() error {
@@ -161,45 +162,63 @@ func (t *JWTTransport) refreshJwtToken() error {
 	if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil {
 		return errors.Wrap(err, "unable to parse jwt expiration")
 	}
-	t.token = response.Token
+	t.Token = response.Token
 
-	log.Debugf("token %s will expire on %s", t.token, t.Expiration.String())
+	log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String())
 	return nil
 }
 
 // RoundTrip implements the RoundTripper interface.
 func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
-	if t.token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) {
+	if t.NbRetry > 1 {
+		t.NbRetry = 0
+		return nil, fmt.Errorf("unable to refresh JWT token multiple times")
+	}
+	if t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) {
 		if err := t.refreshJwtToken(); err != nil {
 			return nil, err
 		}
 	}
 
+	if t.UserAgent != "" {
+		req.Header.Add("User-Agent", t.UserAgent)
+	}
+
 	// We must make a copy of the Request so
 	// that we don't modify the Request we were given. This is required by the
 	// specification of http.RoundTripper.
-	req = cloneRequest(req)
-	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.token))
-	log.Debugf("req-jwt: %s %s", req.Method, req.URL.String())
+	clonedReq := cloneRequest(req)
+
+	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token))
+
 	if log.GetLevel() >= log.TraceLevel {
+		//requestToDump := cloneRequest(req)
 		dump, _ := httputil.DumpRequest(req, true)
 		log.Tracef("req-jwt: %s", string(dump))
 	}
-	if t.UserAgent != "" {
-		req.Header.Add("User-Agent", t.UserAgent)
-	}
+
 	// Make the HTTP request.
 	resp, err := t.transport().RoundTrip(req)
 	if log.GetLevel() >= log.TraceLevel {
 		dump, _ := httputil.DumpResponse(resp, true)
 		log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
 	}
-	if err != nil || resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusUnauthorized {
+	if err != nil {
 		/*we had an error (network error for example, or 401 because token is refused), reset the token ?*/
-		t.token = ""
+		t.Token = ""
 		return resp, errors.Wrapf(err, "performing jwt auth")
 	}
+
+	if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
+		t.Token = ""
+		t.NbRetry++
+		return t.RoundTrip(clonedReq)
+	}
+
+	t.NbRetry = 0
+
 	log.Debugf("resp-jwt: %d", resp.StatusCode)
+
 	return resp, nil
 }
 
@@ -225,5 +244,12 @@ func cloneRequest(r *http.Request) *http.Request {
 	for k, s := range r.Header {
 		r2.Header[k] = append([]string(nil), s...)
 	}
+
+	if r.Body != nil {
+		var b bytes.Buffer
+		b.ReadFrom(r.Body)
+		r.Body = io.NopCloser(&b)
+		r2.Body = io.NopCloser(bytes.NewReader(b.Bytes()))
+	}
 	return r2
 }

+ 1 - 1
pkg/apiclient/auth_service_test.go

@@ -234,5 +234,5 @@ func TestWatcherEnroll(t *testing.T) {
 	}
 
 	_, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false)
-	assert.Contains(t, err.Error(), "the attachment key provided is not valid")
+	assert.Contains(t, err.Error(), "unable to refresh JWT token multiple times", "got %s", err.Error())
 }

+ 1 - 0
pkg/apiclient/client.go

@@ -51,6 +51,7 @@ func NewClient(config *Config) (*ApiClient, error) {
 		UserAgent:      config.UserAgent,
 		VersionPrefix:  config.VersionPrefix,
 		UpdateScenario: config.UpdateScenario,
+		NbRetry:        0,
 	}
 	tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
 	if Cert != nil {