|
@@ -3,6 +3,8 @@ package apiclient
|
|
|
import (
|
|
|
"bytes"
|
|
|
"encoding/json"
|
|
|
+ "math/rand"
|
|
|
+ "sync"
|
|
|
"time"
|
|
|
|
|
|
//"errors"
|
|
@@ -12,6 +14,7 @@ import (
|
|
|
"net/http/httputil"
|
|
|
"net/url"
|
|
|
|
|
|
+ "github.com/crowdsecurity/crowdsec/pkg/fflag"
|
|
|
"github.com/crowdsecurity/crowdsec/pkg/models"
|
|
|
"github.com/go-openapi/strfmt"
|
|
|
"github.com/pkg/errors"
|
|
@@ -75,6 +78,53 @@ func (t *APIKeyTransport) transport() http.RoundTripper {
|
|
|
return http.DefaultTransport
|
|
|
}
|
|
|
|
|
|
+type retryRoundTripper struct {
|
|
|
+ next http.RoundTripper
|
|
|
+ maxAttempts int
|
|
|
+ retryStatusCodes []int
|
|
|
+ withBackOff bool
|
|
|
+ onBeforeRequest func(attempt int)
|
|
|
+}
|
|
|
+
|
|
|
+func (r retryRoundTripper) ShouldRetry(statusCode int) bool {
|
|
|
+ for _, code := range r.retryStatusCodes {
|
|
|
+ if code == statusCode {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
+func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
+ var resp *http.Response
|
|
|
+ var err error
|
|
|
+ backoff := 0
|
|
|
+ for i := 0; i < r.maxAttempts; i++ {
|
|
|
+ if i > 0 {
|
|
|
+ if r.withBackOff && !fflag.DisableHttpRetryBackoff.IsEnabled() {
|
|
|
+ backoff += 10 + rand.Intn(20)
|
|
|
+ }
|
|
|
+ log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts)
|
|
|
+ select {
|
|
|
+ case <-req.Context().Done():
|
|
|
+ return resp, req.Context().Err()
|
|
|
+ case <-time.After(time.Duration(backoff) * time.Second):
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if r.onBeforeRequest != nil {
|
|
|
+ r.onBeforeRequest(i)
|
|
|
+ }
|
|
|
+ clonedReq := cloneRequest(req)
|
|
|
+ resp, err = r.next.RoundTrip(clonedReq)
|
|
|
+ if err == nil {
|
|
|
+ if !r.ShouldRetry(resp.StatusCode) {
|
|
|
+ return resp, nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return resp, err
|
|
|
+}
|
|
|
+
|
|
|
type JWTTransport struct {
|
|
|
MachineID *string
|
|
|
Password *strfmt.Password
|
|
@@ -86,9 +136,9 @@ type JWTTransport struct {
|
|
|
UserAgent string
|
|
|
// Transport is the underlying HTTP transport to use when making requests.
|
|
|
// It will default to http.DefaultTransport if nil.
|
|
|
- Transport http.RoundTripper
|
|
|
- UpdateScenario func() ([]string, error)
|
|
|
- NbRetry int
|
|
|
+ Transport http.RoundTripper
|
|
|
+ UpdateScenario func() ([]string, error)
|
|
|
+ refreshTokenMutex sync.Mutex
|
|
|
}
|
|
|
|
|
|
func (t *JWTTransport) refreshJwtToken() error {
|
|
@@ -124,7 +174,14 @@ func (t *JWTTransport) refreshJwtToken() error {
|
|
|
return errors.Wrap(err, "could not create request")
|
|
|
}
|
|
|
req.Header.Add("Content-Type", "application/json")
|
|
|
- client := &http.Client{}
|
|
|
+ client := &http.Client{
|
|
|
+ Transport: &retryRoundTripper{
|
|
|
+ next: http.DefaultTransport,
|
|
|
+ maxAttempts: 5,
|
|
|
+ withBackOff: true,
|
|
|
+ retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError},
|
|
|
+ },
|
|
|
+ }
|
|
|
if t.UserAgent != "" {
|
|
|
req.Header.Add("User-Agent", t.UserAgent)
|
|
|
}
|
|
@@ -150,6 +207,7 @@ func (t *JWTTransport) refreshJwtToken() error {
|
|
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
|
log.Debugf("received response status %q when fetching %v", resp.Status, req.URL)
|
|
|
+
|
|
|
err = CheckResponse(resp)
|
|
|
if err != nil {
|
|
|
return err
|
|
@@ -170,25 +228,21 @@ func (t *JWTTransport) refreshJwtToken() error {
|
|
|
|
|
|
// RoundTrip implements the RoundTripper interface.
|
|
|
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
- if t.NbRetry > 1 {
|
|
|
- t.NbRetry = 0
|
|
|
- return nil, fmt.Errorf("unable to refresh JWT token multiple times")
|
|
|
- }
|
|
|
+ // in a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI
|
|
|
+ // we use a mutex to avoid this
|
|
|
+ t.refreshTokenMutex.Lock()
|
|
|
if t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) {
|
|
|
if err := t.refreshJwtToken(); err != nil {
|
|
|
+ t.refreshTokenMutex.Unlock()
|
|
|
return nil, err
|
|
|
}
|
|
|
}
|
|
|
+ t.refreshTokenMutex.Unlock()
|
|
|
|
|
|
if t.UserAgent != "" {
|
|
|
req.Header.Add("User-Agent", t.UserAgent)
|
|
|
}
|
|
|
|
|
|
- // We must make a copy of the Request so
|
|
|
- // that we don't modify the Request we were given. This is required by the
|
|
|
- // specification of http.RoundTripper.
|
|
|
- clonedReq := cloneRequest(req)
|
|
|
-
|
|
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token))
|
|
|
|
|
|
if log.GetLevel() >= log.TraceLevel {
|
|
@@ -209,14 +263,6 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
return resp, errors.Wrapf(err, "performing jwt auth")
|
|
|
}
|
|
|
|
|
|
- if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
|
|
- t.Token = ""
|
|
|
- t.NbRetry++
|
|
|
- return t.RoundTrip(clonedReq)
|
|
|
- }
|
|
|
-
|
|
|
- t.NbRetry = 0
|
|
|
-
|
|
|
log.Debugf("resp-jwt: %d", resp.StatusCode)
|
|
|
|
|
|
return resp, nil
|
|
@@ -226,11 +272,39 @@ func (t *JWTTransport) Client() *http.Client {
|
|
|
return &http.Client{Transport: t}
|
|
|
}
|
|
|
|
|
|
+func (t *JWTTransport) ResetToken() {
|
|
|
+ log.Debug("resetting jwt token")
|
|
|
+ t.refreshTokenMutex.Lock()
|
|
|
+ t.Token = ""
|
|
|
+ t.refreshTokenMutex.Unlock()
|
|
|
+}
|
|
|
+
|
|
|
func (t *JWTTransport) transport() http.RoundTripper {
|
|
|
+ var transport http.RoundTripper
|
|
|
if t.Transport != nil {
|
|
|
- return t.Transport
|
|
|
+ transport = t.Transport
|
|
|
+ } else {
|
|
|
+ transport = http.DefaultTransport
|
|
|
+ }
|
|
|
+ // a round tripper that retries once when the status is unauthorized and 5 times when infrastructure is overloaded
|
|
|
+ return &retryRoundTripper{
|
|
|
+ next: &retryRoundTripper{
|
|
|
+ next: transport,
|
|
|
+ maxAttempts: 5,
|
|
|
+ withBackOff: true,
|
|
|
+ retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout},
|
|
|
+ },
|
|
|
+ maxAttempts: 2,
|
|
|
+ withBackOff: false,
|
|
|
+ retryStatusCodes: []int{http.StatusUnauthorized, http.StatusForbidden},
|
|
|
+ onBeforeRequest: func(attempt int) {
|
|
|
+ // reset the token only in the second attempt as this is when we know we had a 401 or 403
|
|
|
+ // the second attempt is supposed to refresh the token
|
|
|
+ if attempt > 0 {
|
|
|
+ t.ResetToken()
|
|
|
+ }
|
|
|
+ },
|
|
|
}
|
|
|
- return http.DefaultTransport
|
|
|
}
|
|
|
|
|
|
// cloneRequest returns a clone of the provided *http.Request. The clone is a
|