Browse Source

retry with backoff requests to CAPI (#1957)

* backoff on refresh token error

* fix tls communication with lapi and user/pw auth (#1956)

allow self-signed TLS encryption with user/pw auth

docker:
 - remove defaults for certificate file locations
 - new envvar INSECURE_SKIP_VERIFY
 - register agent before TLS settings (cscli machine add removes them
   from the credentials file)

* separate cscli cobra constructors:  lapi, machines, bouncers, postoverflows (#1945)

* use feature toggling to improve testability with http retry backoff

* Add parse unix to dateparse enricher (#1958)

Add parse unix is we do have a strTime but wasnt parsed using convential golang time

* func tests: redirect stderr to filter extra logs (#1961)

* backoff on refresh token error

* use feature toggling to improve testability with http retry backoff

* refactor feature backoff toggle for tests

Co-authored-by: mmetc <92726601+mmetc@users.noreply.github.com>
Co-authored-by: Laurence Jones <laurence.jones@live.co.uk>
Cristian Nitescu 2 years ago
parent
commit
7284c0a47a

+ 1 - 0
.github/workflows/bats-mysql.yml

@@ -10,6 +10,7 @@ on:
 env:
 env:
   PREFIX_TEST_NAMES_WITH_FILE: true
   PREFIX_TEST_NAMES_WITH_FILE: true
 
 
+
 jobs:
 jobs:
 
 
   build:
   build:

+ 1 - 0
.github/workflows/go-tests-windows.yml

@@ -16,6 +16,7 @@ on:
 
 
 env:
 env:
   RICHGO_FORCE_COLOR: 1
   RICHGO_FORCE_COLOR: 1
+  CROWDSEC_FEATURE_DISABLE_HTTP_RETRY_BACKOFF: true
 
 
 jobs:
 jobs:
 
 

+ 1 - 0
.github/workflows/go-tests.yml

@@ -30,6 +30,7 @@ env:
   # and to override our endpoint in aws sdk
   # and to override our endpoint in aws sdk
   AWS_ENDPOINT_FORCE: http://localhost:4566
   AWS_ENDPOINT_FORCE: http://localhost:4566
   KINESIS_INITIALIZE_STREAMS: "stream-1-shard:1,stream-2-shards:2"
   KINESIS_INITIALIZE_STREAMS: "stream-1-shard:1,stream-2-shards:2"
+  CROWDSEC_FEATURE_DISABLE_HTTP_RETRY_BACKOFF: true
 
 
 jobs:
 jobs:
 
 

+ 97 - 23
pkg/apiclient/auth.go

@@ -3,6 +3,8 @@ package apiclient
 import (
 import (
 	"bytes"
 	"bytes"
 	"encoding/json"
 	"encoding/json"
+	"math/rand"
+	"sync"
 	"time"
 	"time"
 
 
 	//"errors"
 	//"errors"
@@ -12,6 +14,7 @@ import (
 	"net/http/httputil"
 	"net/http/httputil"
 	"net/url"
 	"net/url"
 
 
+	"github.com/crowdsecurity/crowdsec/pkg/fflag"
 	"github.com/crowdsecurity/crowdsec/pkg/models"
 	"github.com/crowdsecurity/crowdsec/pkg/models"
 	"github.com/go-openapi/strfmt"
 	"github.com/go-openapi/strfmt"
 	"github.com/pkg/errors"
 	"github.com/pkg/errors"
@@ -75,6 +78,53 @@ func (t *APIKeyTransport) transport() http.RoundTripper {
 	return http.DefaultTransport
 	return http.DefaultTransport
 }
 }
 
 
+type retryRoundTripper struct {
+	next             http.RoundTripper
+	maxAttempts      int
+	retryStatusCodes []int
+	withBackOff      bool
+	onBeforeRequest  func(attempt int)
+}
+
+func (r retryRoundTripper) ShouldRetry(statusCode int) bool {
+	for _, code := range r.retryStatusCodes {
+		if code == statusCode {
+			return true
+		}
+	}
+	return false
+}
+
+func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+	var resp *http.Response
+	var err error
+	backoff := 0
+	for i := 0; i < r.maxAttempts; i++ {
+		if i > 0 {
+			if r.withBackOff && !fflag.DisableHttpRetryBackoff.IsEnabled() {
+				backoff += 10 + rand.Intn(20)
+			}
+			log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts)
+			select {
+			case <-req.Context().Done():
+				return resp, req.Context().Err()
+			case <-time.After(time.Duration(backoff) * time.Second):
+			}
+		}
+		if r.onBeforeRequest != nil {
+			r.onBeforeRequest(i)
+		}
+		clonedReq := cloneRequest(req)
+		resp, err = r.next.RoundTrip(clonedReq)
+		if err == nil {
+			if !r.ShouldRetry(resp.StatusCode) {
+				return resp, nil
+			}
+		}
+	}
+	return resp, err
+}
+
 type JWTTransport struct {
 type JWTTransport struct {
 	MachineID     *string
 	MachineID     *string
 	Password      *strfmt.Password
 	Password      *strfmt.Password
@@ -86,9 +136,9 @@ type JWTTransport struct {
 	UserAgent     string
 	UserAgent     string
 	// Transport is the underlying HTTP transport to use when making requests.
 	// Transport is the underlying HTTP transport to use when making requests.
 	// It will default to http.DefaultTransport if nil.
 	// It will default to http.DefaultTransport if nil.
-	Transport      http.RoundTripper
-	UpdateScenario func() ([]string, error)
-	NbRetry        int
+	Transport         http.RoundTripper
+	UpdateScenario    func() ([]string, error)
+	refreshTokenMutex sync.Mutex
 }
 }
 
 
 func (t *JWTTransport) refreshJwtToken() error {
 func (t *JWTTransport) refreshJwtToken() error {
@@ -124,7 +174,14 @@ func (t *JWTTransport) refreshJwtToken() error {
 		return errors.Wrap(err, "could not create request")
 		return errors.Wrap(err, "could not create request")
 	}
 	}
 	req.Header.Add("Content-Type", "application/json")
 	req.Header.Add("Content-Type", "application/json")
-	client := &http.Client{}
+	client := &http.Client{
+		Transport: &retryRoundTripper{
+			next:             http.DefaultTransport,
+			maxAttempts:      5,
+			withBackOff:      true,
+			retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError},
+		},
+	}
 	if t.UserAgent != "" {
 	if t.UserAgent != "" {
 		req.Header.Add("User-Agent", t.UserAgent)
 		req.Header.Add("User-Agent", t.UserAgent)
 	}
 	}
@@ -150,6 +207,7 @@ func (t *JWTTransport) refreshJwtToken() error {
 
 
 	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
 	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
 		log.Debugf("received response status %q when fetching %v", resp.Status, req.URL)
 		log.Debugf("received response status %q when fetching %v", resp.Status, req.URL)
+
 		err = CheckResponse(resp)
 		err = CheckResponse(resp)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
@@ -170,25 +228,21 @@ func (t *JWTTransport) refreshJwtToken() error {
 
 
 // RoundTrip implements the RoundTripper interface.
 // RoundTrip implements the RoundTripper interface.
 func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
-	if t.NbRetry > 1 {
-		t.NbRetry = 0
-		return nil, fmt.Errorf("unable to refresh JWT token multiple times")
-	}
+	// in a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI
+	// we use a mutex to avoid this
+	t.refreshTokenMutex.Lock()
 	if t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) {
 	if t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) {
 		if err := t.refreshJwtToken(); err != nil {
 		if err := t.refreshJwtToken(); err != nil {
+			t.refreshTokenMutex.Unlock()
 			return nil, err
 			return nil, err
 		}
 		}
 	}
 	}
+	t.refreshTokenMutex.Unlock()
 
 
 	if t.UserAgent != "" {
 	if t.UserAgent != "" {
 		req.Header.Add("User-Agent", t.UserAgent)
 		req.Header.Add("User-Agent", t.UserAgent)
 	}
 	}
 
 
-	// We must make a copy of the Request so
-	// that we don't modify the Request we were given. This is required by the
-	// specification of http.RoundTripper.
-	clonedReq := cloneRequest(req)
-
 	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token))
 	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token))
 
 
 	if log.GetLevel() >= log.TraceLevel {
 	if log.GetLevel() >= log.TraceLevel {
@@ -209,14 +263,6 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 		return resp, errors.Wrapf(err, "performing jwt auth")
 		return resp, errors.Wrapf(err, "performing jwt auth")
 	}
 	}
 
 
-	if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
-		t.Token = ""
-		t.NbRetry++
-		return t.RoundTrip(clonedReq)
-	}
-
-	t.NbRetry = 0
-
 	log.Debugf("resp-jwt: %d", resp.StatusCode)
 	log.Debugf("resp-jwt: %d", resp.StatusCode)
 
 
 	return resp, nil
 	return resp, nil
@@ -226,11 +272,39 @@ func (t *JWTTransport) Client() *http.Client {
 	return &http.Client{Transport: t}
 	return &http.Client{Transport: t}
 }
 }
 
 
+func (t *JWTTransport) ResetToken() {
+	log.Debug("resetting jwt token")
+	t.refreshTokenMutex.Lock()
+	t.Token = ""
+	t.refreshTokenMutex.Unlock()
+}
+
 func (t *JWTTransport) transport() http.RoundTripper {
 func (t *JWTTransport) transport() http.RoundTripper {
+	var transport http.RoundTripper
 	if t.Transport != nil {
 	if t.Transport != nil {
-		return t.Transport
+		transport = t.Transport
+	} else {
+		transport = http.DefaultTransport
+	}
+	// a round tripper that retries once when the status is unauthorized and 5 times when infrastructure is overloaded
+	return &retryRoundTripper{
+		next: &retryRoundTripper{
+			next:             transport,
+			maxAttempts:      5,
+			withBackOff:      true,
+			retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout},
+		},
+		maxAttempts:      2,
+		withBackOff:      false,
+		retryStatusCodes: []int{http.StatusUnauthorized, http.StatusForbidden},
+		onBeforeRequest: func(attempt int) {
+			// reset the token only in the second attempt as this is when we know we had a 401 or 403
+			// the second attempt is supposed to refresh the token
+			if attempt > 0 {
+				t.ResetToken()
+			}
+		},
 	}
 	}
-	return http.DefaultTransport
 }
 }
 
 
 // cloneRequest returns a clone of the provided *http.Request. The clone is a
 // cloneRequest returns a clone of the provided *http.Request. The clone is a

+ 1 - 1
pkg/apiclient/auth_service_test.go

@@ -234,5 +234,5 @@ func TestWatcherEnroll(t *testing.T) {
 	}
 	}
 
 
 	_, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false)
 	_, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false)
-	assert.Contains(t, err.Error(), "unable to refresh JWT token multiple times", "got %s", err.Error())
+	assert.Contains(t, err.Error(), "the attachment key provided is not valid", "got %s", err.Error())
 }
 }

+ 0 - 1
pkg/apiclient/client.go

@@ -51,7 +51,6 @@ func NewClient(config *Config) (*ApiClient, error) {
 		UserAgent:      config.UserAgent,
 		UserAgent:      config.UserAgent,
 		VersionPrefix:  config.VersionPrefix,
 		VersionPrefix:  config.VersionPrefix,
 		UpdateScenario: config.UpdateScenario,
 		UpdateScenario: config.UpdateScenario,
-		NbRetry:        0,
 	}
 	}
 	tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
 	tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
 	tlsconfig.RootCAs = CaCertPool
 	tlsconfig.RootCAs = CaCertPool

+ 5 - 0
pkg/fflag/crowdsec.go

@@ -3,12 +3,17 @@ package fflag
 var Crowdsec = FeatureRegister{EnvPrefix: "CROWDSEC_FEATURE_"}
 var Crowdsec = FeatureRegister{EnvPrefix: "CROWDSEC_FEATURE_"}
 
 
 var CscliSetup = &Feature{Name: "cscli_setup"}
 var CscliSetup = &Feature{Name: "cscli_setup"}
+var DisableHttpRetryBackoff = &Feature{Name: "disable_http_retry_backoff", Description: "Disable http retry backoff"}
 
 
 func RegisterAllFeatures() error {
 func RegisterAllFeatures() error {
 	err := Crowdsec.RegisterFeature(CscliSetup)
 	err := Crowdsec.RegisterFeature(CscliSetup)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
+	err = Crowdsec.RegisterFeature(DisableHttpRetryBackoff)
+	if err != nil {
+		return err
+	}
 
 
 	return nil
 	return nil
 }
 }

+ 1 - 0
tests/lib/setup.sh

@@ -9,3 +9,4 @@ load "../lib/bats-assert/load.bash"
 # mark the start of each test in the logs, beware crowdsec might be running
 # mark the start of each test in the logs, beware crowdsec might be running
 # echo "time=\"$(date +"%d-%m-%Y %H:%M:%S")\" level=info msg=\"TEST: ${BATS_TEST_DESCRIPTION}\"" >> /var/log/crowdsec.log
 # echo "time=\"$(date +"%d-%m-%Y %H:%M:%S")\" level=info msg=\"TEST: ${BATS_TEST_DESCRIPTION}\"" >> /var/log/crowdsec.log
 
 
+export CROWDSEC_FEATURE_DISABLE_HTTP_RETRY_BACKOFF=true