Browse Source

apiclient: handle 0-byte error response (#2716)

* apiclient: correctly handle 0-byte response
* lint
mmetc 1 year ago
parent
commit
437a97510a

+ 13 - 7
pkg/apiclient/auth.go

@@ -3,6 +3,7 @@ package apiclient
 import (
 	"bytes"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"math/rand"
@@ -13,7 +14,6 @@ import (
 	"time"
 
 	"github.com/go-openapi/strfmt"
-	"github.com/pkg/errors"
 	log "github.com/sirupsen/logrus"
 
 	"github.com/crowdsecurity/crowdsec/pkg/fflag"
@@ -52,10 +52,12 @@ func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 		dump, _ := httputil.DumpRequest(req, true)
 		log.Tracef("auth-api request: %s", string(dump))
 	}
+
 	// Make the HTTP request.
 	resp, err := t.transport().RoundTrip(req)
 	if err != nil {
 		log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err)
+
 		return resp, err
 	}
 
@@ -115,10 +117,12 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
 	for i := 0; i < maxAttempts; i++ {
 		if i > 0 {
 			if r.withBackOff {
+				//nolint:gosec
 				backoff += 10 + rand.Intn(20)
 			}
 
 			log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts)
+
 			select {
 			case <-req.Context().Done():
 				return resp, req.Context().Err()
@@ -134,8 +138,7 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
 		resp, err = r.next.RoundTrip(clonedReq)
 
 		if err != nil {
-			left := maxAttempts - i - 1
-			if left > 0 {
+			if left := maxAttempts - i - 1; left > 0 {
 				log.Errorf("error while performing request: %s; %d retries left", err, left)
 			}
 
@@ -177,7 +180,7 @@ func (t *JWTTransport) refreshJwtToken() error {
 		log.Debugf("scenarios list updated for '%s'", *t.MachineID)
 	}
 
-	var auth = models.WatcherAuthRequest{
+	auth := models.WatcherAuthRequest{
 		MachineID: t.MachineID,
 		Password:  t.Password,
 		Scenarios: t.Scenarios,
@@ -264,13 +267,14 @@ func (t *JWTTransport) refreshJwtToken() error {
 
 // RoundTrip implements the RoundTripper interface.
 func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
-	// in a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI
+	// In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI
 	// we use a mutex to avoid this
-	//We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request)
+	// We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request)
 	t.refreshTokenMutex.Lock()
 	if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) {
 		if err := t.refreshJwtToken(); err != nil {
 			t.refreshTokenMutex.Unlock()
+
 			return nil, err
 		}
 	}
@@ -296,8 +300,9 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
 	}
 
 	if err != nil {
-		/*we had an error (network error for example, or 401 because token is refused), reset the token ?*/
+		// we had an error (network error for example, or 401 because token is refused), reset the token ?
 		t.Token = ""
+
 		return resp, fmt.Errorf("performing jwt auth: %w", err)
 	}
 
@@ -355,6 +360,7 @@ func cloneRequest(r *http.Request) *http.Request {
 	*r2 = *r
 	// deep copy of the Header
 	r2.Header = make(http.Header, len(r.Header))
+
 	for k, s := range r.Header {
 		r2.Header[k] = append([]string(nil), s...)
 	}

+ 3 - 3
pkg/apiclient/client.go

@@ -74,6 +74,7 @@ func NewClient(config *Config) (*ApiClient, error) {
 		VersionPrefix:  config.VersionPrefix,
 		UpdateScenario: config.UpdateScenario,
 	}
+
 	tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
 	tlsconfig.RootCAs = CaCertPool
 
@@ -180,8 +181,7 @@ func (e *ErrorResponse) Error() string {
 }
 
 func newResponse(r *http.Response) *Response {
-	response := &Response{Response: r}
-	return response
+	return &Response{Response: r}
 }
 
 func CheckResponse(r *http.Response) error {
@@ -192,7 +192,7 @@ func CheckResponse(r *http.Response) error {
 	errorResponse := &ErrorResponse{}
 
 	data, err := io.ReadAll(r.Body)
-	if err == nil && data != nil {
+	if err == nil && len(data)>0 {
 		err := json.Unmarshal(data, errorResponse)
 		if err != nil {
 			return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err)

+ 3 - 1
pkg/apiclient/decisions_service.go

@@ -183,7 +183,8 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
 
 	req = req.WithContext(ctx)
 	log.Debugf("[URL] %s %s", req.Method, req.URL)
-	// we dont use client_http Do method because we need the reader and is not provided. We would be forced to use Pipe and goroutine, etc
+	// we don't use client_http Do method because we need the reader and is not provided.
+	// We would be forced to use Pipe and goroutine, etc
 	resp, err := client.Do(req)
 	if resp != nil && resp.Body != nil {
 		defer resp.Body.Close()
@@ -216,6 +217,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
 
 	if resp.StatusCode != http.StatusOK {
 		log.Debugf("Received nok status code %d for blocklist %s", resp.StatusCode, *blocklist.URL)
+
 		return nil, false, nil
 	}
 

+ 2 - 0
pkg/apiclient/heartbeat.go

@@ -38,11 +38,13 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) {
 			select {
 			case <-hbTimer.C:
 				log.Debug("heartbeat: sending heartbeat")
+
 				ok, resp, err := h.Ping(ctx)
 				if err != nil {
 					log.Errorf("heartbeat error : %s", err)
 					continue
 				}
+
 				resp.Response.Body.Close()
 				if resp.Response.StatusCode != http.StatusOK {
 					log.Errorf("heartbeat unexpected return code : %d", resp.Response.StatusCode)

+ 4 - 0
pkg/apiserver/apiserver.go

@@ -307,6 +307,7 @@ func (s *APIServer) Run(apiReady chan bool) error {
 				log.Errorf("capi push: %s", err)
 				return err
 			}
+
 			return nil
 		})
 
@@ -315,6 +316,7 @@ func (s *APIServer) Run(apiReady chan bool) error {
 				log.Errorf("capi pull: %s", err)
 				return err
 			}
+
 			return nil
 		})
 
@@ -328,6 +330,7 @@ func (s *APIServer) Run(apiReady chan bool) error {
 							log.Errorf("papi pull: %s", err)
 							return err
 						}
+
 						return nil
 					})
 
@@ -336,6 +339,7 @@ func (s *APIServer) Run(apiReady chan bool) error {
 							log.Errorf("capi decisions sync: %s", err)
 							return err
 						}
+
 						return nil
 					})
 				} else {

+ 2 - 1
pkg/apiserver/controllers/controller.go

@@ -55,6 +55,7 @@ func serveHealth() http.HandlerFunc {
 		// no caching required
 		health.WithDisabledCache(),
 	)
+
 	return health.NewHandler(checker)
 }
 
@@ -76,6 +77,7 @@ func (c *Controller) NewV1() error {
 	if err != nil {
 		return err
 	}
+
 	c.Router.GET("/health", gin.WrapF(serveHealth()))
 	c.Router.Use(v1.PrometheusMiddleware())
 	c.Router.HandleMethodNotAllowed = true
@@ -104,7 +106,6 @@ func (c *Controller) NewV1() error {
 		jwtAuth.DELETE("/decisions", c.HandlerV1.DeleteDecisions)
 		jwtAuth.DELETE("/decisions/:decision_id", c.HandlerV1.DeleteDecisionById)
 		jwtAuth.GET("/heartbeat", c.HandlerV1.HeartBeat)
-
 	}
 
 	apiKeyAuth := groupV1.Group("")

+ 15 - 6
pkg/apiserver/controllers/v1/alerts.go

@@ -22,7 +22,6 @@ import (
 )
 
 func FormatOneAlert(alert *ent.Alert) *models.Alert {
-	var outputAlert models.Alert
 	startAt := alert.StartedAt.String()
 	StopAt := alert.StoppedAt.String()
 
@@ -31,7 +30,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert {
 		machineID = alert.Edges.Owner.MachineId
 	}
 
-	outputAlert = models.Alert{
+	outputAlert := models.Alert{
 		ID:              int64(alert.ID),
 		MachineID:       machineID,
 		CreatedAt:       alert.CreatedAt.Format(time.RFC3339),
@@ -58,23 +57,27 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert {
 			Longitude: alert.SourceLongitude,
 		},
 	}
+
 	for _, eventItem := range alert.Edges.Events {
 		var Metas models.Meta
 		timestamp := eventItem.Time.String()
 		if err := json.Unmarshal([]byte(eventItem.Serialized), &Metas); err != nil {
 			log.Errorf("unable to unmarshall events meta '%s' : %s", eventItem.Serialized, err)
 		}
+
 		outputAlert.Events = append(outputAlert.Events, &models.Event{
 			Timestamp: &timestamp,
 			Meta:      Metas,
 		})
 	}
+
 	for _, metaItem := range alert.Edges.Metas {
 		outputAlert.Meta = append(outputAlert.Meta, &models.MetaItems0{
 			Key:   metaItem.Key,
 			Value: metaItem.Value,
 		})
 	}
+
 	for _, decisionItem := range alert.Edges.Decisions {
 		duration := decisionItem.Until.Sub(time.Now().UTC()).String()
 		outputAlert.Decisions = append(outputAlert.Decisions, &models.Decision{
@@ -88,6 +91,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert {
 			ID:        int64(decisionItem.ID),
 		})
 	}
+
 	return &outputAlert
 }
 
@@ -97,6 +101,7 @@ func FormatAlerts(result []*ent.Alert) models.AddAlertsRequest {
 	for _, alertItem := range result {
 		data = append(data, FormatOneAlert(alertItem))
 	}
+
 	return data
 }
 
@@ -107,6 +112,7 @@ func (c *Controller) sendAlertToPluginChannel(alert *models.Alert, profileID uin
 			select {
 			case c.PluginChannel <- csplugin.ProfileAlert{ProfileID: profileID, Alert: alert}:
 				log.Debugf("alert sent to Plugin channel")
+
 				break RETRY
 			default:
 				log.Warningf("Cannot send alert to Plugin channel (try: %d)", try)
@@ -133,7 +139,6 @@ func normalizeScope(scope string) string {
 
 // CreateAlert writes the alerts received in the body to the database
 func (c *Controller) CreateAlert(gctx *gin.Context) {
-
 	var input models.AddAlertsRequest
 
 	claims := jwt.ExtractClaims(gctx)
@@ -144,13 +149,16 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
 		gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
 		return
 	}
+
 	if err := input.Validate(strfmt.Default); err != nil {
 		c.HandleDBErrors(gctx, err)
 		return
 	}
+
 	stopFlush := false
+
 	for _, alert := range input {
-		//normalize scope for alert.Source and decisions
+		// normalize scope for alert.Source and decisions
 		if alert.Source.Scope != nil {
 			*alert.Source.Scope = normalizeScope(*alert.Source.Scope)
 		}
@@ -161,15 +169,16 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
 		}
 
 		alert.MachineID = machineID
-		//generate uuid here for alert
+		// generate uuid here for alert
 		alert.UUID = uuid.NewString()
 
-		//if coming from cscli, alert already has decisions
+		// if coming from cscli, alert already has decisions
 		if len(alert.Decisions) != 0 {
 			//alert already has a decision (cscli decisions add etc.), generate uuid here
 			for _, decision := range alert.Decisions {
 				decision.UUID = uuid.NewString()
 			}
+
 			for pIdx, profile := range c.Profiles {
 				_, matched, err := profile.EvaluateProfile(alert)
 				if err != nil {