CI: enable testifylint (#2696)

- reverse actual and expected values
 - use assert.False, assert.True
 - use assert.Len, assert.Emtpy
 - use require.Error, require.NoError
 - use assert.InDelta
This commit is contained in:
mmetc 2024-01-05 15:26:13 +01:00 committed by GitHub
parent da746f77d5
commit 5622ac8338
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 283 additions and 215 deletions

View file

@ -242,42 +242,6 @@ issues:
# Will fix, trivial - just beware of merge conflicts # Will fix, trivial - just beware of merge conflicts
- linters:
- testifylint
text: "expected-actual: need to reverse actual and expected values"
- linters:
- testifylint
text: "bool-compare: use assert.False"
- linters:
- testifylint
text: "len: use assert.Len"
- linters:
- testifylint
text: "bool-compare: use assert.True"
- linters:
- testifylint
text: "bool-compare: use require.True"
- linters:
- testifylint
text: "require-error: for error assertions use require"
- linters:
- testifylint
text: "error-nil: use assert.NoError"
- linters:
- testifylint
text: "error-nil: use assert.Error"
- linters:
- testifylint
text: "empty: use assert.Empty"
- linters: - linters:
- perfsprint - perfsprint
text: "fmt.Sprintf can be replaced .*" text: "fmt.Sprintf can be replaced .*"
@ -286,10 +250,6 @@ issues:
# Will fix, easy but some neurons required # Will fix, easy but some neurons required
# #
- linters:
- testifylint
text: "float-compare: use assert.InEpsilon .*or InDelta.*"
- linters: - linters:
- errorlint - errorlint
text: "non-wrapping format verb for fmt.Errorf. Use `%w` to format errors" text: "non-wrapping format verb for fmt.Errorf. Use `%w` to format errors"

View file

@ -40,15 +40,19 @@ func (f *MockSource) Configure(cfg []byte, logger *log.Entry) error {
if err := f.UnmarshalConfig(cfg); err != nil { if err := f.UnmarshalConfig(cfg); err != nil {
return err return err
} }
if f.Mode == "" { if f.Mode == "" {
f.Mode = configuration.CAT_MODE f.Mode = configuration.CAT_MODE
} }
if f.Mode != configuration.CAT_MODE && f.Mode != configuration.TAIL_MODE { if f.Mode != configuration.CAT_MODE && f.Mode != configuration.TAIL_MODE {
return fmt.Errorf("mode %s is not supported", f.Mode) return fmt.Errorf("mode %s is not supported", f.Mode)
} }
if f.Toto == "" { if f.Toto == "" {
return fmt.Errorf("expect non-empty toto") return fmt.Errorf("expect non-empty toto")
} }
return nil return nil
} }
func (f *MockSource) GetMode() string { return f.Mode } func (f *MockSource) GetMode() string { return f.Mode }
@ -77,6 +81,7 @@ func appendMockSource() {
if GetDataSourceIface("mock") == nil { if GetDataSourceIface("mock") == nil {
AcquisitionSources["mock"] = func() DataSource { return &MockSource{} } AcquisitionSources["mock"] = func() DataSource { return &MockSource{} }
} }
if GetDataSourceIface("mock_cant_run") == nil { if GetDataSourceIface("mock_cant_run") == nil {
AcquisitionSources["mock_cant_run"] = func() DataSource { return &MockSourceCantRun{} } AcquisitionSources["mock_cant_run"] = func() DataSource { return &MockSourceCantRun{} }
} }
@ -84,6 +89,7 @@ func appendMockSource() {
func TestDataSourceConfigure(t *testing.T) { func TestDataSourceConfigure(t *testing.T) {
appendMockSource() appendMockSource()
tests := []struct { tests := []struct {
TestName string TestName string
String string String string
@ -185,22 +191,22 @@ wowo: ajsajasjas
switch tc.TestName { switch tc.TestName {
case "basic_valid_config": case "basic_valid_config":
mock := (*ds).Dump().(*MockSource) mock := (*ds).Dump().(*MockSource)
assert.Equal(t, mock.Toto, "test_value1") assert.Equal(t, "test_value1", mock.Toto)
assert.Equal(t, mock.Mode, "cat") assert.Equal(t, "cat", mock.Mode)
assert.Equal(t, mock.logger.Logger.Level, log.InfoLevel) assert.Equal(t, log.InfoLevel, mock.logger.Logger.Level)
assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"}) assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels)
case "basic_debug_config": case "basic_debug_config":
mock := (*ds).Dump().(*MockSource) mock := (*ds).Dump().(*MockSource)
assert.Equal(t, mock.Toto, "test_value1") assert.Equal(t, "test_value1", mock.Toto)
assert.Equal(t, mock.Mode, "cat") assert.Equal(t, "cat", mock.Mode)
assert.Equal(t, mock.logger.Logger.Level, log.DebugLevel) assert.Equal(t, log.DebugLevel, mock.logger.Logger.Level)
assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"}) assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels)
case "basic_tailmode_config": case "basic_tailmode_config":
mock := (*ds).Dump().(*MockSource) mock := (*ds).Dump().(*MockSource)
assert.Equal(t, mock.Toto, "test_value1") assert.Equal(t, "test_value1", mock.Toto)
assert.Equal(t, mock.Mode, "tail") assert.Equal(t, "tail", mock.Mode)
assert.Equal(t, mock.logger.Logger.Level, log.DebugLevel) assert.Equal(t, log.DebugLevel, mock.logger.Logger.Level)
assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"}) assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels)
} }
}) })
} }
@ -208,6 +214,7 @@ wowo: ajsajasjas
func TestLoadAcquisitionFromFile(t *testing.T) { func TestLoadAcquisitionFromFile(t *testing.T) {
appendMockSource() appendMockSource()
tests := []struct { tests := []struct {
TestName string TestName string
Config csconfig.CrowdsecServiceCfg Config csconfig.CrowdsecServiceCfg
@ -284,7 +291,6 @@ func TestLoadAcquisitionFromFile(t *testing.T) {
assert.Len(t, dss, tc.ExpectedLen) assert.Len(t, dss, tc.ExpectedLen)
}) })
} }
} }
@ -304,9 +310,11 @@ func (f *MockCat) Configure(cfg []byte, logger *log.Entry) error {
if f.Mode == "" { if f.Mode == "" {
f.Mode = configuration.CAT_MODE f.Mode = configuration.CAT_MODE
} }
if f.Mode != configuration.CAT_MODE { if f.Mode != configuration.CAT_MODE {
return fmt.Errorf("mode %s is not supported", f.Mode) return fmt.Errorf("mode %s is not supported", f.Mode)
} }
return nil return nil
} }
@ -319,6 +327,7 @@ func (f *MockCat) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) erro
evt.Line.Src = "test" evt.Line.Src = "test"
out <- evt out <- evt
} }
return nil return nil
} }
func (f *MockCat) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { func (f *MockCat) StreamingAcquisition(chan types.Event, *tomb.Tomb) error {
@ -345,9 +354,11 @@ func (f *MockTail) Configure(cfg []byte, logger *log.Entry) error {
if f.Mode == "" { if f.Mode == "" {
f.Mode = configuration.TAIL_MODE f.Mode = configuration.TAIL_MODE
} }
if f.Mode != configuration.TAIL_MODE { if f.Mode != configuration.TAIL_MODE {
return fmt.Errorf("mode %s is not supported", f.Mode) return fmt.Errorf("mode %s is not supported", f.Mode)
} }
return nil return nil
} }
@ -364,6 +375,7 @@ func (f *MockTail) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) erro
out <- evt out <- evt
} }
<-t.Dying() <-t.Dying()
return nil return nil
} }
func (f *MockTail) CanRun() error { return nil } func (f *MockTail) CanRun() error { return nil }
@ -446,6 +458,7 @@ func (f *MockTailError) StreamingAcquisition(out chan types.Event, t *tomb.Tomb)
out <- evt out <- evt
} }
t.Kill(fmt.Errorf("got error (tomb)")) t.Kill(fmt.Errorf("got error (tomb)"))
return fmt.Errorf("got error") return fmt.Errorf("got error")
} }
@ -499,6 +512,7 @@ func (f *MockSourceByDSN) ConfigureByDSN(dsn string, labels map[string]string, l
if dsn != "test_expect" { if dsn != "test_expect" {
return fmt.Errorf("unexpected value") return fmt.Errorf("unexpected value")
} }
return nil return nil
} }
func (f *MockSourceByDSN) GetUuid() string { return "" } func (f *MockSourceByDSN) GetUuid() string { return "" }

View file

@ -57,6 +57,7 @@ container_name:
subLogger := log.WithFields(log.Fields{ subLogger := log.WithFields(log.Fields{
"type": "docker", "type": "docker",
}) })
for _, test := range tests { for _, test := range tests {
f := DockerSource{} f := DockerSource{}
err := f.Configure([]byte(test.config), subLogger) err := f.Configure([]byte(test.config), subLogger)
@ -66,12 +67,15 @@ container_name:
func TestConfigureDSN(t *testing.T) { func TestConfigureDSN(t *testing.T) {
log.Infof("Test 'TestConfigureDSN'") log.Infof("Test 'TestConfigureDSN'")
var dockerHost string var dockerHost string
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
dockerHost = "npipe:////./pipe/docker_engine" dockerHost = "npipe:////./pipe/docker_engine"
} else { } else {
dockerHost = "unix:///var/run/podman/podman.sock" dockerHost = "unix:///var/run/podman/podman.sock"
} }
tests := []struct { tests := []struct {
name string name string
dsn string dsn string
@ -106,6 +110,7 @@ func TestConfigureDSN(t *testing.T) {
subLogger := log.WithFields(log.Fields{ subLogger := log.WithFields(log.Fields{
"type": "docker", "type": "docker",
}) })
for _, test := range tests { for _, test := range tests {
f := DockerSource{} f := DockerSource{}
err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger, "") err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger, "")
@ -156,8 +161,11 @@ container_name_regexp:
} }
for _, ts := range tests { for _, ts := range tests {
var logger *log.Logger var (
var subLogger *log.Entry logger *log.Logger
subLogger *log.Entry
)
if ts.expectedOutput != "" { if ts.expectedOutput != "" {
logger.SetLevel(ts.logLevel) logger.SetLevel(ts.logLevel)
subLogger = logger.WithFields(log.Fields{ subLogger = logger.WithFields(log.Fields{
@ -173,10 +181,12 @@ container_name_regexp:
dockerTomb := tomb.Tomb{} dockerTomb := tomb.Tomb{}
out := make(chan types.Event) out := make(chan types.Event)
dockerSource := DockerSource{} dockerSource := DockerSource{}
err := dockerSource.Configure([]byte(ts.config), subLogger) err := dockerSource.Configure([]byte(ts.config), subLogger)
if err != nil { if err != nil {
t.Fatalf("Unexpected error : %s", err) t.Fatalf("Unexpected error : %s", err)
} }
dockerSource.Client = new(mockDockerCli) dockerSource.Client = new(mockDockerCli)
actualLines := 0 actualLines := 0
readerTomb := &tomb.Tomb{} readerTomb := &tomb.Tomb{}
@ -204,21 +214,23 @@ container_name_regexp:
if err := readerTomb.Wait(); err != nil { if err := readerTomb.Wait(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if ts.expectedLines != 0 { if ts.expectedLines != 0 {
assert.Equal(t, ts.expectedLines, actualLines) assert.Equal(t, ts.expectedLines, actualLines)
} }
err = streamTomb.Wait() err = streamTomb.Wait()
if err != nil { if err != nil {
t.Fatalf("docker acquisition error: %s", err) t.Fatalf("docker acquisition error: %s", err)
} }
} }
} }
func (cli *mockDockerCli) ContainerList(ctx context.Context, options dockerTypes.ContainerListOptions) ([]dockerTypes.Container, error) { func (cli *mockDockerCli) ContainerList(ctx context.Context, options dockerTypes.ContainerListOptions) ([]dockerTypes.Container, error) {
if readLogs == true { if readLogs == true {
return []dockerTypes.Container{}, nil return []dockerTypes.Container{}, nil
} }
containers := make([]dockerTypes.Container, 0) containers := make([]dockerTypes.Container, 0)
container := &dockerTypes.Container{ container := &dockerTypes.Container{
ID: "12456", ID: "12456",
@ -233,16 +245,20 @@ func (cli *mockDockerCli) ContainerLogs(ctx context.Context, container string, o
if readLogs == true { if readLogs == true {
return io.NopCloser(strings.NewReader("")), nil return io.NopCloser(strings.NewReader("")), nil
} }
readLogs = true readLogs = true
data := []string{"docker\n", "test\n", "1234\n"} data := []string{"docker\n", "test\n", "1234\n"}
ret := "" ret := ""
for _, line := range data { for _, line := range data {
startLineByte := make([]byte, 8) startLineByte := make([]byte, 8)
binary.LittleEndian.PutUint32(startLineByte, 1) //stdout stream binary.LittleEndian.PutUint32(startLineByte, 1) //stdout stream
binary.BigEndian.PutUint32(startLineByte[4:], uint32(len(line))) binary.BigEndian.PutUint32(startLineByte[4:], uint32(len(line)))
ret += fmt.Sprintf("%s%s", startLineByte, line) ret += fmt.Sprintf("%s%s", startLineByte, line)
} }
r := io.NopCloser(strings.NewReader(ret)) // r type is io.ReadCloser r := io.NopCloser(strings.NewReader(ret)) // r type is io.ReadCloser
return r, nil return r, nil
} }
@ -252,6 +268,7 @@ func (cli *mockDockerCli) ContainerInspect(ctx context.Context, c string) (docke
Tty: false, Tty: false,
}, },
} }
return r, nil return r, nil
} }
@ -285,8 +302,11 @@ func TestOneShot(t *testing.T) {
} }
for _, ts := range tests { for _, ts := range tests {
var subLogger *log.Entry var (
var logger *log.Logger subLogger *log.Entry
logger *log.Logger
)
if ts.expectedOutput != "" { if ts.expectedOutput != "" {
logger.SetLevel(ts.logLevel) logger.SetLevel(ts.logLevel)
subLogger = logger.WithFields(log.Fields{ subLogger = logger.WithFields(log.Fields{
@ -307,6 +327,7 @@ func TestOneShot(t *testing.T) {
if err := dockerClient.ConfigureByDSN(ts.dsn, labels, subLogger, ""); err != nil { if err := dockerClient.ConfigureByDSN(ts.dsn, labels, subLogger, ""); err != nil {
t.Fatalf("unable to configure dsn '%s': %s", ts.dsn, err) t.Fatalf("unable to configure dsn '%s': %s", ts.dsn, err)
} }
dockerClient.Client = new(mockDockerCli) dockerClient.Client = new(mockDockerCli)
out := make(chan types.Event, 100) out := make(chan types.Event, 100)
tomb := tomb.Tomb{} tomb := tomb.Tomb{}
@ -315,8 +336,7 @@ func TestOneShot(t *testing.T) {
// else we do the check before actualLines is incremented ... // else we do the check before actualLines is incremented ...
if ts.expectedLines != 0 { if ts.expectedLines != 0 {
assert.Equal(t, ts.expectedLines, len(out)) assert.Len(t, out, ts.expectedLines)
} }
} }
} }

View file

@ -21,6 +21,7 @@ func TestBadConfiguration(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("Skipping test on windows") t.Skip("Skipping test on windows")
} }
tests := []struct { tests := []struct {
config string config string
expectedErr string expectedErr string
@ -48,6 +49,7 @@ journalctl_filter:
subLogger := log.WithFields(log.Fields{ subLogger := log.WithFields(log.Fields{
"type": "journalctl", "type": "journalctl",
}) })
for _, test := range tests { for _, test := range tests {
f := JournalCtlSource{} f := JournalCtlSource{}
err := f.Configure([]byte(test.config), subLogger) err := f.Configure([]byte(test.config), subLogger)
@ -59,6 +61,7 @@ func TestConfigureDSN(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("Skipping test on windows") t.Skip("Skipping test on windows")
} }
tests := []struct { tests := []struct {
dsn string dsn string
expectedErr string expectedErr string
@ -92,9 +95,11 @@ func TestConfigureDSN(t *testing.T) {
expectedErr: "", expectedErr: "",
}, },
} }
subLogger := log.WithFields(log.Fields{ subLogger := log.WithFields(log.Fields{
"type": "journalctl", "type": "journalctl",
}) })
for _, test := range tests { for _, test := range tests {
f := JournalCtlSource{} f := JournalCtlSource{}
err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger, "") err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger, "")
@ -106,6 +111,7 @@ func TestOneShot(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("Skipping test on windows") t.Skip("Skipping test on windows")
} }
tests := []struct { tests := []struct {
config string config string
expectedErr string expectedErr string
@ -137,9 +143,12 @@ journalctl_filter:
}, },
} }
for _, ts := range tests { for _, ts := range tests {
var logger *log.Logger var (
var subLogger *log.Entry logger *log.Logger
var hook *test.Hook subLogger *log.Entry
hook *test.Hook
)
if ts.expectedOutput != "" { if ts.expectedOutput != "" {
logger, hook = test.NewNullLogger() logger, hook = test.NewNullLogger()
logger.SetLevel(ts.logLevel) logger.SetLevel(ts.logLevel)
@ -151,27 +160,32 @@ journalctl_filter:
"type": "journalctl", "type": "journalctl",
}) })
} }
tomb := tomb.Tomb{} tomb := tomb.Tomb{}
out := make(chan types.Event, 100) out := make(chan types.Event, 100)
j := JournalCtlSource{} j := JournalCtlSource{}
err := j.Configure([]byte(ts.config), subLogger) err := j.Configure([]byte(ts.config), subLogger)
if err != nil { if err != nil {
t.Fatalf("Unexpected error : %s", err) t.Fatalf("Unexpected error : %s", err)
} }
err = j.OneShotAcquisition(out, &tomb) err = j.OneShotAcquisition(out, &tomb)
cstest.AssertErrorContains(t, err, ts.expectedErr) cstest.AssertErrorContains(t, err, ts.expectedErr)
if err != nil { if err != nil {
continue continue
} }
if ts.expectedLines != 0 { if ts.expectedLines != 0 {
assert.Equal(t, ts.expectedLines, len(out)) assert.Len(t, out, ts.expectedLines)
} }
if ts.expectedOutput != "" { if ts.expectedOutput != "" {
if hook.LastEntry() == nil { if hook.LastEntry() == nil {
t.Fatalf("Expected log output '%s' but got nothing !", ts.expectedOutput) t.Fatalf("Expected log output '%s' but got nothing !", ts.expectedOutput)
} }
assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput) assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput)
hook.Reset() hook.Reset()
} }
@ -182,6 +196,7 @@ func TestStreaming(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("Skipping test on windows") t.Skip("Skipping test on windows")
} }
tests := []struct { tests := []struct {
config string config string
expectedErr string expectedErr string
@ -202,9 +217,12 @@ journalctl_filter:
}, },
} }
for _, ts := range tests { for _, ts := range tests {
var logger *log.Logger var (
var subLogger *log.Entry logger *log.Logger
var hook *test.Hook subLogger *log.Entry
hook *test.Hook
)
if ts.expectedOutput != "" { if ts.expectedOutput != "" {
logger, hook = test.NewNullLogger() logger, hook = test.NewNullLogger()
logger.SetLevel(ts.logLevel) logger.SetLevel(ts.logLevel)
@ -216,14 +234,18 @@ journalctl_filter:
"type": "journalctl", "type": "journalctl",
}) })
} }
tomb := tomb.Tomb{} tomb := tomb.Tomb{}
out := make(chan types.Event) out := make(chan types.Event)
j := JournalCtlSource{} j := JournalCtlSource{}
err := j.Configure([]byte(ts.config), subLogger) err := j.Configure([]byte(ts.config), subLogger)
if err != nil { if err != nil {
t.Fatalf("Unexpected error : %s", err) t.Fatalf("Unexpected error : %s", err)
} }
actualLines := 0 actualLines := 0
if ts.expectedLines != 0 { if ts.expectedLines != 0 {
go func() { go func() {
READLOOP: READLOOP:
@ -240,6 +262,7 @@ journalctl_filter:
err = j.StreamingAcquisition(out, &tomb) err = j.StreamingAcquisition(out, &tomb)
cstest.AssertErrorContains(t, err, ts.expectedErr) cstest.AssertErrorContains(t, err, ts.expectedErr)
if err != nil { if err != nil {
continue continue
} }
@ -248,16 +271,20 @@ journalctl_filter:
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
assert.Equal(t, ts.expectedLines, actualLines) assert.Equal(t, ts.expectedLines, actualLines)
} }
tomb.Kill(nil) tomb.Kill(nil)
tomb.Wait() tomb.Wait()
output, _ := exec.Command("pgrep", "-x", "journalctl").CombinedOutput() output, _ := exec.Command("pgrep", "-x", "journalctl").CombinedOutput()
if string(output) != "" { if string(output) != "" {
t.Fatalf("Found a journalctl process after killing the tomb !") t.Fatalf("Found a journalctl process after killing the tomb !")
} }
if ts.expectedOutput != "" { if ts.expectedOutput != "" {
if hook.LastEntry() == nil { if hook.LastEntry() == nil {
t.Fatalf("Expected log output '%s' but got nothing !", ts.expectedOutput) t.Fatalf("Expected log output '%s' but got nothing !", ts.expectedOutput)
} }
assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput) assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput)
hook.Reset() hook.Reset()
} }
@ -270,5 +297,6 @@ func TestMain(m *testing.M) {
fullPath := filepath.Join(currentDir, "test_files") fullPath := filepath.Join(currentDir, "test_files")
os.Setenv("PATH", fullPath+":"+os.Getenv("PATH")) os.Setenv("PATH", fullPath+":"+os.Getenv("PATH"))
} }
os.Exit(m.Run()) os.Exit(m.Run())
} }

View file

@ -9,6 +9,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/types" "github.com/crowdsecurity/crowdsec/pkg/types"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/tomb.v2" "gopkg.in/tomb.v2"
) )
@ -78,24 +79,23 @@ webhook_path: /k8s-audit`,
err := f.UnmarshalConfig([]byte(test.config)) err := f.UnmarshalConfig([]byte(test.config))
assert.NoError(t, err) require.NoError(t, err)
err = f.Configure([]byte(test.config), subLogger) err = f.Configure([]byte(test.config), subLogger)
assert.NoError(t, err) require.NoError(t, err)
f.StreamingAcquisition(out, tb) f.StreamingAcquisition(out, tb)
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
tb.Kill(nil) tb.Kill(nil)
err = tb.Wait() err = tb.Wait()
if test.expectedErr != "" { if test.expectedErr != "" {
assert.ErrorContains(t, err, test.expectedErr) require.ErrorContains(t, err, test.expectedErr)
return return
} }
assert.NoError(t, err) require.NoError(t, err)
}) })
} }
} }
func TestHandler(t *testing.T) { func TestHandler(t *testing.T) {
@ -252,10 +252,10 @@ webhook_path: /k8s-audit`,
f := KubernetesAuditSource{} f := KubernetesAuditSource{}
err := f.UnmarshalConfig([]byte(test.config)) err := f.UnmarshalConfig([]byte(test.config))
assert.NoError(t, err) require.NoError(t, err)
err = f.Configure([]byte(test.config), subLogger) err = f.Configure([]byte(test.config), subLogger)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest(test.method, "/k8s-audit", strings.NewReader(test.body)) req := httptest.NewRequest(test.method, "/k8s-audit", strings.NewReader(test.body))
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -268,11 +268,11 @@ webhook_path: /k8s-audit`,
assert.Equal(t, test.expectedStatusCode, res.StatusCode) assert.Equal(t, test.expectedStatusCode, res.StatusCode)
//time.Sleep(1 * time.Second) //time.Sleep(1 * time.Second)
assert.NoError(t, err) require.NoError(t, err)
tb.Kill(nil) tb.Kill(nil)
err = tb.Wait() err = tb.Wait()
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, test.eventCount, eventCount) assert.Equal(t, test.eventCount, eventCount)
}) })

View file

@ -11,6 +11,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/types" "github.com/crowdsecurity/crowdsec/pkg/types"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows/svc/eventlog" "golang.org/x/sys/windows/svc/eventlog"
"gopkg.in/tomb.v2" "gopkg.in/tomb.v2"
) )
@ -124,7 +125,7 @@ event_level: bla`,
} }
assert.Contains(t, err.Error(), test.expectedErr) assert.Contains(t, err.Error(), test.expectedErr)
} else { } else {
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, test.expectedQuery, q) assert.Equal(t, test.expectedQuery, q)
} }
} }
@ -221,9 +222,8 @@ event_ids:
} }
} }
if test.expectedLines == nil { if test.expectedLines == nil {
assert.Equal(t, 0, len(linesRead)) assert.Empty(t, linesRead)
} else { } else {
assert.Equal(t, len(test.expectedLines), len(linesRead))
assert.Equal(t, test.expectedLines, linesRead) assert.Equal(t, test.expectedLines, linesRead)
} }
to.Kill(nil) to.Kill(nil)

View file

@ -7,6 +7,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types" "github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNewAlertContext(t *testing.T) { func TestNewAlertContext(t *testing.T) {
@ -29,8 +30,7 @@ func TestNewAlertContext(t *testing.T) {
for _, test := range tests { for _, test := range tests {
fmt.Printf("Running test '%s'\n", test.name) fmt.Printf("Running test '%s'\n", test.name)
err := NewAlertContext(test.contextToSend, test.valueLength) err := NewAlertContext(test.contextToSend, test.valueLength)
assert.ErrorIs(t, err, test.expectedErr) require.ErrorIs(t, err, test.expectedErr)
} }
} }
@ -193,7 +193,7 @@ func TestEventToContext(t *testing.T) {
for _, test := range tests { for _, test := range tests {
fmt.Printf("Running test '%s'\n", test.name) fmt.Printf("Running test '%s'\n", test.name)
err := NewAlertContext(test.contextToSend, test.valueLength) err := NewAlertContext(test.contextToSend, test.valueLength)
assert.ErrorIs(t, err, nil) require.NoError(t, err)
metas, _ := EventToContext(test.events) metas, _ := EventToContext(test.events)
assert.ElementsMatch(t, test.expectedResult, metas) assert.ElementsMatch(t, test.expectedResult, metas)

View file

@ -149,7 +149,7 @@ func (s *PluginSuite) TestBrokerNoThreshold() {
t := s.T() t := s.T()
pb, err := s.InitBroker(nil) pb, err := s.InitBroker(nil)
assert.NoError(t, err) require.NoError(t, err)
tomb := tomb.Tomb{} tomb := tomb.Tomb{}
go pb.Run(&tomb) go pb.Run(&tomb)
@ -182,7 +182,7 @@ func (s *PluginSuite) TestBrokerNoThreshold() {
err = json.Unmarshal(content, &alerts) err = json.Unmarshal(content, &alerts)
log.Printf("content-> %s", content) log.Printf("content-> %s", content)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, alerts, 1) assert.Len(t, alerts, 1)
} }
@ -199,7 +199,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() {
s.writeconfig(cfg) s.writeconfig(cfg)
pb, err := s.InitBroker(nil) pb, err := s.InitBroker(nil)
assert.NoError(t, err) require.NoError(t, err)
tomb := tomb.Tomb{} tomb := tomb.Tomb{}
go pb.Run(&tomb) go pb.Run(&tomb)
@ -215,11 +215,11 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
// after 1 seconds, we should have data // after 1 seconds, we should have data
content, err := os.ReadFile("./out") content, err := os.ReadFile("./out")
assert.NoError(t, err) require.NoError(t, err)
var alerts []models.Alert var alerts []models.Alert
err = json.Unmarshal(content, &alerts) err = json.Unmarshal(content, &alerts)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, alerts, 3) assert.Len(t, alerts, 3)
} }
@ -235,7 +235,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() {
s.writeconfig(cfg) s.writeconfig(cfg)
pb, err := s.InitBroker(nil) pb, err := s.InitBroker(nil)
assert.NoError(t, err) require.NoError(t, err)
tomb := tomb.Tomb{} tomb := tomb.Tomb{}
go pb.Run(&tomb) go pb.Run(&tomb)
@ -259,7 +259,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() {
var alerts []models.Alert var alerts []models.Alert
err = json.Unmarshal(content, &alerts) err = json.Unmarshal(content, &alerts)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, alerts, 4) assert.Len(t, alerts, 4)
} }
@ -275,7 +275,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() {
s.writeconfig(cfg) s.writeconfig(cfg)
pb, err := s.InitBroker(nil) pb, err := s.InitBroker(nil)
assert.NoError(t, err) require.NoError(t, err)
tomb := tomb.Tomb{} tomb := tomb.Tomb{}
go pb.Run(&tomb) go pb.Run(&tomb)
@ -306,11 +306,11 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() {
// two notifications, one with 4 alerts, one with 2 alerts // two notifications, one with 4 alerts, one with 2 alerts
err = decoder.Decode(&alerts) err = decoder.Decode(&alerts)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, alerts, 4) assert.Len(t, alerts, 4)
err = decoder.Decode(&alerts) err = decoder.Decode(&alerts)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, alerts, 2) assert.Len(t, alerts, 2)
err = decoder.Decode(&alerts) err = decoder.Decode(&alerts)
@ -328,7 +328,7 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() {
s.writeconfig(cfg) s.writeconfig(cfg)
pb, err := s.InitBroker(nil) pb, err := s.InitBroker(nil)
assert.NoError(t, err) require.NoError(t, err)
tomb := tomb.Tomb{} tomb := tomb.Tomb{}
go pb.Run(&tomb) go pb.Run(&tomb)
@ -348,7 +348,7 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() {
var alerts []models.Alert var alerts []models.Alert
err = json.Unmarshal(content, &alerts) err = json.Unmarshal(content, &alerts)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, alerts, 1) assert.Len(t, alerts, 1)
} }
@ -358,7 +358,7 @@ func (s *PluginSuite) TestBrokerRunSimple() {
t := s.T() t := s.T()
pb, err := s.InitBroker(nil) pb, err := s.InitBroker(nil)
assert.NoError(t, err) require.NoError(t, err)
tomb := tomb.Tomb{} tomb := tomb.Tomb{}
go pb.Run(&tomb) go pb.Run(&tomb)
@ -382,11 +382,11 @@ func (s *PluginSuite) TestBrokerRunSimple() {
// two notifications, one alert each // two notifications, one alert each
err = decoder.Decode(&alerts) err = decoder.Decode(&alerts)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, alerts, 1) assert.Len(t, alerts, 1)
err = decoder.Decode(&alerts) err = decoder.Decode(&alerts)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, alerts, 1) assert.Len(t, alerts, 1)
err = decoder.Decode(&alerts) err = decoder.Decode(&alerts)

View file

@ -70,7 +70,7 @@ func (s *PluginSuite) TestBrokerRun() {
t := s.T() t := s.T()
pb, err := s.InitBroker(nil) pb, err := s.InitBroker(nil)
assert.NoError(t, err) require.NoError(t, err)
tomb := tomb.Tomb{} tomb := tomb.Tomb{}
go pb.Run(&tomb) go pb.Run(&tomb)
@ -94,11 +94,11 @@ func (s *PluginSuite) TestBrokerRun() {
// two notifications, one alert each // two notifications, one alert each
err = decoder.Decode(&alerts) err = decoder.Decode(&alerts)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, alerts, 1) assert.Len(t, alerts, 1)
err = decoder.Decode(&alerts) err = decoder.Decode(&alerts)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, alerts, 1) assert.Len(t, alerts, 1)
err = decoder.Decode(&alerts) err = decoder.Decode(&alerts)

View file

@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/ptr"
) )
@ -36,25 +37,30 @@ func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
// wip // wip
func fireHandler(req *http.Request) *http.Response { func fireHandler(req *http.Request) *http.Response {
var err error var err error
apiKey := req.Header.Get("x-api-key") apiKey := req.Header.Get("x-api-key")
if apiKey != validApiKey { if apiKey != validApiKey {
log.Warningf("invalid api key: %s", apiKey) log.Warningf("invalid api key: %s", apiKey)
return &http.Response{ return &http.Response{
StatusCode: http.StatusForbidden, StatusCode: http.StatusForbidden,
Body: nil, Body: nil,
Header: make(http.Header), Header: make(http.Header),
} }
} }
//unmarshal data //unmarshal data
if fireResponses == nil { if fireResponses == nil {
page1, err := os.ReadFile("tests/fire-page1.json") page1, err := os.ReadFile("tests/fire-page1.json")
if err != nil { if err != nil {
panic("can't read file") panic("can't read file")
} }
page2, err := os.ReadFile("tests/fire-page2.json") page2, err := os.ReadFile("tests/fire-page2.json")
if err != nil { if err != nil {
panic("can't read file") panic("can't read file")
} }
fireResponses = []string{string(page1), string(page2)} fireResponses = []string{string(page1), string(page2)}
} }
//let's assume we have two valid pages. //let's assume we have two valid pages.
@ -70,6 +76,7 @@ func fireHandler(req *http.Request) *http.Response {
//how to react if you give a page number that is too big ? //how to react if you give a page number that is too big ?
if page > len(fireResponses) { if page > len(fireResponses) {
log.Warningf(" page too big %d vs %d", page, len(fireResponses)) log.Warningf(" page too big %d vs %d", page, len(fireResponses))
emptyResponse := `{ emptyResponse := `{
"_links": { "_links": {
"first": { "first": {
@ -82,8 +89,10 @@ func fireHandler(req *http.Request) *http.Response {
"items": [] "items": []
} }
` `
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(emptyResponse))} return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(emptyResponse))}
} }
reader := io.NopCloser(strings.NewReader(fireResponses[page-1])) reader := io.NopCloser(strings.NewReader(fireResponses[page-1]))
//we should care about limit too //we should care about limit too
return &http.Response{ return &http.Response{
@ -106,6 +115,7 @@ func smokeHandler(req *http.Request) *http.Response {
} }
requestedIP := strings.Split(req.URL.Path, "/")[3] requestedIP := strings.Split(req.URL.Path, "/")[3]
response, ok := smokeResponses[requestedIP] response, ok := smokeResponses[requestedIP]
if !ok { if !ok {
return &http.Response{ return &http.Response{
@ -135,6 +145,7 @@ func rateLimitedHandler(req *http.Request) *http.Response {
Header: make(http.Header), Header: make(http.Header),
} }
} }
return &http.Response{ return &http.Response{
StatusCode: http.StatusTooManyRequests, StatusCode: http.StatusTooManyRequests,
Body: nil, Body: nil,
@ -151,7 +162,9 @@ func searchHandler(req *http.Request) *http.Response {
Header: make(http.Header), Header: make(http.Header),
} }
} }
url, _ := url.Parse(req.URL.String()) url, _ := url.Parse(req.URL.String())
ipsParam := url.Query().Get("ips") ipsParam := url.Query().Get("ips")
if ipsParam == "" { if ipsParam == "" {
return &http.Response{ return &http.Response{
@ -163,6 +176,7 @@ func searchHandler(req *http.Request) *http.Response {
totalIps := 0 totalIps := 0
notFound := 0 notFound := 0
ips := strings.Split(ipsParam, ",") ips := strings.Split(ipsParam, ",")
for _, ip := range ips { for _, ip := range ips {
_, ok := smokeResponses[ip] _, ok := smokeResponses[ip]
@ -172,12 +186,15 @@ func searchHandler(req *http.Request) *http.Response {
notFound++ notFound++
} }
} }
response := fmt.Sprintf(`{"total": %d, "not_found": %d, "items": [`, totalIps, notFound) response := fmt.Sprintf(`{"total": %d, "not_found": %d, "items": [`, totalIps, notFound)
for _, ip := range ips { for _, ip := range ips {
response += smokeResponses[ip] response += smokeResponses[ip]
} }
response += "]}" response += "]}"
reader := io.NopCloser(strings.NewReader(response)) reader := io.NopCloser(strings.NewReader(response))
return &http.Response{ return &http.Response{
StatusCode: http.StatusOK, StatusCode: http.StatusOK,
Body: reader, Body: reader,
@ -190,7 +207,7 @@ func TestBadFireAuth(t *testing.T) {
Transport: RoundTripFunc(fireHandler), Transport: RoundTripFunc(fireHandler),
})) }))
_, err := ctiClient.Fire(FireParams{}) _, err := ctiClient.Fire(FireParams{})
assert.EqualError(t, err, ErrUnauthorized.Error()) require.EqualError(t, err, ErrUnauthorized.Error())
} }
func TestFireOk(t *testing.T) { func TestFireOk(t *testing.T) {
@ -198,19 +215,19 @@ func TestFireOk(t *testing.T) {
Transport: RoundTripFunc(fireHandler), Transport: RoundTripFunc(fireHandler),
})) }))
data, err := cticlient.Fire(FireParams{}) data, err := cticlient.Fire(FireParams{})
assert.Equal(t, err, nil) require.NoError(t, err)
assert.Equal(t, len(data.Items), 3) assert.Len(t, data.Items, 3)
assert.Equal(t, data.Items[0].Ip, "1.2.3.4") assert.Equal(t, "1.2.3.4", data.Items[0].Ip)
//page 1 is the default //page 1 is the default
data, err = cticlient.Fire(FireParams{Page: ptr.Of(1)}) data, err = cticlient.Fire(FireParams{Page: ptr.Of(1)})
assert.Equal(t, err, nil) require.NoError(t, err)
assert.Equal(t, len(data.Items), 3) assert.Len(t, data.Items, 3)
assert.Equal(t, data.Items[0].Ip, "1.2.3.4") assert.Equal(t, "1.2.3.4", data.Items[0].Ip)
//page 2 //page 2
data, err = cticlient.Fire(FireParams{Page: ptr.Of(2)}) data, err = cticlient.Fire(FireParams{Page: ptr.Of(2)})
assert.Equal(t, err, nil) require.NoError(t, err)
assert.Equal(t, len(data.Items), 3) assert.Len(t, data.Items, 3)
assert.Equal(t, data.Items[0].Ip, "4.2.3.4") assert.Equal(t, "4.2.3.4", data.Items[0].Ip)
} }
func TestFirePaginator(t *testing.T) { func TestFirePaginator(t *testing.T) {
@ -219,17 +236,16 @@ func TestFirePaginator(t *testing.T) {
})) }))
paginator := NewFirePaginator(cticlient, FireParams{}) paginator := NewFirePaginator(cticlient, FireParams{})
items, err := paginator.Next() items, err := paginator.Next()
assert.Equal(t, err, nil) require.NoError(t, err)
assert.Equal(t, len(items), 3) assert.Len(t, items, 3)
assert.Equal(t, items[0].Ip, "1.2.3.4") assert.Equal(t, "1.2.3.4", items[0].Ip)
items, err = paginator.Next() items, err = paginator.Next()
assert.Equal(t, err, nil) require.NoError(t, err)
assert.Equal(t, len(items), 3) assert.Len(t, items, 3)
assert.Equal(t, items[0].Ip, "4.2.3.4") assert.Equal(t, "4.2.3.4", items[0].Ip)
items, err = paginator.Next() items, err = paginator.Next()
assert.Equal(t, err, nil) require.NoError(t, err)
assert.Equal(t, len(items), 0) assert.Empty(t, items)
} }
func TestBadSmokeAuth(t *testing.T) { func TestBadSmokeAuth(t *testing.T) {
@ -237,13 +253,14 @@ func TestBadSmokeAuth(t *testing.T) {
Transport: RoundTripFunc(smokeHandler), Transport: RoundTripFunc(smokeHandler),
})) }))
_, err := ctiClient.GetIPInfo("1.1.1.1") _, err := ctiClient.GetIPInfo("1.1.1.1")
assert.EqualError(t, err, ErrUnauthorized.Error()) require.EqualError(t, err, ErrUnauthorized.Error())
} }
func TestSmokeInfoValidIP(t *testing.T) { func TestSmokeInfoValidIP(t *testing.T) {
ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{ ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{
Transport: RoundTripFunc(smokeHandler), Transport: RoundTripFunc(smokeHandler),
})) }))
resp, err := ctiClient.GetIPInfo("1.1.1.1") resp, err := ctiClient.GetIPInfo("1.1.1.1")
if err != nil { if err != nil {
t.Fatalf("failed to get ip info: %s", err) t.Fatalf("failed to get ip info: %s", err)
@ -257,6 +274,7 @@ func TestSmokeUnknownIP(t *testing.T) {
ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{ ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{
Transport: RoundTripFunc(smokeHandler), Transport: RoundTripFunc(smokeHandler),
})) }))
resp, err := ctiClient.GetIPInfo("42.42.42.42") resp, err := ctiClient.GetIPInfo("42.42.42.42")
if err != nil { if err != nil {
t.Fatalf("failed to get ip info: %s", err) t.Fatalf("failed to get ip info: %s", err)
@ -270,20 +288,22 @@ func TestRateLimit(t *testing.T) {
Transport: RoundTripFunc(rateLimitedHandler), Transport: RoundTripFunc(rateLimitedHandler),
})) }))
_, err := ctiClient.GetIPInfo("1.1.1.1") _, err := ctiClient.GetIPInfo("1.1.1.1")
assert.EqualError(t, err, ErrLimit.Error()) require.EqualError(t, err, ErrLimit.Error())
} }
func TestSearchIPs(t *testing.T) { func TestSearchIPs(t *testing.T) {
ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{ ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{
Transport: RoundTripFunc(searchHandler), Transport: RoundTripFunc(searchHandler),
})) }))
resp, err := ctiClient.SearchIPs([]string{"1.1.1.1", "42.42.42.42"}) resp, err := ctiClient.SearchIPs([]string{"1.1.1.1", "42.42.42.42"})
if err != nil { if err != nil {
t.Fatalf("failed to search ips: %s", err) t.Fatalf("failed to search ips: %s", err)
} }
assert.Equal(t, 1, resp.Total) assert.Equal(t, 1, resp.Total)
assert.Equal(t, 1, resp.NotFound) assert.Equal(t, 1, resp.NotFound)
assert.Equal(t, 1, len(resp.Items)) assert.Len(t, resp.Items, 1)
assert.Equal(t, "1.1.1.1", resp.Items[0].Ip) assert.Equal(t, "1.1.1.1", resp.Items[0].Ip)
} }

View file

@ -88,27 +88,28 @@ func getSampleSmokeItem() SmokeItem {
}, },
}, },
} }
return emptyItem return emptyItem
} }
func TestBasicSmokeItem(t *testing.T) { func TestBasicSmokeItem(t *testing.T) {
item := getSampleSmokeItem() item := getSampleSmokeItem()
assert.Equal(t, item.GetAttackDetails(), []string{"ssh:bruteforce"}) assert.Equal(t, []string{"ssh:bruteforce"}, item.GetAttackDetails())
assert.Equal(t, item.GetBehaviors(), []string{"ssh:bruteforce"}) assert.Equal(t, []string{"ssh:bruteforce"}, item.GetBehaviors())
assert.Equal(t, item.GetMaliciousnessScore(), float32(0.1)) assert.InDelta(t, 0.1, item.GetMaliciousnessScore(), 0.000001)
assert.Equal(t, item.IsPartOfCommunityBlocklist(), false) assert.False(t, item.IsPartOfCommunityBlocklist())
assert.Equal(t, item.GetBackgroundNoiseScore(), int(3)) assert.Equal(t, 3, item.GetBackgroundNoiseScore())
assert.Equal(t, item.GetFalsePositives(), []string{}) assert.Equal(t, []string{}, item.GetFalsePositives())
assert.Equal(t, item.IsFalsePositive(), false) assert.False(t, item.IsFalsePositive())
} }
func TestEmptySmokeItem(t *testing.T) { func TestEmptySmokeItem(t *testing.T) {
item := SmokeItem{} item := SmokeItem{}
assert.Equal(t, item.GetAttackDetails(), []string{}) assert.Equal(t, []string{}, item.GetAttackDetails())
assert.Equal(t, item.GetBehaviors(), []string{}) assert.Equal(t, []string{}, item.GetBehaviors())
assert.Equal(t, item.GetMaliciousnessScore(), float32(0.0)) assert.InDelta(t, 0.0, item.GetMaliciousnessScore(), 0)
assert.Equal(t, item.IsPartOfCommunityBlocklist(), false) assert.False(t, item.IsPartOfCommunityBlocklist())
assert.Equal(t, item.GetBackgroundNoiseScore(), int(0)) assert.Equal(t, 0, item.GetBackgroundNoiseScore())
assert.Equal(t, item.GetFalsePositives(), []string{}) assert.Equal(t, []string{}, item.GetFalsePositives())
assert.Equal(t, item.IsFalsePositive(), false) assert.False(t, item.IsFalsePositive())
} }

View file

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/ptr"
@ -78,6 +79,7 @@ func smokeHandler(req *http.Request) *http.Response {
} }
requestedIP := strings.Split(req.URL.Path, "/")[3] requestedIP := strings.Split(req.URL.Path, "/")[3]
sample, ok := sampledata[requestedIP] sample, ok := sampledata[requestedIP]
if !ok { if !ok {
return &http.Response{ return &http.Response{
@ -109,9 +111,11 @@ func smokeHandler(req *http.Request) *http.Response {
func TestNillClient(t *testing.T) { func TestNillClient(t *testing.T) {
defer ShutdownCrowdsecCTI() defer ShutdownCrowdsecCTI()
if err := InitCrowdsecCTI(ptr.Of(""), nil, nil, nil); !errors.Is(err, cticlient.ErrDisabled) { if err := InitCrowdsecCTI(ptr.Of(""), nil, nil, nil); !errors.Is(err, cticlient.ErrDisabled) {
t.Fatalf("failed to init CTI : %s", err) t.Fatalf("failed to init CTI : %s", err)
} }
item, err := CrowdsecCTI("1.2.3.4") item, err := CrowdsecCTI("1.2.3.4")
assert.Equal(t, err, cticlient.ErrDisabled) assert.Equal(t, err, cticlient.ErrDisabled)
assert.Equal(t, item, &cticlient.SmokeItem{}) assert.Equal(t, item, &cticlient.SmokeItem{})
@ -119,6 +123,7 @@ func TestNillClient(t *testing.T) {
func TestInvalidAuth(t *testing.T) { func TestInvalidAuth(t *testing.T) {
defer ShutdownCrowdsecCTI() defer ShutdownCrowdsecCTI()
if err := InitCrowdsecCTI(ptr.Of("asdasd"), nil, nil, nil); err != nil { if err := InitCrowdsecCTI(ptr.Of("asdasd"), nil, nil, nil); err != nil {
t.Fatalf("failed to init CTI : %s", err) t.Fatalf("failed to init CTI : %s", err)
} }
@ -129,7 +134,7 @@ func TestInvalidAuth(t *testing.T) {
item, err := CrowdsecCTI("1.2.3.4") item, err := CrowdsecCTI("1.2.3.4")
assert.Equal(t, item, &cticlient.SmokeItem{}) assert.Equal(t, item, &cticlient.SmokeItem{})
assert.Equal(t, CTIApiEnabled, false) assert.False(t, CTIApiEnabled)
assert.Equal(t, err, cticlient.ErrUnauthorized) assert.Equal(t, err, cticlient.ErrUnauthorized)
//CTI is now disabled, all requests should return empty //CTI is now disabled, all requests should return empty
@ -139,14 +144,15 @@ func TestInvalidAuth(t *testing.T) {
item, err = CrowdsecCTI("1.2.3.4") item, err = CrowdsecCTI("1.2.3.4")
assert.Equal(t, item, &cticlient.SmokeItem{}) assert.Equal(t, item, &cticlient.SmokeItem{})
assert.Equal(t, CTIApiEnabled, false) assert.False(t, CTIApiEnabled)
assert.Equal(t, err, cticlient.ErrDisabled) assert.Equal(t, err, cticlient.ErrDisabled)
} }
func TestNoKey(t *testing.T) { func TestNoKey(t *testing.T) {
defer ShutdownCrowdsecCTI() defer ShutdownCrowdsecCTI()
err := InitCrowdsecCTI(nil, nil, nil, nil) err := InitCrowdsecCTI(nil, nil, nil, nil)
assert.ErrorIs(t, err, cticlient.ErrDisabled) require.ErrorIs(t, err, cticlient.ErrDisabled)
//Replace the client created by InitCrowdsecCTI with one that uses a custom transport //Replace the client created by InitCrowdsecCTI with one that uses a custom transport
ctiClient = cticlient.NewCrowdsecCTIClient(cticlient.WithAPIKey("asdasd"), cticlient.WithHTTPClient(&http.Client{ ctiClient = cticlient.NewCrowdsecCTIClient(cticlient.WithAPIKey("asdasd"), cticlient.WithHTTPClient(&http.Client{
Transport: RoundTripFunc(smokeHandler), Transport: RoundTripFunc(smokeHandler),
@ -154,12 +160,13 @@ func TestNoKey(t *testing.T) {
item, err := CrowdsecCTI("1.2.3.4") item, err := CrowdsecCTI("1.2.3.4")
assert.Equal(t, item, &cticlient.SmokeItem{}) assert.Equal(t, item, &cticlient.SmokeItem{})
assert.Equal(t, CTIApiEnabled, false) assert.False(t, CTIApiEnabled)
assert.Equal(t, err, cticlient.ErrDisabled) assert.Equal(t, err, cticlient.ErrDisabled)
} }
func TestCache(t *testing.T) { func TestCache(t *testing.T) {
defer ShutdownCrowdsecCTI() defer ShutdownCrowdsecCTI()
cacheDuration := 1 * time.Second cacheDuration := 1 * time.Second
if err := InitCrowdsecCTI(ptr.Of(validApiKey), &cacheDuration, nil, nil); err != nil { if err := InitCrowdsecCTI(ptr.Of(validApiKey), &cacheDuration, nil, nil); err != nil {
t.Fatalf("failed to init CTI : %s", err) t.Fatalf("failed to init CTI : %s", err)
@ -172,28 +179,27 @@ func TestCache(t *testing.T) {
item, err := CrowdsecCTI("1.2.3.4") item, err := CrowdsecCTI("1.2.3.4")
ctiResp := item.(*cticlient.SmokeItem) ctiResp := item.(*cticlient.SmokeItem)
assert.Equal(t, "1.2.3.4", ctiResp.Ip) assert.Equal(t, "1.2.3.4", ctiResp.Ip)
assert.Equal(t, CTIApiEnabled, true) assert.True(t, CTIApiEnabled)
assert.Equal(t, CTICache.Len(true), 1) assert.Equal(t, 1, CTICache.Len(true))
assert.Equal(t, err, nil) require.NoError(t, err)
item, err = CrowdsecCTI("1.2.3.4") item, err = CrowdsecCTI("1.2.3.4")
ctiResp = item.(*cticlient.SmokeItem) ctiResp = item.(*cticlient.SmokeItem)
assert.Equal(t, "1.2.3.4", ctiResp.Ip) assert.Equal(t, "1.2.3.4", ctiResp.Ip)
assert.Equal(t, CTIApiEnabled, true) assert.True(t, CTIApiEnabled)
assert.Equal(t, CTICache.Len(true), 1) assert.Equal(t, 1, CTICache.Len(true))
assert.Equal(t, err, nil) require.NoError(t, err)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
assert.Equal(t, CTICache.Len(true), 0) assert.Equal(t, 0, CTICache.Len(true))
item, err = CrowdsecCTI("1.2.3.4") item, err = CrowdsecCTI("1.2.3.4")
ctiResp = item.(*cticlient.SmokeItem) ctiResp = item.(*cticlient.SmokeItem)
assert.Equal(t, "1.2.3.4", ctiResp.Ip) assert.Equal(t, "1.2.3.4", ctiResp.Ip)
assert.Equal(t, CTIApiEnabled, true) assert.True(t, CTIApiEnabled)
assert.Equal(t, CTICache.Len(true), 1) assert.Equal(t, 1, CTICache.Len(true))
assert.Equal(t, err, nil) require.NoError(t, err)
} }

View file

@ -28,17 +28,18 @@ var (
func getDBClient(t *testing.T) *database.Client { func getDBClient(t *testing.T) *database.Client {
t.Helper() t.Helper()
dbPath, err := os.CreateTemp("", "*sqlite") dbPath, err := os.CreateTemp("", "*sqlite")
require.NoError(t, err) require.NoError(t, err)
testDbClient, err := database.NewClient(&csconfig.DatabaseCfg{ testDBClient, err := database.NewClient(&csconfig.DatabaseCfg{
Type: "sqlite", Type: "sqlite",
DbName: "crowdsec", DbName: "crowdsec",
DbPath: dbPath.Name(), DbPath: dbPath.Name(),
}) })
require.NoError(t, err) require.NoError(t, err)
return testDbClient return testDBClient
} }
func TestVisitor(t *testing.T) { func TestVisitor(t *testing.T) {
@ -109,17 +110,18 @@ func TestVisitor(t *testing.T) {
if err != nil && test.err == nil { if err != nil && test.err == nil {
log.Fatalf("run : %s", err) log.Fatalf("run : %s", err)
} }
if isOk := assert.Equal(t, test.result, result); !isOk { if isOk := assert.Equal(t, test.result, result); !isOk {
t.Fatalf("test '%s' : NOK", test.filter) t.Fatalf("test '%s' : NOK", test.filter)
} }
} }
} }
} }
func TestMatch(t *testing.T) { func TestMatch(t *testing.T) {
err := Init(nil) err := Init(nil)
require.NoError(t, err) require.NoError(t, err)
tests := []struct { tests := []struct {
glob string glob string
val string val string
@ -149,12 +151,15 @@ func TestMatch(t *testing.T) {
"pattern": test.glob, "pattern": test.glob,
"name": test.val, "name": test.val,
} }
vm, err := expr.Compile(test.expr, GetExprOptions(env)...) vm, err := expr.Compile(test.expr, GetExprOptions(env)...)
if err != nil { if err != nil {
t.Fatalf("pattern:%s val:%s NOK %s", test.glob, test.val, err) t.Fatalf("pattern:%s val:%s NOK %s", test.glob, test.val, err)
} }
ret, err := expr.Run(vm, env) ret, err := expr.Run(vm, env)
assert.NoError(t, err) require.NoError(t, err)
if isOk := assert.Equal(t, test.ret, ret); !isOk { if isOk := assert.Equal(t, test.ret, ret); !isOk {
t.Fatalf("pattern:%s val:%s NOK %t != %t", test.glob, test.val, ret, test.ret) t.Fatalf("pattern:%s val:%s NOK %t != %t", test.glob, test.val, ret, test.ret)
} }
@ -194,10 +199,10 @@ func TestDistanceHelper(t *testing.T) {
} }
ret, err := expr.Run(vm, env) ret, err := expr.Run(vm, env)
if test.valid { if test.valid {
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, test.dist, ret) assert.Equal(t, test.dist, ret)
} else { } else {
assert.NotNil(t, err) require.Error(t, err)
} }
}) })
} }
@ -283,10 +288,12 @@ func TestRegexpInFile(t *testing.T) {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
result, err := expr.Run(compiledFilter, map[string]interface{}{}) result, err := expr.Run(compiledFilter, map[string]interface{}{})
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
if isOk := assert.Equal(t, test.result, result); !isOk { if isOk := assert.Equal(t, test.result, result); !isOk {
t.Fatalf("test '%s' : NOK", test.name) t.Fatalf("test '%s' : NOK", test.name)
} }
@ -335,28 +342,34 @@ func TestFileInit(t *testing.T) {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
if test.types == "string" {
switch test.types {
case "string":
if _, ok := dataFile[test.filename]; !ok { if _, ok := dataFile[test.filename]; !ok {
t.Fatalf("test '%s' : NOK", test.name) t.Fatalf("test '%s' : NOK", test.name)
} }
if isOk := assert.Equal(t, test.result, len(dataFile[test.filename])); !isOk {
if isOk := assert.Len(t, dataFile[test.filename], test.result); !isOk {
t.Fatalf("test '%s' : NOK", test.name) t.Fatalf("test '%s' : NOK", test.name)
} }
} else if test.types == "regex" { case "regex":
if _, ok := dataFileRegex[test.filename]; !ok { if _, ok := dataFileRegex[test.filename]; !ok {
t.Fatalf("test '%s' : NOK", test.name) t.Fatalf("test '%s' : NOK", test.name)
} }
if isOk := assert.Equal(t, test.result, len(dataFileRegex[test.filename])); !isOk {
if isOk := assert.Len(t, dataFileRegex[test.filename], test.result); !isOk {
t.Fatalf("test '%s' : NOK", test.name) t.Fatalf("test '%s' : NOK", test.name)
} }
} else { default:
if _, ok := dataFileRegex[test.filename]; ok { if _, ok := dataFileRegex[test.filename]; ok {
t.Fatalf("test '%s' : NOK", test.name) t.Fatalf("test '%s' : NOK", test.name)
} }
if _, ok := dataFile[test.filename]; ok { if _, ok := dataFile[test.filename]; ok {
t.Fatalf("test '%s' : NOK", test.name) t.Fatalf("test '%s' : NOK", test.name)
} }
} }
log.Printf("test '%s' : OK", test.name) log.Printf("test '%s' : OK", test.name)
} }
} }
@ -408,21 +421,23 @@ func TestFile(t *testing.T) {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
result, err := expr.Run(compiledFilter, map[string]interface{}{}) result, err := expr.Run(compiledFilter, map[string]interface{}{})
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
if isOk := assert.Equal(t, test.result, result); !isOk { if isOk := assert.Equal(t, test.result, result); !isOk {
t.Fatalf("test '%s' : NOK", test.name) t.Fatalf("test '%s' : NOK", test.name)
} }
log.Printf("test '%s' : OK", test.name)
log.Printf("test '%s' : OK", test.name)
} }
} }
func TestIpInRange(t *testing.T) { func TestIpInRange(t *testing.T) {
err := Init(nil) err := Init(nil)
assert.NoError(t, err) require.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
env map[string]interface{} env map[string]interface{}
@ -470,12 +485,11 @@ func TestIpInRange(t *testing.T) {
require.Equal(t, test.result, output) require.Equal(t, test.result, output)
log.Printf("test '%s' : OK", test.name) log.Printf("test '%s' : OK", test.name)
} }
} }
func TestIpToRange(t *testing.T) { func TestIpToRange(t *testing.T) {
err := Init(nil) err := Init(nil)
assert.NoError(t, err) require.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
env map[string]interface{} env map[string]interface{}
@ -543,13 +557,11 @@ func TestIpToRange(t *testing.T) {
require.Equal(t, test.result, output) require.Equal(t, test.result, output)
log.Printf("test '%s' : OK", test.name) log.Printf("test '%s' : OK", test.name)
} }
} }
func TestAtof(t *testing.T) { func TestAtof(t *testing.T) {
err := Init(nil) err := Init(nil)
assert.NoError(t, err) require.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
@ -593,13 +605,14 @@ func TestUpper(t *testing.T) {
} }
err := Init(nil) err := Init(nil)
assert.NoError(t, err) require.NoError(t, err)
vm, err := expr.Compile("Upper(testStr)", GetExprOptions(env)...) vm, err := expr.Compile("Upper(testStr)", GetExprOptions(env)...)
assert.NoError(t, err) require.NoError(t, err)
out, err := expr.Run(vm, env) out, err := expr.Run(vm, env)
assert.NoError(t, err) require.NoError(t, err)
v, ok := out.(string) v, ok := out.(string)
if !ok { if !ok {
t.Fatalf("Upper() should return a string") t.Fatalf("Upper() should return a string")
@ -612,6 +625,7 @@ func TestUpper(t *testing.T) {
func TestTimeNow(t *testing.T) { func TestTimeNow(t *testing.T) {
now, _ := TimeNow() now, _ := TimeNow()
ti, err := time.Parse(time.RFC3339, now.(string)) ti, err := time.Parse(time.RFC3339, now.(string))
if err != nil { if err != nil {
t.Fatalf("Error parsing the return value of TimeNow: %s", err) t.Fatalf("Error parsing the return value of TimeNow: %s", err)
@ -620,6 +634,7 @@ func TestTimeNow(t *testing.T) {
if -1*time.Until(ti) > time.Second { if -1*time.Until(ti) > time.Second {
t.Fatalf("TimeNow func should return time.Now().UTC()") t.Fatalf("TimeNow func should return time.Now().UTC()")
} }
log.Printf("test 'TimeNow()' : OK") log.Printf("test 'TimeNow()' : OK")
} }
@ -894,15 +909,14 @@ func TestLower(t *testing.T) {
} }
func TestGetDecisionsCount(t *testing.T) { func TestGetDecisionsCount(t *testing.T) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int
existingIP := "1.2.3.4" existingIP := "1.2.3.4"
unknownIP := "1.2.3.5" unknownIP := "1.2.3.5"
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(existingIP)
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(existingIP)
if err != nil { if err != nil {
t.Errorf("unable to convert '%s' to int: %s", existingIP, err) t.Errorf("unable to convert '%s' to int: %s", existingIP, err)
} }
// Add sample data to DB // Add sample data to DB
dbClient = getDBClient(t) dbClient = getDBClient(t)
@ -921,11 +935,11 @@ func TestGetDecisionsCount(t *testing.T) {
SaveX(context.Background()) SaveX(context.Background())
if decision == nil { if decision == nil {
assert.Error(t, errors.Errorf("Failed to create sample decision")) require.Error(t, errors.Errorf("Failed to create sample decision"))
} }
err = Init(dbClient) err = Init(dbClient)
assert.NoError(t, err) require.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
@ -982,12 +996,10 @@ func TestGetDecisionsCount(t *testing.T) {
} }
} }
func TestGetDecisionsSinceCount(t *testing.T) { func TestGetDecisionsSinceCount(t *testing.T) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int
existingIP := "1.2.3.4" existingIP := "1.2.3.4"
unknownIP := "1.2.3.5" unknownIP := "1.2.3.5"
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(existingIP)
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(existingIP)
if err != nil { if err != nil {
t.Errorf("unable to convert '%s' to int: %s", existingIP, err) t.Errorf("unable to convert '%s' to int: %s", existingIP, err)
} }
@ -1008,8 +1020,9 @@ func TestGetDecisionsSinceCount(t *testing.T) {
SetOrigin("CAPI"). SetOrigin("CAPI").
SaveX(context.Background()) SaveX(context.Background())
if decision == nil { if decision == nil {
assert.Error(t, errors.Errorf("Failed to create sample decision")) require.Error(t, errors.Errorf("Failed to create sample decision"))
} }
decision2 := dbClient.Ent.Decision.Create(). decision2 := dbClient.Ent.Decision.Create().
SetCreatedAt(time.Now().AddDate(0, 0, -1)). SetCreatedAt(time.Now().AddDate(0, 0, -1)).
SetUntil(time.Now().AddDate(0, 0, -1)). SetUntil(time.Now().AddDate(0, 0, -1)).
@ -1024,12 +1037,13 @@ func TestGetDecisionsSinceCount(t *testing.T) {
SetValue(existingIP). SetValue(existingIP).
SetOrigin("CAPI"). SetOrigin("CAPI").
SaveX(context.Background()) SaveX(context.Background())
if decision2 == nil { if decision2 == nil {
assert.Error(t, errors.Errorf("Failed to create sample decision")) require.Error(t, errors.Errorf("Failed to create sample decision"))
} }
err = Init(dbClient) err = Init(dbClient)
assert.NoError(t, err) require.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
@ -1152,6 +1166,7 @@ func TestIsIp(t *testing.T) {
if err := Init(nil); err != nil { if err := Init(nil); err != nil {
log.Fatal(err) log.Fatal(err)
} }
tests := []struct { tests := []struct {
name string name string
expr string expr string
@ -1235,17 +1250,18 @@ func TestIsIp(t *testing.T) {
expectedBuildErr: true, expectedBuildErr: true,
}, },
} }
for _, tc := range tests { for _, tc := range tests {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...)
if tc.expectedBuildErr { if tc.expectedBuildErr {
assert.Error(t, err) require.Error(t, err)
return return
} }
assert.NoError(t, err) require.NoError(t, err)
output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) output, err := expr.Run(vm, map[string]interface{}{"value": tc.value})
assert.NoError(t, err) require.NoError(t, err)
assert.IsType(t, tc.expected, output) assert.IsType(t, tc.expected, output)
assert.Equal(t, tc.expected, output.(bool)) assert.Equal(t, tc.expected, output.(bool))
}) })
@ -1255,6 +1271,7 @@ func TestIsIp(t *testing.T) {
func TestToString(t *testing.T) { func TestToString(t *testing.T) {
err := Init(nil) err := Init(nil)
require.NoError(t, err) require.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
value interface{} value interface{}
@ -1290,9 +1307,9 @@ func TestToString(t *testing.T) {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...)
assert.NoError(t, err) require.NoError(t, err)
output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) output, err := expr.Run(vm, map[string]interface{}{"value": tc.value})
assert.NoError(t, err) require.NoError(t, err)
require.Equal(t, tc.expected, output) require.Equal(t, tc.expected, output)
}) })
} }
@ -1338,16 +1355,16 @@ func TestB64Decode(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...)
if tc.expectedBuildErr { if tc.expectedBuildErr {
assert.Error(t, err) require.Error(t, err)
return return
} }
assert.NoError(t, err) require.NoError(t, err)
output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) output, err := expr.Run(vm, map[string]interface{}{"value": tc.value})
if tc.expectedRuntimeErr { if tc.expectedRuntimeErr {
assert.Error(t, err) require.Error(t, err)
return return
} }
assert.NoError(t, err) require.NoError(t, err)
require.Equal(t, tc.expected, output) require.Equal(t, tc.expected, output)
}) })
} }
@ -1412,9 +1429,9 @@ func TestParseKv(t *testing.T) {
"out": outMap, "out": outMap,
} }
vm, err := expr.Compile(tc.expr, GetExprOptions(env)...) vm, err := expr.Compile(tc.expr, GetExprOptions(env)...)
assert.NoError(t, err) require.NoError(t, err)
_, err = expr.Run(vm, env) _, err = expr.Run(vm, env)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tc.expected, outMap["a"]) assert.Equal(t, tc.expected, outMap["a"])
}) })
} }

View file

@ -7,6 +7,7 @@ import (
"github.com/antonmedv/expr" "github.com/antonmedv/expr"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestJsonExtract(t *testing.T) { func TestJsonExtract(t *testing.T) {
@ -56,14 +57,14 @@ func TestJsonExtract(t *testing.T) {
"target": test.targetField, "target": test.targetField,
} }
vm, err := expr.Compile(test.expr, GetExprOptions(env)...) vm, err := expr.Compile(test.expr, GetExprOptions(env)...)
assert.NoError(t, err) require.NoError(t, err)
out, err := expr.Run(vm, env) out, err := expr.Run(vm, env)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, test.expectResult, out) assert.Equal(t, test.expectResult, out)
}) })
} }
} }
func TestJsonExtractUnescape(t *testing.T) { func TestJsonExtractUnescape(t *testing.T) {
if err := Init(nil); err != nil { if err := Init(nil); err != nil {
log.Fatal(err) log.Fatal(err)
@ -104,9 +105,9 @@ func TestJsonExtractUnescape(t *testing.T) {
"target": test.targetField, "target": test.targetField,
} }
vm, err := expr.Compile(test.expr, GetExprOptions(env)...) vm, err := expr.Compile(test.expr, GetExprOptions(env)...)
assert.NoError(t, err) require.NoError(t, err)
out, err := expr.Run(vm, env) out, err := expr.Run(vm, env)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, test.expectResult, out) assert.Equal(t, test.expectResult, out)
}) })
} }
@ -167,9 +168,9 @@ func TestJsonExtractSlice(t *testing.T) {
"target": test.targetField, "target": test.targetField,
} }
vm, err := expr.Compile(test.expr, GetExprOptions(env)...) vm, err := expr.Compile(test.expr, GetExprOptions(env)...)
assert.NoError(t, err) require.NoError(t, err)
out, err := expr.Run(vm, env) out, err := expr.Run(vm, env)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, test.expectResult, out) assert.Equal(t, test.expectResult, out)
}) })
} }
@ -223,9 +224,9 @@ func TestJsonExtractObject(t *testing.T) {
"target": test.targetField, "target": test.targetField,
} }
vm, err := expr.Compile(test.expr, GetExprOptions(env)...) vm, err := expr.Compile(test.expr, GetExprOptions(env)...)
assert.NoError(t, err) require.NoError(t, err)
out, err := expr.Run(vm, env) out, err := expr.Run(vm, env)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, test.expectResult, out) assert.Equal(t, test.expectResult, out)
}) })
} }
@ -233,7 +234,8 @@ func TestJsonExtractObject(t *testing.T) {
func TestToJson(t *testing.T) { func TestToJson(t *testing.T) {
err := Init(nil) err := Init(nil)
assert.NoError(t, err) require.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
obj interface{} obj interface{}
@ -298,9 +300,9 @@ func TestToJson(t *testing.T) {
"obj": test.obj, "obj": test.obj,
} }
vm, err := expr.Compile(test.expr, GetExprOptions(env)...) vm, err := expr.Compile(test.expr, GetExprOptions(env)...)
assert.NoError(t, err) require.NoError(t, err)
out, err := expr.Run(vm, env) out, err := expr.Run(vm, env)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, test.expectResult, out) assert.Equal(t, test.expectResult, out)
}) })
} }
@ -308,7 +310,8 @@ func TestToJson(t *testing.T) {
func TestUnmarshalJSON(t *testing.T) { func TestUnmarshalJSON(t *testing.T) {
err := Init(nil) err := Init(nil)
assert.NoError(t, err) require.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
json string json string
@ -361,11 +364,10 @@ func TestUnmarshalJSON(t *testing.T) {
"out": outMap, "out": outMap,
} }
vm, err := expr.Compile(test.expr, GetExprOptions(env)...) vm, err := expr.Compile(test.expr, GetExprOptions(env)...)
assert.NoError(t, err) require.NoError(t, err)
_, err = expr.Run(vm, env) _, err = expr.Run(vm, env)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, test.expectResult, outMap["a"]) assert.Equal(t, test.expectResult, outMap["a"])
}) })
} }
} }

View file

@ -353,7 +353,7 @@ func TestUnitFound(t *testing.T) {
installed, err := env.UnitFound("crowdsec-setup-detect.service") installed, err := env.UnitFound("crowdsec-setup-detect.service")
require.NoError(err) require.NoError(err)
require.Equal(true, installed) require.True(installed)
} }
// TODO apply rules to filter a list of Service structs // TODO apply rules to filter a list of Service structs
@ -566,8 +566,8 @@ func TestDetectForcedUnit(t *testing.T) {
func TestDetectForcedProcess(t *testing.T) { func TestDetectForcedProcess(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("skipping on windows")
// while looking for service wizard: rule 'ProcessRunning("foobar")': while looking up running processes: could not get Name: A device attached to the system is not functioning. // while looking for service wizard: rule 'ProcessRunning("foobar")': while looking up running processes: could not get Name: A device attached to the system is not functioning.
t.Skip("skipping on windows")
} }
require := require.New(t) require := require.New(t)

View file

@ -73,7 +73,7 @@ func TestParseIPSources(t *testing.T) {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ips := tt.evt.ParseIPSources() ips := tt.evt.ParseIPSources()
assert.Equal(t, ips, tt.expected) assert.Equal(t, tt.expected, ips)
}) })
} }
} }