apiclient: fix http roundtrip (clone body also) (#1758)
* apiclient: fix http roundtrip (clone body also)
This commit is contained in:
parent
fe23da6e0c
commit
579cecde04
3 changed files with 44 additions and 17 deletions
|
@ -78,7 +78,7 @@ func (t *APIKeyTransport) transport() http.RoundTripper {
|
|||
type JWTTransport struct {
|
||||
MachineID *string
|
||||
Password *strfmt.Password
|
||||
token string
|
||||
Token string
|
||||
Expiration time.Time
|
||||
Scenarios []string
|
||||
URL *url.URL
|
||||
|
@ -88,6 +88,7 @@ type JWTTransport struct {
|
|||
// It will default to http.DefaultTransport if nil.
|
||||
Transport http.RoundTripper
|
||||
UpdateScenario func() ([]string, error)
|
||||
NbRetry int
|
||||
}
|
||||
|
||||
func (t *JWTTransport) refreshJwtToken() error {
|
||||
|
@ -161,45 +162,63 @@ func (t *JWTTransport) refreshJwtToken() error {
|
|||
if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil {
|
||||
return errors.Wrap(err, "unable to parse jwt expiration")
|
||||
}
|
||||
t.token = response.Token
|
||||
t.Token = response.Token
|
||||
|
||||
log.Debugf("token %s will expire on %s", t.token, t.Expiration.String())
|
||||
log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// RoundTrip implements the RoundTripper interface.
|
||||
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if t.token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) {
|
||||
if t.NbRetry > 1 {
|
||||
t.NbRetry = 0
|
||||
return nil, fmt.Errorf("unable to refresh JWT token multiple times")
|
||||
}
|
||||
if t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) {
|
||||
if err := t.refreshJwtToken(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// We must make a copy of the Request so
|
||||
// that we don't modify the Request we were given. This is required by the
|
||||
// specification of http.RoundTripper.
|
||||
req = cloneRequest(req)
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.token))
|
||||
log.Debugf("req-jwt: %s %s", req.Method, req.URL.String())
|
||||
if log.GetLevel() >= log.TraceLevel {
|
||||
dump, _ := httputil.DumpRequest(req, true)
|
||||
log.Tracef("req-jwt: %s", string(dump))
|
||||
}
|
||||
if t.UserAgent != "" {
|
||||
req.Header.Add("User-Agent", t.UserAgent)
|
||||
}
|
||||
|
||||
// We must make a copy of the Request so
|
||||
// that we don't modify the Request we were given. This is required by the
|
||||
// specification of http.RoundTripper.
|
||||
clonedReq := cloneRequest(req)
|
||||
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token))
|
||||
|
||||
if log.GetLevel() >= log.TraceLevel {
|
||||
//requestToDump := cloneRequest(req)
|
||||
dump, _ := httputil.DumpRequest(req, true)
|
||||
log.Tracef("req-jwt: %s", string(dump))
|
||||
}
|
||||
|
||||
// Make the HTTP request.
|
||||
resp, err := t.transport().RoundTrip(req)
|
||||
if log.GetLevel() >= log.TraceLevel {
|
||||
dump, _ := httputil.DumpResponse(resp, true)
|
||||
log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
|
||||
}
|
||||
if err != nil || resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusUnauthorized {
|
||||
if err != nil {
|
||||
/*we had an error (network error for example, or 401 because token is refused), reset the token ?*/
|
||||
t.token = ""
|
||||
t.Token = ""
|
||||
return resp, errors.Wrapf(err, "performing jwt auth")
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
t.Token = ""
|
||||
t.NbRetry++
|
||||
return t.RoundTrip(clonedReq)
|
||||
}
|
||||
|
||||
t.NbRetry = 0
|
||||
|
||||
log.Debugf("resp-jwt: %d", resp.StatusCode)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
@ -225,5 +244,12 @@ func cloneRequest(r *http.Request) *http.Request {
|
|||
for k, s := range r.Header {
|
||||
r2.Header[k] = append([]string(nil), s...)
|
||||
}
|
||||
|
||||
if r.Body != nil {
|
||||
var b bytes.Buffer
|
||||
b.ReadFrom(r.Body)
|
||||
r.Body = io.NopCloser(&b)
|
||||
r2.Body = io.NopCloser(bytes.NewReader(b.Bytes()))
|
||||
}
|
||||
return r2
|
||||
}
|
||||
|
|
|
@ -234,5 +234,5 @@ func TestWatcherEnroll(t *testing.T) {
|
|||
}
|
||||
|
||||
_, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false)
|
||||
assert.Contains(t, err.Error(), "the attachment key provided is not valid")
|
||||
assert.Contains(t, err.Error(), "unable to refresh JWT token multiple times", "got %s", err.Error())
|
||||
}
|
||||
|
|
|
@ -51,6 +51,7 @@ func NewClient(config *Config) (*ApiClient, error) {
|
|||
UserAgent: config.UserAgent,
|
||||
VersionPrefix: config.VersionPrefix,
|
||||
UpdateScenario: config.UpdateScenario,
|
||||
NbRetry: 0,
|
||||
}
|
||||
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
|
||||
if Cert != nil {
|
||||
|
|
Loading…
Reference in a new issue