refact pkg/apiclient (#2846)

* extract resperr.go
* extract method prepareRequest()
* reset token inside mutex
This commit is contained in:
mmetc 2024-02-22 11:42:33 +01:00 committed by GitHub
parent 3e3df5e4c6
commit 8da490f593
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 73 additions and 49 deletions

View file

@ -130,20 +130,24 @@ func (t *JWTTransport) refreshJwtToken() error {
return nil
}
// RoundTrip implements the RoundTripper interface.
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI
// we use a mutex to avoid this
// We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request)
t.refreshTokenMutex.Lock()
if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) {
if err := t.refreshJwtToken(); err != nil {
t.refreshTokenMutex.Unlock()
func (t *JWTTransport) needsTokenRefresh() bool {
return t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())
}
// prepareRequest returns a copy of the request with the necessary authentication headers.
func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error) {
// In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless
// and will cause overload on CAPI. We use a mutex to avoid this.
t.refreshTokenMutex.Lock()
defer t.refreshTokenMutex.Unlock()
// We bypass the refresh if we are requesting the login endpoint, as it does not require a token,
// and it leads to do 2 requests instead of one (refresh + actual login request).
if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && t.needsTokenRefresh() {
if err := t.refreshJwtToken(); err != nil {
return nil, err
}
}
t.refreshTokenMutex.Unlock()
if t.UserAgent != "" {
req.Header.Add("User-Agent", t.UserAgent)
@ -151,6 +155,16 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token))
return req, nil
}
// RoundTrip implements the RoundTripper interface.
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req, err := t.prepareRequest(req)
if err != nil {
return nil, err
}
if log.GetLevel() >= log.TraceLevel {
//requestToDump := cloneRequest(req)
dump, _ := httputil.DumpRequest(req, true)
@ -166,7 +180,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if err != nil {
// we had an error (network error for example, or 401 because token is refused), reset the token?
t.Token = ""
t.ResetToken()
return resp, fmt.Errorf("performing jwt auth: %w", err)
}
@ -189,7 +203,8 @@ func (t *JWTTransport) ResetToken() {
t.refreshTokenMutex.Unlock()
}
// transport() returns a round tripper that retries once when the status is unauthorized, and 5 times when the infrastructure is overloaded.
// transport() returns a round tripper that retries once when the status is unauthorized,
// and 5 times when the infrastructure is overloaded.
func (t *JWTTransport) transport() http.RoundTripper {
transport := t.Transport
if transport == nil {

View file

@ -4,9 +4,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
@ -167,44 +165,10 @@ type Response struct {
//...
}
type ErrorResponse struct {
models.ErrorResponse
}
func (e *ErrorResponse) Error() string {
err := fmt.Sprintf("API error: %s", *e.Message)
if len(e.Errors) > 0 {
err += fmt.Sprintf(" (%s)", e.Errors)
}
return err
}
func newResponse(r *http.Response) *Response {
return &Response{Response: r}
}
func CheckResponse(r *http.Response) error {
if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 {
return nil
}
errorResponse := &ErrorResponse{}
data, err := io.ReadAll(r.Body)
if err == nil && len(data)>0 {
err := json.Unmarshal(data, errorResponse)
if err != nil {
return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err)
}
} else {
errorResponse.Message = new(string)
*errorResponse.Message = fmt.Sprintf("http code %d, no error message", r.StatusCode)
}
return errorResponse
}
type ListOpts struct {
//Page int
//PerPage int

46
pkg/apiclient/resperr.go Normal file
View file

@ -0,0 +1,46 @@
package apiclient
import (
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/crowdsec/pkg/models"
)
type ErrorResponse struct {
models.ErrorResponse
}
func (e *ErrorResponse) Error() string {
err := fmt.Sprintf("API error: %s", *e.Message)
if len(e.Errors) > 0 {
err += fmt.Sprintf(" (%s)", e.Errors)
}
return err
}
// CheckResponse verifies the API response and builds an appropriate Go error if necessary.
func CheckResponse(r *http.Response) error {
if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 {
return nil
}
ret := &ErrorResponse{}
data, err := io.ReadAll(r.Body)
if err != nil || len(data) == 0 {
ret.Message = ptr.Of(fmt.Sprintf("http code %d, no error message", r.StatusCode))
return ret
}
if err := json.Unmarshal(data, ret); err != nil {
return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err)
}
return ret
}

View file

@ -539,7 +539,6 @@ func createAlertForDecision(decision *models.Decision) *models.Alert {
scenario = *decision.Scenario
scope = types.ListOrigin
default:
// XXX: this or nil?
scenario = ""
scope = ""