Explorar el Código

lint (errorlint) (#2644)

mmetc hace 1 año
padre
commit
08694adf1b

+ 2 - 1
cmd/crowdsec-cli/explain.go

@@ -2,6 +2,7 @@ package main
 
 import (
 	"bufio"
+	"errors"
 	"fmt"
 	"io"
 	"os"
@@ -196,7 +197,7 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error {
 			errCount := 0
 			for {
 				input, err := reader.ReadBytes('\n')
-				if err != nil && err == io.EOF {
+				if err != nil && errors.Is(err, io.EOF) {
 					break
 				}
 				if len(input) > 1 {

+ 2 - 1
pkg/cache/cache.go

@@ -1,6 +1,7 @@
 package cache
 
 import (
+	"errors"
 	"time"
 
 	"github.com/bluele/gcache"
@@ -104,7 +105,7 @@ func GetKey(cacheName string, key string) (string, error) {
 		if name == cacheName {
 			if value, err := Caches[i].Get(key); err != nil {
 				//do not warn or log if key not found
-				if err == gcache.KeyNotFoundError {
+				if errors.Is(err, gcache.KeyNotFoundError) {
 					return "", nil
 				}
 				CacheConfig[i].Logger.Warningf("While getting key %s in cache %s: %s", key, cacheName, err)

+ 1 - 1
pkg/cticlient/client.go

@@ -71,7 +71,7 @@ func (c *CrowdsecCTIClient) doRequest(method string, endpoint string, params map
 func (c *CrowdsecCTIClient) GetIPInfo(ip string) (*SmokeItem, error) {
 	body, err := c.doRequest(http.MethodGet, smokeEndpoint+"/"+ip, nil)
 	if err != nil {
-		if err == ErrNotFound {
+		if errors.Is(err, ErrNotFound) {
 			return &SmokeItem{}, nil
 		}
 		return nil, err

+ 1 - 1
pkg/database/alerts.go

@@ -742,7 +742,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str
 	if machineID != "" {
 		owner, err = c.QueryMachineByID(machineID)
 		if err != nil {
-			if errors.Cause(err) != UserNotExists {
+			if !errors.Is(err, UserNotExists) {
 				return nil, fmt.Errorf("machine '%s': %w", machineID, err)
 			}
 

+ 3 - 3
pkg/exprhelpers/crowdsec_cti.go

@@ -104,12 +104,12 @@ func CrowdsecCTI(params ...any) (any, error) {
 	ctiResp, err := ctiClient.GetIPInfo(ip)
 	ctiClient.Logger.Debugf("request for %s took %v", ip, time.Since(before))
 	if err != nil {
-		switch err {
-		case cticlient.ErrUnauthorized:
+		switch {
+		case errors.Is(err, cticlient.ErrUnauthorized):
 			CTIApiEnabled = false
 			ctiClient.Logger.Errorf("Invalid API key provided, disabling CTI API")
 			return &cticlient.SmokeItem{}, cticlient.ErrUnauthorized
-		case cticlient.ErrLimit:
+		case errors.Is(err, cticlient.ErrLimit):
 			CTIBackOffUntil = time.Now().Add(CTIBackOffDuration)
 			ctiClient.Logger.Errorf("CTI API is throttled, will try again in %s", CTIBackOffDuration)
 			return &cticlient.SmokeItem{}, cticlient.ErrLimit

+ 2 - 1
pkg/exprhelpers/crowdsec_cti_test.go

@@ -3,6 +3,7 @@ package exprhelpers
 import (
 	"bytes"
 	"encoding/json"
+	"errors"
 	"io"
 	"net/http"
 	"strings"
@@ -108,7 +109,7 @@ func smokeHandler(req *http.Request) *http.Response {
 
 func TestNillClient(t *testing.T) {
 	defer ShutdownCrowdsecCTI()
-	if err := InitCrowdsecCTI(ptr.Of(""), nil, nil, nil); err != cticlient.ErrDisabled {
+	if err := InitCrowdsecCTI(ptr.Of(""), nil, nil, nil); !errors.Is(err, cticlient.ErrDisabled) {
 		t.Fatalf("failed to init CTI : %s", err)
 	}
 	item, err := CrowdsecCTI("1.2.3.4")

+ 4 - 3
pkg/longpollclient/client.go

@@ -2,6 +2,7 @@ package longpollclient
 
 import (
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"net/http"
@@ -112,7 +113,7 @@ func (c *LongPollClient) poll() error {
 			var pollResp pollResponse
 			err = decoder.Decode(&pollResp)
 			if err != nil {
-				if err == io.EOF {
+				if errors.Is(err, io.EOF) {
 					logger.Debugf("server closed connection")
 					return nil
 				}
@@ -158,7 +159,7 @@ func (c *LongPollClient) pollEvents() error {
 			err := c.poll()
 			if err != nil {
 				c.logger.Errorf("failed to poll: %s", err)
-				if err == errUnauthorized {
+				if errors.Is(err, errUnauthorized) {
 					c.t.Kill(err)
 					close(c.c)
 					return err
@@ -198,7 +199,7 @@ func (c *LongPollClient) PullOnce(since time.Time) ([]Event, error) {
 		var pollResp pollResponse
 		err = decoder.Decode(&pollResp)
 		if err != nil {
-			if err == io.EOF {
+			if errors.Is(err, io.EOF) {
 				c.logger.Debugf("server closed connection")
 				break
 			}

+ 12 - 22
pkg/parser/enrich_date_test.go

@@ -4,34 +4,33 @@ import (
 	"testing"
 
 	log "github.com/sirupsen/logrus"
+	"github.com/stretchr/testify/assert"
 
-	"github.com/crowdsecurity/go-cs-lib/ptr"
+	"github.com/crowdsecurity/go-cs-lib/cstest"
 
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 )
 
 func TestDateParse(t *testing.T) {
 	tests := []struct {
-		name             string
-		evt              types.Event
-		expected_err     *error
-		expected_strTime *string
+		name        string
+		evt         types.Event
+		expectedErr string
+		expected    string
 	}{
 		{
 			name: "RFC3339",
 			evt: types.Event{
 				StrTime: "2019-10-12T07:20:50.52Z",
 			},
-			expected_err:     nil,
-			expected_strTime: ptr.Of("2019-10-12T07:20:50.52Z"),
+			expected: "2019-10-12T07:20:50.52Z",
 		},
 		{
 			name: "02/Jan/2006:15:04:05 -0700",
 			evt: types.Event{
 				StrTime: "02/Jan/2006:15:04:05 -0700",
 			},
-			expected_err:     nil,
-			expected_strTime: ptr.Of("2006-01-02T15:04:05-07:00"),
+			expected: "2006-01-02T15:04:05-07:00",
 		},
 		{
 			name: "Dec 17 08:17:43",
@@ -39,8 +38,7 @@ func TestDateParse(t *testing.T) {
 				StrTime:       "2011 X 17 zz 08X17X43 oneone Dec",
 				StrTimeFormat: "2006 X 2 zz 15X04X05 oneone Jan",
 			},
-			expected_err:     nil,
-			expected_strTime: ptr.Of("2011-12-17T08:17:43Z"),
+			expected: "2011-12-17T08:17:43Z",
 		},
 	}
 
@@ -51,19 +49,11 @@ func TestDateParse(t *testing.T) {
 		tt := tt
 		t.Run(tt.name, func(t *testing.T) {
 			strTime, err := ParseDate(tt.evt.StrTime, &tt.evt, nil, logger)
-			if tt.expected_err != nil {
-				if err != *tt.expected_err {
-					t.Errorf("%s: expected error %v, got %v", tt.name, tt.expected_err, err)
-				}
-			} else if err != nil {
-				t.Errorf("%s: expected no error, got %v", tt.name, err)
-			}
-			if err != nil {
+			cstest.RequireErrorContains(t, err, tt.expectedErr)
+			if tt.expectedErr != "" {
 				return
 			}
-			if tt.expected_strTime != nil && strTime["MarshaledTime"] != *tt.expected_strTime {
-				t.Errorf("expected strTime %s, got %s", *tt.expected_strTime, strTime["MarshaledTime"])
-			}
+			assert.Equal(t, tt.expected, strTime["MarshaledTime"])
 		})
 	}
 }