From 5622ac83382843e9cf60f29bd2d5303b2511aa63 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Fri, 5 Jan 2024 15:26:13 +0100 Subject: [PATCH] CI: enable testifylint (#2696) - reverse actual and expected values - use assert.False, assert.True - use assert.Len, assert.Emtpy - use require.Error, require.NoError - use assert.InDelta --- .golangci.yml | 40 ------- pkg/acquisition/acquisition_test.go | 40 ++++--- pkg/acquisition/modules/docker/docker_test.go | 34 ++++-- .../modules/journalctl/journalctl_test.go | 42 +++++-- .../modules/kubernetesaudit/k8s_audit_test.go | 18 +-- .../modules/wineventlog/wineventlog_test.go | 6 +- pkg/alertcontext/alertcontext_test.go | 6 +- pkg/csplugin/broker_test.go | 30 ++--- pkg/csplugin/broker_win_test.go | 6 +- pkg/cticlient/client_test.go | 64 +++++++---- pkg/cticlient/types_test.go | 29 ++--- pkg/exprhelpers/crowdsec_cti_test.go | 36 +++--- pkg/exprhelpers/exprlib_test.go | 107 ++++++++++-------- pkg/exprhelpers/jsonextract_test.go | 34 +++--- pkg/setup/detect_test.go | 4 +- pkg/types/event_test.go | 2 +- 16 files changed, 283 insertions(+), 215 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index eba554d87..5c0bab58c 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -242,42 +242,6 @@ issues: # Will fix, trivial - just beware of merge conflicts - - linters: - - testifylint - text: "expected-actual: need to reverse actual and expected values" - - - linters: - - testifylint - text: "bool-compare: use assert.False" - - - linters: - - testifylint - text: "len: use assert.Len" - - - linters: - - testifylint - text: "bool-compare: use assert.True" - - - linters: - - testifylint - text: "bool-compare: use require.True" - - - linters: - - testifylint - text: "require-error: for error assertions use require" - - - linters: - - testifylint - text: "error-nil: use assert.NoError" - - - linters: - - testifylint - text: "error-nil: use assert.Error" - - - linters: - - testifylint - text: "empty: use assert.Empty" - - linters: - perfsprint text: "fmt.Sprintf can be replaced .*" @@ -286,10 +250,6 @@ issues: # Will fix, easy but some neurons required # - - linters: - - testifylint - text: "float-compare: use assert.InEpsilon .*or InDelta.*" - - linters: - errorlint text: "non-wrapping format verb for fmt.Errorf. Use `%w` to format errors" diff --git a/pkg/acquisition/acquisition_test.go b/pkg/acquisition/acquisition_test.go index c1373a6c7..44b3878e1 100644 --- a/pkg/acquisition/acquisition_test.go +++ b/pkg/acquisition/acquisition_test.go @@ -40,15 +40,19 @@ func (f *MockSource) Configure(cfg []byte, logger *log.Entry) error { if err := f.UnmarshalConfig(cfg); err != nil { return err } + if f.Mode == "" { f.Mode = configuration.CAT_MODE } + if f.Mode != configuration.CAT_MODE && f.Mode != configuration.TAIL_MODE { return fmt.Errorf("mode %s is not supported", f.Mode) } + if f.Toto == "" { return fmt.Errorf("expect non-empty toto") } + return nil } func (f *MockSource) GetMode() string { return f.Mode } @@ -77,6 +81,7 @@ func appendMockSource() { if GetDataSourceIface("mock") == nil { AcquisitionSources["mock"] = func() DataSource { return &MockSource{} } } + if GetDataSourceIface("mock_cant_run") == nil { AcquisitionSources["mock_cant_run"] = func() DataSource { return &MockSourceCantRun{} } } @@ -84,6 +89,7 @@ func appendMockSource() { func TestDataSourceConfigure(t *testing.T) { appendMockSource() + tests := []struct { TestName string String string @@ -185,22 +191,22 @@ wowo: ajsajasjas switch tc.TestName { case "basic_valid_config": mock := (*ds).Dump().(*MockSource) - assert.Equal(t, mock.Toto, "test_value1") - assert.Equal(t, mock.Mode, "cat") - assert.Equal(t, mock.logger.Logger.Level, log.InfoLevel) - assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"}) + assert.Equal(t, "test_value1", mock.Toto) + assert.Equal(t, "cat", mock.Mode) + assert.Equal(t, log.InfoLevel, mock.logger.Logger.Level) + assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels) case "basic_debug_config": mock := (*ds).Dump().(*MockSource) - assert.Equal(t, mock.Toto, "test_value1") - assert.Equal(t, mock.Mode, "cat") - assert.Equal(t, mock.logger.Logger.Level, log.DebugLevel) - assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"}) + assert.Equal(t, "test_value1", mock.Toto) + assert.Equal(t, "cat", mock.Mode) + assert.Equal(t, log.DebugLevel, mock.logger.Logger.Level) + assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels) case "basic_tailmode_config": mock := (*ds).Dump().(*MockSource) - assert.Equal(t, mock.Toto, "test_value1") - assert.Equal(t, mock.Mode, "tail") - assert.Equal(t, mock.logger.Logger.Level, log.DebugLevel) - assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"}) + assert.Equal(t, "test_value1", mock.Toto) + assert.Equal(t, "tail", mock.Mode) + assert.Equal(t, log.DebugLevel, mock.logger.Logger.Level) + assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels) } }) } @@ -208,6 +214,7 @@ wowo: ajsajasjas func TestLoadAcquisitionFromFile(t *testing.T) { appendMockSource() + tests := []struct { TestName string Config csconfig.CrowdsecServiceCfg @@ -284,7 +291,6 @@ func TestLoadAcquisitionFromFile(t *testing.T) { assert.Len(t, dss, tc.ExpectedLen) }) - } } @@ -304,9 +310,11 @@ func (f *MockCat) Configure(cfg []byte, logger *log.Entry) error { if f.Mode == "" { f.Mode = configuration.CAT_MODE } + if f.Mode != configuration.CAT_MODE { return fmt.Errorf("mode %s is not supported", f.Mode) } + return nil } @@ -319,6 +327,7 @@ func (f *MockCat) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) erro evt.Line.Src = "test" out <- evt } + return nil } func (f *MockCat) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { @@ -345,9 +354,11 @@ func (f *MockTail) Configure(cfg []byte, logger *log.Entry) error { if f.Mode == "" { f.Mode = configuration.TAIL_MODE } + if f.Mode != configuration.TAIL_MODE { return fmt.Errorf("mode %s is not supported", f.Mode) } + return nil } @@ -364,6 +375,7 @@ func (f *MockTail) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) erro out <- evt } <-t.Dying() + return nil } func (f *MockTail) CanRun() error { return nil } @@ -446,6 +458,7 @@ func (f *MockTailError) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) out <- evt } t.Kill(fmt.Errorf("got error (tomb)")) + return fmt.Errorf("got error") } @@ -499,6 +512,7 @@ func (f *MockSourceByDSN) ConfigureByDSN(dsn string, labels map[string]string, l if dsn != "test_expect" { return fmt.Errorf("unexpected value") } + return nil } func (f *MockSourceByDSN) GetUuid() string { return "" } diff --git a/pkg/acquisition/modules/docker/docker_test.go b/pkg/acquisition/modules/docker/docker_test.go index 3c3eeefe6..c4d23168a 100644 --- a/pkg/acquisition/modules/docker/docker_test.go +++ b/pkg/acquisition/modules/docker/docker_test.go @@ -57,6 +57,7 @@ container_name: subLogger := log.WithFields(log.Fields{ "type": "docker", }) + for _, test := range tests { f := DockerSource{} err := f.Configure([]byte(test.config), subLogger) @@ -66,12 +67,15 @@ container_name: func TestConfigureDSN(t *testing.T) { log.Infof("Test 'TestConfigureDSN'") + var dockerHost string + if runtime.GOOS == "windows" { dockerHost = "npipe:////./pipe/docker_engine" } else { dockerHost = "unix:///var/run/podman/podman.sock" } + tests := []struct { name string dsn string @@ -106,6 +110,7 @@ func TestConfigureDSN(t *testing.T) { subLogger := log.WithFields(log.Fields{ "type": "docker", }) + for _, test := range tests { f := DockerSource{} err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger, "") @@ -156,8 +161,11 @@ container_name_regexp: } for _, ts := range tests { - var logger *log.Logger - var subLogger *log.Entry + var ( + logger *log.Logger + subLogger *log.Entry + ) + if ts.expectedOutput != "" { logger.SetLevel(ts.logLevel) subLogger = logger.WithFields(log.Fields{ @@ -173,10 +181,12 @@ container_name_regexp: dockerTomb := tomb.Tomb{} out := make(chan types.Event) dockerSource := DockerSource{} + err := dockerSource.Configure([]byte(ts.config), subLogger) if err != nil { t.Fatalf("Unexpected error : %s", err) } + dockerSource.Client = new(mockDockerCli) actualLines := 0 readerTomb := &tomb.Tomb{} @@ -204,21 +214,23 @@ container_name_regexp: if err := readerTomb.Wait(); err != nil { t.Fatal(err) } + if ts.expectedLines != 0 { assert.Equal(t, ts.expectedLines, actualLines) } + err = streamTomb.Wait() if err != nil { t.Fatalf("docker acquisition error: %s", err) } } - } func (cli *mockDockerCli) ContainerList(ctx context.Context, options dockerTypes.ContainerListOptions) ([]dockerTypes.Container, error) { if readLogs == true { return []dockerTypes.Container{}, nil } + containers := make([]dockerTypes.Container, 0) container := &dockerTypes.Container{ ID: "12456", @@ -233,16 +245,20 @@ func (cli *mockDockerCli) ContainerLogs(ctx context.Context, container string, o if readLogs == true { return io.NopCloser(strings.NewReader("")), nil } + readLogs = true data := []string{"docker\n", "test\n", "1234\n"} ret := "" + for _, line := range data { startLineByte := make([]byte, 8) binary.LittleEndian.PutUint32(startLineByte, 1) //stdout stream binary.BigEndian.PutUint32(startLineByte[4:], uint32(len(line))) ret += fmt.Sprintf("%s%s", startLineByte, line) } + r := io.NopCloser(strings.NewReader(ret)) // r type is io.ReadCloser + return r, nil } @@ -252,6 +268,7 @@ func (cli *mockDockerCli) ContainerInspect(ctx context.Context, c string) (docke Tty: false, }, } + return r, nil } @@ -285,8 +302,11 @@ func TestOneShot(t *testing.T) { } for _, ts := range tests { - var subLogger *log.Entry - var logger *log.Logger + var ( + subLogger *log.Entry + logger *log.Logger + ) + if ts.expectedOutput != "" { logger.SetLevel(ts.logLevel) subLogger = logger.WithFields(log.Fields{ @@ -307,6 +327,7 @@ func TestOneShot(t *testing.T) { if err := dockerClient.ConfigureByDSN(ts.dsn, labels, subLogger, ""); err != nil { t.Fatalf("unable to configure dsn '%s': %s", ts.dsn, err) } + dockerClient.Client = new(mockDockerCli) out := make(chan types.Event, 100) tomb := tomb.Tomb{} @@ -315,8 +336,7 @@ func TestOneShot(t *testing.T) { // else we do the check before actualLines is incremented ... if ts.expectedLines != 0 { - assert.Equal(t, ts.expectedLines, len(out)) + assert.Len(t, out, ts.expectedLines) } } - } diff --git a/pkg/acquisition/modules/journalctl/journalctl_test.go b/pkg/acquisition/modules/journalctl/journalctl_test.go index 0ad49edd8..a91fba31b 100644 --- a/pkg/acquisition/modules/journalctl/journalctl_test.go +++ b/pkg/acquisition/modules/journalctl/journalctl_test.go @@ -21,6 +21,7 @@ func TestBadConfiguration(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { config string expectedErr string @@ -48,6 +49,7 @@ journalctl_filter: subLogger := log.WithFields(log.Fields{ "type": "journalctl", }) + for _, test := range tests { f := JournalCtlSource{} err := f.Configure([]byte(test.config), subLogger) @@ -59,6 +61,7 @@ func TestConfigureDSN(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { dsn string expectedErr string @@ -92,9 +95,11 @@ func TestConfigureDSN(t *testing.T) { expectedErr: "", }, } + subLogger := log.WithFields(log.Fields{ "type": "journalctl", }) + for _, test := range tests { f := JournalCtlSource{} err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger, "") @@ -106,6 +111,7 @@ func TestOneShot(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { config string expectedErr string @@ -137,9 +143,12 @@ journalctl_filter: }, } for _, ts := range tests { - var logger *log.Logger - var subLogger *log.Entry - var hook *test.Hook + var ( + logger *log.Logger + subLogger *log.Entry + hook *test.Hook + ) + if ts.expectedOutput != "" { logger, hook = test.NewNullLogger() logger.SetLevel(ts.logLevel) @@ -151,27 +160,32 @@ journalctl_filter: "type": "journalctl", }) } + tomb := tomb.Tomb{} out := make(chan types.Event, 100) j := JournalCtlSource{} + err := j.Configure([]byte(ts.config), subLogger) if err != nil { t.Fatalf("Unexpected error : %s", err) } + err = j.OneShotAcquisition(out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) + if err != nil { continue } if ts.expectedLines != 0 { - assert.Equal(t, ts.expectedLines, len(out)) + assert.Len(t, out, ts.expectedLines) } if ts.expectedOutput != "" { if hook.LastEntry() == nil { t.Fatalf("Expected log output '%s' but got nothing !", ts.expectedOutput) } + assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput) hook.Reset() } @@ -182,6 +196,7 @@ func TestStreaming(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { config string expectedErr string @@ -202,9 +217,12 @@ journalctl_filter: }, } for _, ts := range tests { - var logger *log.Logger - var subLogger *log.Entry - var hook *test.Hook + var ( + logger *log.Logger + subLogger *log.Entry + hook *test.Hook + ) + if ts.expectedOutput != "" { logger, hook = test.NewNullLogger() logger.SetLevel(ts.logLevel) @@ -216,14 +234,18 @@ journalctl_filter: "type": "journalctl", }) } + tomb := tomb.Tomb{} out := make(chan types.Event) j := JournalCtlSource{} + err := j.Configure([]byte(ts.config), subLogger) if err != nil { t.Fatalf("Unexpected error : %s", err) } + actualLines := 0 + if ts.expectedLines != 0 { go func() { READLOOP: @@ -240,6 +262,7 @@ journalctl_filter: err = j.StreamingAcquisition(out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) + if err != nil { continue } @@ -248,16 +271,20 @@ journalctl_filter: time.Sleep(1 * time.Second) assert.Equal(t, ts.expectedLines, actualLines) } + tomb.Kill(nil) tomb.Wait() + output, _ := exec.Command("pgrep", "-x", "journalctl").CombinedOutput() if string(output) != "" { t.Fatalf("Found a journalctl process after killing the tomb !") } + if ts.expectedOutput != "" { if hook.LastEntry() == nil { t.Fatalf("Expected log output '%s' but got nothing !", ts.expectedOutput) } + assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput) hook.Reset() } @@ -270,5 +297,6 @@ func TestMain(m *testing.M) { fullPath := filepath.Join(currentDir, "test_files") os.Setenv("PATH", fullPath+":"+os.Getenv("PATH")) } + os.Exit(m.Run()) } diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go index 799868dc8..c3502c956 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go @@ -9,6 +9,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" ) @@ -78,24 +79,23 @@ webhook_path: /k8s-audit`, err := f.UnmarshalConfig([]byte(test.config)) - assert.NoError(t, err) + require.NoError(t, err) err = f.Configure([]byte(test.config), subLogger) - assert.NoError(t, err) + require.NoError(t, err) f.StreamingAcquisition(out, tb) time.Sleep(1 * time.Second) tb.Kill(nil) err = tb.Wait() if test.expectedErr != "" { - assert.ErrorContains(t, err, test.expectedErr) + require.ErrorContains(t, err, test.expectedErr) return } - assert.NoError(t, err) + require.NoError(t, err) }) } - } func TestHandler(t *testing.T) { @@ -252,10 +252,10 @@ webhook_path: /k8s-audit`, f := KubernetesAuditSource{} err := f.UnmarshalConfig([]byte(test.config)) - assert.NoError(t, err) + require.NoError(t, err) err = f.Configure([]byte(test.config), subLogger) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest(test.method, "/k8s-audit", strings.NewReader(test.body)) w := httptest.NewRecorder() @@ -268,11 +268,11 @@ webhook_path: /k8s-audit`, assert.Equal(t, test.expectedStatusCode, res.StatusCode) //time.Sleep(1 * time.Second) - assert.NoError(t, err) + require.NoError(t, err) tb.Kill(nil) err = tb.Wait() - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.eventCount, eventCount) }) diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_test.go b/pkg/acquisition/modules/wineventlog/wineventlog_test.go index 20f8a5834..053ba88b5 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_test.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_test.go @@ -11,6 +11,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/sys/windows/svc/eventlog" "gopkg.in/tomb.v2" ) @@ -124,7 +125,7 @@ event_level: bla`, } assert.Contains(t, err.Error(), test.expectedErr) } else { - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectedQuery, q) } } @@ -221,9 +222,8 @@ event_ids: } } if test.expectedLines == nil { - assert.Equal(t, 0, len(linesRead)) + assert.Empty(t, linesRead) } else { - assert.Equal(t, len(test.expectedLines), len(linesRead)) assert.Equal(t, test.expectedLines, linesRead) } to.Kill(nil) diff --git a/pkg/alertcontext/alertcontext_test.go b/pkg/alertcontext/alertcontext_test.go index 2e7e71bd6..8b598eab8 100644 --- a/pkg/alertcontext/alertcontext_test.go +++ b/pkg/alertcontext/alertcontext_test.go @@ -7,6 +7,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewAlertContext(t *testing.T) { @@ -29,8 +30,7 @@ func TestNewAlertContext(t *testing.T) { for _, test := range tests { fmt.Printf("Running test '%s'\n", test.name) err := NewAlertContext(test.contextToSend, test.valueLength) - assert.ErrorIs(t, err, test.expectedErr) - + require.ErrorIs(t, err, test.expectedErr) } } @@ -193,7 +193,7 @@ func TestEventToContext(t *testing.T) { for _, test := range tests { fmt.Printf("Running test '%s'\n", test.name) err := NewAlertContext(test.contextToSend, test.valueLength) - assert.ErrorIs(t, err, nil) + require.NoError(t, err) metas, _ := EventToContext(test.events) assert.ElementsMatch(t, test.expectedResult, metas) diff --git a/pkg/csplugin/broker_test.go b/pkg/csplugin/broker_test.go index f41eb8031..9adb35ad7 100644 --- a/pkg/csplugin/broker_test.go +++ b/pkg/csplugin/broker_test.go @@ -149,7 +149,7 @@ func (s *PluginSuite) TestBrokerNoThreshold() { t := s.T() pb, err := s.InitBroker(nil) - assert.NoError(t, err) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -182,7 +182,7 @@ func (s *PluginSuite) TestBrokerNoThreshold() { err = json.Unmarshal(content, &alerts) log.Printf("content-> %s", content) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) } @@ -199,7 +199,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { s.writeconfig(cfg) pb, err := s.InitBroker(nil) - assert.NoError(t, err) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -215,11 +215,11 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { time.Sleep(1 * time.Second) // after 1 seconds, we should have data content, err := os.ReadFile("./out") - assert.NoError(t, err) + require.NoError(t, err) var alerts []models.Alert err = json.Unmarshal(content, &alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 3) } @@ -235,7 +235,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { s.writeconfig(cfg) pb, err := s.InitBroker(nil) - assert.NoError(t, err) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -259,7 +259,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { var alerts []models.Alert err = json.Unmarshal(content, &alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 4) } @@ -275,7 +275,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { s.writeconfig(cfg) pb, err := s.InitBroker(nil) - assert.NoError(t, err) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -306,11 +306,11 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { // two notifications, one with 4 alerts, one with 2 alerts err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 4) err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 2) err = decoder.Decode(&alerts) @@ -328,7 +328,7 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() { s.writeconfig(cfg) pb, err := s.InitBroker(nil) - assert.NoError(t, err) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -348,7 +348,7 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() { var alerts []models.Alert err = json.Unmarshal(content, &alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) } @@ -358,7 +358,7 @@ func (s *PluginSuite) TestBrokerRunSimple() { t := s.T() pb, err := s.InitBroker(nil) - assert.NoError(t, err) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -382,11 +382,11 @@ func (s *PluginSuite) TestBrokerRunSimple() { // two notifications, one alert each err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) err = decoder.Decode(&alerts) diff --git a/pkg/csplugin/broker_win_test.go b/pkg/csplugin/broker_win_test.go index 6466e0d54..97a3ad33d 100644 --- a/pkg/csplugin/broker_win_test.go +++ b/pkg/csplugin/broker_win_test.go @@ -70,7 +70,7 @@ func (s *PluginSuite) TestBrokerRun() { t := s.T() pb, err := s.InitBroker(nil) - assert.NoError(t, err) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -94,11 +94,11 @@ func (s *PluginSuite) TestBrokerRun() { // two notifications, one alert each err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) err = decoder.Decode(&alerts) diff --git a/pkg/cticlient/client_test.go b/pkg/cticlient/client_test.go index a229bde55..79406a6c2 100644 --- a/pkg/cticlient/client_test.go +++ b/pkg/cticlient/client_test.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/crowdsecurity/go-cs-lib/ptr" ) @@ -36,25 +37,30 @@ func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { // wip func fireHandler(req *http.Request) *http.Response { var err error + apiKey := req.Header.Get("x-api-key") if apiKey != validApiKey { log.Warningf("invalid api key: %s", apiKey) + return &http.Response{ StatusCode: http.StatusForbidden, Body: nil, Header: make(http.Header), } } + //unmarshal data if fireResponses == nil { page1, err := os.ReadFile("tests/fire-page1.json") if err != nil { panic("can't read file") } + page2, err := os.ReadFile("tests/fire-page2.json") if err != nil { panic("can't read file") } + fireResponses = []string{string(page1), string(page2)} } //let's assume we have two valid pages. @@ -70,6 +76,7 @@ func fireHandler(req *http.Request) *http.Response { //how to react if you give a page number that is too big ? if page > len(fireResponses) { log.Warningf(" page too big %d vs %d", page, len(fireResponses)) + emptyResponse := `{ "_links": { "first": { @@ -82,8 +89,10 @@ func fireHandler(req *http.Request) *http.Response { "items": [] } ` + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(emptyResponse))} } + reader := io.NopCloser(strings.NewReader(fireResponses[page-1])) //we should care about limit too return &http.Response{ @@ -106,6 +115,7 @@ func smokeHandler(req *http.Request) *http.Response { } requestedIP := strings.Split(req.URL.Path, "/")[3] + response, ok := smokeResponses[requestedIP] if !ok { return &http.Response{ @@ -135,6 +145,7 @@ func rateLimitedHandler(req *http.Request) *http.Response { Header: make(http.Header), } } + return &http.Response{ StatusCode: http.StatusTooManyRequests, Body: nil, @@ -151,7 +162,9 @@ func searchHandler(req *http.Request) *http.Response { Header: make(http.Header), } } + url, _ := url.Parse(req.URL.String()) + ipsParam := url.Query().Get("ips") if ipsParam == "" { return &http.Response{ @@ -163,6 +176,7 @@ func searchHandler(req *http.Request) *http.Response { totalIps := 0 notFound := 0 + ips := strings.Split(ipsParam, ",") for _, ip := range ips { _, ok := smokeResponses[ip] @@ -172,12 +186,15 @@ func searchHandler(req *http.Request) *http.Response { notFound++ } } + response := fmt.Sprintf(`{"total": %d, "not_found": %d, "items": [`, totalIps, notFound) for _, ip := range ips { response += smokeResponses[ip] } + response += "]}" reader := io.NopCloser(strings.NewReader(response)) + return &http.Response{ StatusCode: http.StatusOK, Body: reader, @@ -190,7 +207,7 @@ func TestBadFireAuth(t *testing.T) { Transport: RoundTripFunc(fireHandler), })) _, err := ctiClient.Fire(FireParams{}) - assert.EqualError(t, err, ErrUnauthorized.Error()) + require.EqualError(t, err, ErrUnauthorized.Error()) } func TestFireOk(t *testing.T) { @@ -198,19 +215,19 @@ func TestFireOk(t *testing.T) { Transport: RoundTripFunc(fireHandler), })) data, err := cticlient.Fire(FireParams{}) - assert.Equal(t, err, nil) - assert.Equal(t, len(data.Items), 3) - assert.Equal(t, data.Items[0].Ip, "1.2.3.4") + require.NoError(t, err) + assert.Len(t, data.Items, 3) + assert.Equal(t, "1.2.3.4", data.Items[0].Ip) //page 1 is the default data, err = cticlient.Fire(FireParams{Page: ptr.Of(1)}) - assert.Equal(t, err, nil) - assert.Equal(t, len(data.Items), 3) - assert.Equal(t, data.Items[0].Ip, "1.2.3.4") + require.NoError(t, err) + assert.Len(t, data.Items, 3) + assert.Equal(t, "1.2.3.4", data.Items[0].Ip) //page 2 data, err = cticlient.Fire(FireParams{Page: ptr.Of(2)}) - assert.Equal(t, err, nil) - assert.Equal(t, len(data.Items), 3) - assert.Equal(t, data.Items[0].Ip, "4.2.3.4") + require.NoError(t, err) + assert.Len(t, data.Items, 3) + assert.Equal(t, "4.2.3.4", data.Items[0].Ip) } func TestFirePaginator(t *testing.T) { @@ -219,17 +236,16 @@ func TestFirePaginator(t *testing.T) { })) paginator := NewFirePaginator(cticlient, FireParams{}) items, err := paginator.Next() - assert.Equal(t, err, nil) - assert.Equal(t, len(items), 3) - assert.Equal(t, items[0].Ip, "1.2.3.4") + require.NoError(t, err) + assert.Len(t, items, 3) + assert.Equal(t, "1.2.3.4", items[0].Ip) items, err = paginator.Next() - assert.Equal(t, err, nil) - assert.Equal(t, len(items), 3) - assert.Equal(t, items[0].Ip, "4.2.3.4") + require.NoError(t, err) + assert.Len(t, items, 3) + assert.Equal(t, "4.2.3.4", items[0].Ip) items, err = paginator.Next() - assert.Equal(t, err, nil) - assert.Equal(t, len(items), 0) - + require.NoError(t, err) + assert.Empty(t, items) } func TestBadSmokeAuth(t *testing.T) { @@ -237,13 +253,14 @@ func TestBadSmokeAuth(t *testing.T) { Transport: RoundTripFunc(smokeHandler), })) _, err := ctiClient.GetIPInfo("1.1.1.1") - assert.EqualError(t, err, ErrUnauthorized.Error()) + require.EqualError(t, err, ErrUnauthorized.Error()) } func TestSmokeInfoValidIP(t *testing.T) { ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{ Transport: RoundTripFunc(smokeHandler), })) + resp, err := ctiClient.GetIPInfo("1.1.1.1") if err != nil { t.Fatalf("failed to get ip info: %s", err) @@ -257,6 +274,7 @@ func TestSmokeUnknownIP(t *testing.T) { ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{ Transport: RoundTripFunc(smokeHandler), })) + resp, err := ctiClient.GetIPInfo("42.42.42.42") if err != nil { t.Fatalf("failed to get ip info: %s", err) @@ -270,20 +288,22 @@ func TestRateLimit(t *testing.T) { Transport: RoundTripFunc(rateLimitedHandler), })) _, err := ctiClient.GetIPInfo("1.1.1.1") - assert.EqualError(t, err, ErrLimit.Error()) + require.EqualError(t, err, ErrLimit.Error()) } func TestSearchIPs(t *testing.T) { ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{ Transport: RoundTripFunc(searchHandler), })) + resp, err := ctiClient.SearchIPs([]string{"1.1.1.1", "42.42.42.42"}) if err != nil { t.Fatalf("failed to search ips: %s", err) } + assert.Equal(t, 1, resp.Total) assert.Equal(t, 1, resp.NotFound) - assert.Equal(t, 1, len(resp.Items)) + assert.Len(t, resp.Items, 1) assert.Equal(t, "1.1.1.1", resp.Items[0].Ip) } diff --git a/pkg/cticlient/types_test.go b/pkg/cticlient/types_test.go index c20acc95a..a7308af35 100644 --- a/pkg/cticlient/types_test.go +++ b/pkg/cticlient/types_test.go @@ -88,27 +88,28 @@ func getSampleSmokeItem() SmokeItem { }, }, } + return emptyItem } func TestBasicSmokeItem(t *testing.T) { item := getSampleSmokeItem() - assert.Equal(t, item.GetAttackDetails(), []string{"ssh:bruteforce"}) - assert.Equal(t, item.GetBehaviors(), []string{"ssh:bruteforce"}) - assert.Equal(t, item.GetMaliciousnessScore(), float32(0.1)) - assert.Equal(t, item.IsPartOfCommunityBlocklist(), false) - assert.Equal(t, item.GetBackgroundNoiseScore(), int(3)) - assert.Equal(t, item.GetFalsePositives(), []string{}) - assert.Equal(t, item.IsFalsePositive(), false) + assert.Equal(t, []string{"ssh:bruteforce"}, item.GetAttackDetails()) + assert.Equal(t, []string{"ssh:bruteforce"}, item.GetBehaviors()) + assert.InDelta(t, 0.1, item.GetMaliciousnessScore(), 0.000001) + assert.False(t, item.IsPartOfCommunityBlocklist()) + assert.Equal(t, 3, item.GetBackgroundNoiseScore()) + assert.Equal(t, []string{}, item.GetFalsePositives()) + assert.False(t, item.IsFalsePositive()) } func TestEmptySmokeItem(t *testing.T) { item := SmokeItem{} - assert.Equal(t, item.GetAttackDetails(), []string{}) - assert.Equal(t, item.GetBehaviors(), []string{}) - assert.Equal(t, item.GetMaliciousnessScore(), float32(0.0)) - assert.Equal(t, item.IsPartOfCommunityBlocklist(), false) - assert.Equal(t, item.GetBackgroundNoiseScore(), int(0)) - assert.Equal(t, item.GetFalsePositives(), []string{}) - assert.Equal(t, item.IsFalsePositive(), false) + assert.Equal(t, []string{}, item.GetAttackDetails()) + assert.Equal(t, []string{}, item.GetBehaviors()) + assert.InDelta(t, 0.0, item.GetMaliciousnessScore(), 0) + assert.False(t, item.IsPartOfCommunityBlocklist()) + assert.Equal(t, 0, item.GetBackgroundNoiseScore()) + assert.Equal(t, []string{}, item.GetFalsePositives()) + assert.False(t, item.IsFalsePositive()) } diff --git a/pkg/exprhelpers/crowdsec_cti_test.go b/pkg/exprhelpers/crowdsec_cti_test.go index 80ccadba4..fc3a236c5 100644 --- a/pkg/exprhelpers/crowdsec_cti_test.go +++ b/pkg/exprhelpers/crowdsec_cti_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/crowdsecurity/go-cs-lib/ptr" @@ -78,6 +79,7 @@ func smokeHandler(req *http.Request) *http.Response { } requestedIP := strings.Split(req.URL.Path, "/")[3] + sample, ok := sampledata[requestedIP] if !ok { return &http.Response{ @@ -109,9 +111,11 @@ func smokeHandler(req *http.Request) *http.Response { func TestNillClient(t *testing.T) { defer ShutdownCrowdsecCTI() + if err := InitCrowdsecCTI(ptr.Of(""), nil, nil, nil); !errors.Is(err, cticlient.ErrDisabled) { t.Fatalf("failed to init CTI : %s", err) } + item, err := CrowdsecCTI("1.2.3.4") assert.Equal(t, err, cticlient.ErrDisabled) assert.Equal(t, item, &cticlient.SmokeItem{}) @@ -119,6 +123,7 @@ func TestNillClient(t *testing.T) { func TestInvalidAuth(t *testing.T) { defer ShutdownCrowdsecCTI() + if err := InitCrowdsecCTI(ptr.Of("asdasd"), nil, nil, nil); err != nil { t.Fatalf("failed to init CTI : %s", err) } @@ -129,7 +134,7 @@ func TestInvalidAuth(t *testing.T) { item, err := CrowdsecCTI("1.2.3.4") assert.Equal(t, item, &cticlient.SmokeItem{}) - assert.Equal(t, CTIApiEnabled, false) + assert.False(t, CTIApiEnabled) assert.Equal(t, err, cticlient.ErrUnauthorized) //CTI is now disabled, all requests should return empty @@ -139,14 +144,15 @@ func TestInvalidAuth(t *testing.T) { item, err = CrowdsecCTI("1.2.3.4") assert.Equal(t, item, &cticlient.SmokeItem{}) - assert.Equal(t, CTIApiEnabled, false) + assert.False(t, CTIApiEnabled) assert.Equal(t, err, cticlient.ErrDisabled) } func TestNoKey(t *testing.T) { defer ShutdownCrowdsecCTI() + err := InitCrowdsecCTI(nil, nil, nil, nil) - assert.ErrorIs(t, err, cticlient.ErrDisabled) + require.ErrorIs(t, err, cticlient.ErrDisabled) //Replace the client created by InitCrowdsecCTI with one that uses a custom transport ctiClient = cticlient.NewCrowdsecCTIClient(cticlient.WithAPIKey("asdasd"), cticlient.WithHTTPClient(&http.Client{ Transport: RoundTripFunc(smokeHandler), @@ -154,12 +160,13 @@ func TestNoKey(t *testing.T) { item, err := CrowdsecCTI("1.2.3.4") assert.Equal(t, item, &cticlient.SmokeItem{}) - assert.Equal(t, CTIApiEnabled, false) + assert.False(t, CTIApiEnabled) assert.Equal(t, err, cticlient.ErrDisabled) } func TestCache(t *testing.T) { defer ShutdownCrowdsecCTI() + cacheDuration := 1 * time.Second if err := InitCrowdsecCTI(ptr.Of(validApiKey), &cacheDuration, nil, nil); err != nil { t.Fatalf("failed to init CTI : %s", err) @@ -172,28 +179,27 @@ func TestCache(t *testing.T) { item, err := CrowdsecCTI("1.2.3.4") ctiResp := item.(*cticlient.SmokeItem) assert.Equal(t, "1.2.3.4", ctiResp.Ip) - assert.Equal(t, CTIApiEnabled, true) - assert.Equal(t, CTICache.Len(true), 1) - assert.Equal(t, err, nil) + assert.True(t, CTIApiEnabled) + assert.Equal(t, 1, CTICache.Len(true)) + require.NoError(t, err) item, err = CrowdsecCTI("1.2.3.4") ctiResp = item.(*cticlient.SmokeItem) assert.Equal(t, "1.2.3.4", ctiResp.Ip) - assert.Equal(t, CTIApiEnabled, true) - assert.Equal(t, CTICache.Len(true), 1) - assert.Equal(t, err, nil) + assert.True(t, CTIApiEnabled) + assert.Equal(t, 1, CTICache.Len(true)) + require.NoError(t, err) time.Sleep(2 * time.Second) - assert.Equal(t, CTICache.Len(true), 0) + assert.Equal(t, 0, CTICache.Len(true)) item, err = CrowdsecCTI("1.2.3.4") ctiResp = item.(*cticlient.SmokeItem) assert.Equal(t, "1.2.3.4", ctiResp.Ip) - assert.Equal(t, CTIApiEnabled, true) - assert.Equal(t, CTICache.Len(true), 1) - assert.Equal(t, err, nil) - + assert.True(t, CTIApiEnabled) + assert.Equal(t, 1, CTICache.Len(true)) + require.NoError(t, err) } diff --git a/pkg/exprhelpers/exprlib_test.go b/pkg/exprhelpers/exprlib_test.go index 4a6d5b74d..6b9cd15c7 100644 --- a/pkg/exprhelpers/exprlib_test.go +++ b/pkg/exprhelpers/exprlib_test.go @@ -28,17 +28,18 @@ var ( func getDBClient(t *testing.T) *database.Client { t.Helper() + dbPath, err := os.CreateTemp("", "*sqlite") require.NoError(t, err) - testDbClient, err := database.NewClient(&csconfig.DatabaseCfg{ + testDBClient, err := database.NewClient(&csconfig.DatabaseCfg{ Type: "sqlite", DbName: "crowdsec", DbPath: dbPath.Name(), }) require.NoError(t, err) - return testDbClient + return testDBClient } func TestVisitor(t *testing.T) { @@ -109,17 +110,18 @@ func TestVisitor(t *testing.T) { if err != nil && test.err == nil { log.Fatalf("run : %s", err) } + if isOk := assert.Equal(t, test.result, result); !isOk { t.Fatalf("test '%s' : NOK", test.filter) } } - } } func TestMatch(t *testing.T) { err := Init(nil) require.NoError(t, err) + tests := []struct { glob string val string @@ -149,12 +151,15 @@ func TestMatch(t *testing.T) { "pattern": test.glob, "name": test.val, } + vm, err := expr.Compile(test.expr, GetExprOptions(env)...) if err != nil { t.Fatalf("pattern:%s val:%s NOK %s", test.glob, test.val, err) } + ret, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) + if isOk := assert.Equal(t, test.ret, ret); !isOk { t.Fatalf("pattern:%s val:%s NOK %t != %t", test.glob, test.val, ret, test.ret) } @@ -194,10 +199,10 @@ func TestDistanceHelper(t *testing.T) { } ret, err := expr.Run(vm, env) if test.valid { - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.dist, ret) } else { - assert.NotNil(t, err) + require.Error(t, err) } }) } @@ -283,10 +288,12 @@ func TestRegexpInFile(t *testing.T) { if err != nil { log.Fatal(err) } + result, err := expr.Run(compiledFilter, map[string]interface{}{}) if err != nil { log.Fatal(err) } + if isOk := assert.Equal(t, test.result, result); !isOk { t.Fatalf("test '%s' : NOK", test.name) } @@ -335,28 +342,34 @@ func TestFileInit(t *testing.T) { if err != nil { log.Fatal(err) } - if test.types == "string" { + + switch test.types { + case "string": if _, ok := dataFile[test.filename]; !ok { t.Fatalf("test '%s' : NOK", test.name) } - if isOk := assert.Equal(t, test.result, len(dataFile[test.filename])); !isOk { + + if isOk := assert.Len(t, dataFile[test.filename], test.result); !isOk { t.Fatalf("test '%s' : NOK", test.name) } - } else if test.types == "regex" { + case "regex": if _, ok := dataFileRegex[test.filename]; !ok { t.Fatalf("test '%s' : NOK", test.name) } - if isOk := assert.Equal(t, test.result, len(dataFileRegex[test.filename])); !isOk { + + if isOk := assert.Len(t, dataFileRegex[test.filename], test.result); !isOk { t.Fatalf("test '%s' : NOK", test.name) } - } else { + default: if _, ok := dataFileRegex[test.filename]; ok { t.Fatalf("test '%s' : NOK", test.name) } + if _, ok := dataFile[test.filename]; ok { t.Fatalf("test '%s' : NOK", test.name) } } + log.Printf("test '%s' : OK", test.name) } } @@ -408,21 +421,23 @@ func TestFile(t *testing.T) { if err != nil { log.Fatal(err) } + result, err := expr.Run(compiledFilter, map[string]interface{}{}) if err != nil { log.Fatal(err) } + if isOk := assert.Equal(t, test.result, result); !isOk { t.Fatalf("test '%s' : NOK", test.name) } - log.Printf("test '%s' : OK", test.name) + log.Printf("test '%s' : OK", test.name) } } func TestIpInRange(t *testing.T) { err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string env map[string]interface{} @@ -470,12 +485,11 @@ func TestIpInRange(t *testing.T) { require.Equal(t, test.result, output) log.Printf("test '%s' : OK", test.name) } - } func TestIpToRange(t *testing.T) { err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string env map[string]interface{} @@ -543,13 +557,11 @@ func TestIpToRange(t *testing.T) { require.Equal(t, test.result, output) log.Printf("test '%s' : OK", test.name) } - } func TestAtof(t *testing.T) { - err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string @@ -593,13 +605,14 @@ func TestUpper(t *testing.T) { } err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) vm, err := expr.Compile("Upper(testStr)", GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) + v, ok := out.(string) if !ok { t.Fatalf("Upper() should return a string") @@ -612,6 +625,7 @@ func TestUpper(t *testing.T) { func TestTimeNow(t *testing.T) { now, _ := TimeNow() + ti, err := time.Parse(time.RFC3339, now.(string)) if err != nil { t.Fatalf("Error parsing the return value of TimeNow: %s", err) @@ -620,6 +634,7 @@ func TestTimeNow(t *testing.T) { if -1*time.Until(ti) > time.Second { t.Fatalf("TimeNow func should return time.Now().UTC()") } + log.Printf("test 'TimeNow()' : OK") } @@ -894,15 +909,14 @@ func TestLower(t *testing.T) { } func TestGetDecisionsCount(t *testing.T) { - var err error - var start_ip, start_sfx, end_ip, end_sfx int64 - var ip_sz int existingIP := "1.2.3.4" unknownIP := "1.2.3.5" - ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(existingIP) + + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(existingIP) if err != nil { t.Errorf("unable to convert '%s' to int: %s", existingIP, err) } + // Add sample data to DB dbClient = getDBClient(t) @@ -921,11 +935,11 @@ func TestGetDecisionsCount(t *testing.T) { SaveX(context.Background()) if decision == nil { - assert.Error(t, errors.Errorf("Failed to create sample decision")) + require.Error(t, errors.Errorf("Failed to create sample decision")) } err = Init(dbClient) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string @@ -982,12 +996,10 @@ func TestGetDecisionsCount(t *testing.T) { } } func TestGetDecisionsSinceCount(t *testing.T) { - var err error - var start_ip, start_sfx, end_ip, end_sfx int64 - var ip_sz int existingIP := "1.2.3.4" unknownIP := "1.2.3.5" - ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(existingIP) + + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(existingIP) if err != nil { t.Errorf("unable to convert '%s' to int: %s", existingIP, err) } @@ -1008,8 +1020,9 @@ func TestGetDecisionsSinceCount(t *testing.T) { SetOrigin("CAPI"). SaveX(context.Background()) if decision == nil { - assert.Error(t, errors.Errorf("Failed to create sample decision")) + require.Error(t, errors.Errorf("Failed to create sample decision")) } + decision2 := dbClient.Ent.Decision.Create(). SetCreatedAt(time.Now().AddDate(0, 0, -1)). SetUntil(time.Now().AddDate(0, 0, -1)). @@ -1024,12 +1037,13 @@ func TestGetDecisionsSinceCount(t *testing.T) { SetValue(existingIP). SetOrigin("CAPI"). SaveX(context.Background()) + if decision2 == nil { - assert.Error(t, errors.Errorf("Failed to create sample decision")) + require.Error(t, errors.Errorf("Failed to create sample decision")) } err = Init(dbClient) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string @@ -1152,6 +1166,7 @@ func TestIsIp(t *testing.T) { if err := Init(nil); err != nil { log.Fatal(err) } + tests := []struct { name string expr string @@ -1235,17 +1250,18 @@ func TestIsIp(t *testing.T) { expectedBuildErr: true, }, } + for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) if tc.expectedBuildErr { - assert.Error(t, err) + require.Error(t, err) return } - assert.NoError(t, err) + require.NoError(t, err) output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) - assert.NoError(t, err) + require.NoError(t, err) assert.IsType(t, tc.expected, output) assert.Equal(t, tc.expected, output.(bool)) }) @@ -1255,6 +1271,7 @@ func TestIsIp(t *testing.T) { func TestToString(t *testing.T) { err := Init(nil) require.NoError(t, err) + tests := []struct { name string value interface{} @@ -1290,9 +1307,9 @@ func TestToString(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) - assert.NoError(t, err) + require.NoError(t, err) output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) - assert.NoError(t, err) + require.NoError(t, err) require.Equal(t, tc.expected, output) }) } @@ -1338,16 +1355,16 @@ func TestB64Decode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) if tc.expectedBuildErr { - assert.Error(t, err) + require.Error(t, err) return } - assert.NoError(t, err) + require.NoError(t, err) output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) if tc.expectedRuntimeErr { - assert.Error(t, err) + require.Error(t, err) return } - assert.NoError(t, err) + require.NoError(t, err) require.Equal(t, tc.expected, output) }) } @@ -1412,9 +1429,9 @@ func TestParseKv(t *testing.T) { "out": outMap, } vm, err := expr.Compile(tc.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) _, err = expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, tc.expected, outMap["a"]) }) } diff --git a/pkg/exprhelpers/jsonextract_test.go b/pkg/exprhelpers/jsonextract_test.go index 481c7d723..1bd45aa2d 100644 --- a/pkg/exprhelpers/jsonextract_test.go +++ b/pkg/exprhelpers/jsonextract_test.go @@ -7,6 +7,7 @@ import ( "github.com/antonmedv/expr" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestJsonExtract(t *testing.T) { @@ -56,14 +57,14 @@ func TestJsonExtract(t *testing.T) { "target": test.targetField, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } - } + func TestJsonExtractUnescape(t *testing.T) { if err := Init(nil); err != nil { log.Fatal(err) @@ -104,9 +105,9 @@ func TestJsonExtractUnescape(t *testing.T) { "target": test.targetField, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } @@ -167,9 +168,9 @@ func TestJsonExtractSlice(t *testing.T) { "target": test.targetField, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } @@ -223,9 +224,9 @@ func TestJsonExtractObject(t *testing.T) { "target": test.targetField, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } @@ -233,7 +234,8 @@ func TestJsonExtractObject(t *testing.T) { func TestToJson(t *testing.T) { err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) + tests := []struct { name string obj interface{} @@ -298,9 +300,9 @@ func TestToJson(t *testing.T) { "obj": test.obj, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } @@ -308,7 +310,8 @@ func TestToJson(t *testing.T) { func TestUnmarshalJSON(t *testing.T) { err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) + tests := []struct { name string json string @@ -361,11 +364,10 @@ func TestUnmarshalJSON(t *testing.T) { "out": outMap, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) _, err = expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, outMap["a"]) }) } - } diff --git a/pkg/setup/detect_test.go b/pkg/setup/detect_test.go index 98a22db0d..242ade049 100644 --- a/pkg/setup/detect_test.go +++ b/pkg/setup/detect_test.go @@ -353,7 +353,7 @@ func TestUnitFound(t *testing.T) { installed, err := env.UnitFound("crowdsec-setup-detect.service") require.NoError(err) - require.Equal(true, installed) + require.True(installed) } // TODO apply rules to filter a list of Service structs @@ -566,8 +566,8 @@ func TestDetectForcedUnit(t *testing.T) { func TestDetectForcedProcess(t *testing.T) { if runtime.GOOS == "windows" { - t.Skip("skipping on windows") // while looking for service wizard: rule 'ProcessRunning("foobar")': while looking up running processes: could not get Name: A device attached to the system is not functioning. + t.Skip("skipping on windows") } require := require.New(t) diff --git a/pkg/types/event_test.go b/pkg/types/event_test.go index c3261c647..14ca48cd2 100644 --- a/pkg/types/event_test.go +++ b/pkg/types/event_test.go @@ -73,7 +73,7 @@ func TestParseIPSources(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { ips := tt.evt.ParseIPSources() - assert.Equal(t, ips, tt.expected) + assert.Equal(t, tt.expected, ips) }) } }