refact pkg/apiclient (#2846)
* extract resperr.go * extract method prepareRequest() * reset token inside mutex
This commit is contained in:
parent
3e3df5e4c6
commit
8da490f593
4 changed files with 73 additions and 49 deletions
|
@ -130,20 +130,24 @@ func (t *JWTTransport) refreshJwtToken() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// RoundTrip implements the RoundTripper interface.
|
||||
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI
|
||||
// we use a mutex to avoid this
|
||||
// We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request)
|
||||
t.refreshTokenMutex.Lock()
|
||||
if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) {
|
||||
if err := t.refreshJwtToken(); err != nil {
|
||||
t.refreshTokenMutex.Unlock()
|
||||
func (t *JWTTransport) needsTokenRefresh() bool {
|
||||
return t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())
|
||||
}
|
||||
|
||||
// prepareRequest returns a copy of the request with the necessary authentication headers.
|
||||
func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error) {
|
||||
// In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless
|
||||
// and will cause overload on CAPI. We use a mutex to avoid this.
|
||||
t.refreshTokenMutex.Lock()
|
||||
defer t.refreshTokenMutex.Unlock()
|
||||
|
||||
// We bypass the refresh if we are requesting the login endpoint, as it does not require a token,
|
||||
// and it leads to do 2 requests instead of one (refresh + actual login request).
|
||||
if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && t.needsTokenRefresh() {
|
||||
if err := t.refreshJwtToken(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
t.refreshTokenMutex.Unlock()
|
||||
|
||||
if t.UserAgent != "" {
|
||||
req.Header.Add("User-Agent", t.UserAgent)
|
||||
|
@ -151,6 +155,16 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token))
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// RoundTrip implements the RoundTripper interface.
|
||||
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req, err := t.prepareRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if log.GetLevel() >= log.TraceLevel {
|
||||
//requestToDump := cloneRequest(req)
|
||||
dump, _ := httputil.DumpRequest(req, true)
|
||||
|
@ -166,7 +180,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
|
||||
if err != nil {
|
||||
// we had an error (network error for example, or 401 because token is refused), reset the token?
|
||||
t.Token = ""
|
||||
t.ResetToken()
|
||||
|
||||
return resp, fmt.Errorf("performing jwt auth: %w", err)
|
||||
}
|
||||
|
@ -189,7 +203,8 @@ func (t *JWTTransport) ResetToken() {
|
|||
t.refreshTokenMutex.Unlock()
|
||||
}
|
||||
|
||||
// transport() returns a round tripper that retries once when the status is unauthorized, and 5 times when the infrastructure is overloaded.
|
||||
// transport() returns a round tripper that retries once when the status is unauthorized,
|
||||
// and 5 times when the infrastructure is overloaded.
|
||||
func (t *JWTTransport) transport() http.RoundTripper {
|
||||
transport := t.Transport
|
||||
if transport == nil {
|
||||
|
|
|
@ -4,9 +4,7 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
|
@ -167,44 +165,10 @@ type Response struct {
|
|||
//...
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
models.ErrorResponse
|
||||
}
|
||||
|
||||
func (e *ErrorResponse) Error() string {
|
||||
err := fmt.Sprintf("API error: %s", *e.Message)
|
||||
if len(e.Errors) > 0 {
|
||||
err += fmt.Sprintf(" (%s)", e.Errors)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func newResponse(r *http.Response) *Response {
|
||||
return &Response{Response: r}
|
||||
}
|
||||
|
||||
func CheckResponse(r *http.Response) error {
|
||||
if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 {
|
||||
return nil
|
||||
}
|
||||
|
||||
errorResponse := &ErrorResponse{}
|
||||
|
||||
data, err := io.ReadAll(r.Body)
|
||||
if err == nil && len(data)>0 {
|
||||
err := json.Unmarshal(data, errorResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err)
|
||||
}
|
||||
} else {
|
||||
errorResponse.Message = new(string)
|
||||
*errorResponse.Message = fmt.Sprintf("http code %d, no error message", r.StatusCode)
|
||||
}
|
||||
|
||||
return errorResponse
|
||||
}
|
||||
|
||||
type ListOpts struct {
|
||||
//Page int
|
||||
//PerPage int
|
||||
|
|
46
pkg/apiclient/resperr.go
Normal file
46
pkg/apiclient/resperr.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package apiclient
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/crowdsecurity/go-cs-lib/ptr"
|
||||
|
||||
"github.com/crowdsecurity/crowdsec/pkg/models"
|
||||
)
|
||||
|
||||
type ErrorResponse struct {
|
||||
models.ErrorResponse
|
||||
}
|
||||
|
||||
func (e *ErrorResponse) Error() string {
|
||||
err := fmt.Sprintf("API error: %s", *e.Message)
|
||||
if len(e.Errors) > 0 {
|
||||
err += fmt.Sprintf(" (%s)", e.Errors)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// CheckResponse verifies the API response and builds an appropriate Go error if necessary.
|
||||
func CheckResponse(r *http.Response) error {
|
||||
if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ret := &ErrorResponse{}
|
||||
|
||||
data, err := io.ReadAll(r.Body)
|
||||
if err != nil || len(data) == 0 {
|
||||
ret.Message = ptr.Of(fmt.Sprintf("http code %d, no error message", r.StatusCode))
|
||||
return ret
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, ret); err != nil {
|
||||
return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
|
@ -539,7 +539,6 @@ func createAlertForDecision(decision *models.Decision) *models.Alert {
|
|||
scenario = *decision.Scenario
|
||||
scope = types.ListOrigin
|
||||
default:
|
||||
// XXX: this or nil?
|
||||
scenario = ""
|
||||
scope = ""
|
||||
|
||||
|
|
Loading…
Reference in a new issue