Bladeren bron

Refactor unit tests to reduce line count (#1264)

mmetc 3 jaren geleden
bovenliggende
commit
9bc7e6ffcf

+ 16 - 22
pkg/acquisition/acquisition_test.go

@@ -8,6 +8,7 @@ import (
 
 	"github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration"
 	"github.com/crowdsecurity/crowdsec/pkg/csconfig"
+	"github.com/crowdsecurity/crowdsec/pkg/cstest"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	"github.com/pkg/errors"
 	"github.com/prometheus/client_golang/prometheus"
@@ -184,10 +185,9 @@ wowo: ajsajasjas
 				t.Fatalf("%s : expected error '%s' in '%s'", test.TestName, test.ExpectedError, err.Error())
 			}
 			continue
-		} else {
-			if err != nil {
-				t.Fatalf("%s : unexpected error '%s'", test.TestName, err)
-			}
+		}
+		if err != nil {
+			t.Fatalf("%s : unexpected error '%s'", test.TestName, err)
 		}
 
 		switch test.TestName {
@@ -290,10 +290,9 @@ func TestLoadAcquisitionFromFile(t *testing.T) {
 				t.Fatalf("%s : expected error '%s' in '%s'", test.TestName, test.ExpectedError, err.Error())
 			}
 			continue
-		} else {
-			if err != nil {
-				t.Fatalf("%s : unexpected error '%s'", test.TestName, err)
-			}
+		}
+		if err != nil {
+			t.Fatalf("%s : unexpected error '%s'", test.TestName, err)
 		}
 		if len(dss) != test.ExpectedLen {
 			t.Fatalf("%s : expected %d datasources got %d", test.TestName, test.ExpectedLen, len(dss))
@@ -336,11 +335,13 @@ func (f *MockCat) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) erro
 func (f *MockCat) StreamingAcquisition(chan types.Event, *tomb.Tomb) error {
 	return fmt.Errorf("can't run in tail")
 }
-func (f *MockCat) CanRun() error                                   { return nil }
-func (f *MockCat) GetMetrics() []prometheus.Collector              { return nil }
-func (f *MockCat) GetAggregMetrics() []prometheus.Collector        { return nil }
-func (f *MockCat) Dump() interface{}                               { return f }
-func (f *MockCat) ConfigureByDSN(string, map[string]string, *log.Entry) error { return fmt.Errorf("not supported") }
+func (f *MockCat) CanRun() error                            { return nil }
+func (f *MockCat) GetMetrics() []prometheus.Collector       { return nil }
+func (f *MockCat) GetAggregMetrics() []prometheus.Collector { return nil }
+func (f *MockCat) Dump() interface{}                        { return f }
+func (f *MockCat) ConfigureByDSN(string, map[string]string, *log.Entry) error {
+	return fmt.Errorf("not supported")
+}
 
 //----
 
@@ -554,15 +555,8 @@ func TestConfigureByDSN(t *testing.T) {
 
 	for _, test := range tests {
 		srcs, err := LoadAcquisitionFromDSN(test.dsn, map[string]string{"type": "test_label"})
-		if err != nil && test.ExpectedError != "" {
-			if !strings.Contains(err.Error(), test.ExpectedError) {
-				t.Fatalf("expected '%s', got '%s'", test.ExpectedError, err.Error())
-			}
-		} else if err != nil && test.ExpectedError == "" {
-			t.Fatalf("got unexpected error '%s'", err.Error())
-		} else if err == nil && test.ExpectedError != "" {
-			t.Fatalf("expected error '%s' got none", test.ExpectedError)
-		}
+		cstest.AssertErrorContains(t, err, test.ExpectedError)
+
 		if len(srcs) != test.ExpectedResLen {
 			t.Fatalf("expected %d results, got %d", test.ExpectedResLen, len(srcs))
 		}

+ 6 - 27
pkg/acquisition/modules/docker/docker_test.go

@@ -9,6 +9,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/crowdsecurity/crowdsec/pkg/cstest"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	dockerTypes "github.com/docker/docker/api/types"
 	"github.com/docker/docker/client"
@@ -53,12 +54,7 @@ container_name:
 	for _, test := range tests {
 		f := DockerSource{}
 		err := f.Configure([]byte(test.config), subLogger)
-		if test.expectedErr != "" && err == nil {
-			t.Fatalf("Expected err %s but got nil !", test.expectedErr)
-		}
-		if test.expectedErr != "" {
-			assert.Contains(t, err.Error(), test.expectedErr)
-		}
+		cstest.AssertErrorContains(t, err, test.expectedErr)
 	}
 }
 
@@ -102,11 +98,7 @@ func TestConfigureDSN(t *testing.T) {
 	for _, test := range tests {
 		f := DockerSource{}
 		err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger)
-		if test.expectedErr != "" {
-			assert.Contains(t, err.Error(), test.expectedErr)
-		} else {
-			assert.Equal(t, err, nil)
-		}
+		cstest.AssertErrorContains(t, err, test.expectedErr)
 	}
 }
 
@@ -196,14 +188,8 @@ container_name_regexp:
 			}
 		})
 		time.Sleep(10 * time.Second)
-		if ts.expectedErr == "" && err != nil {
-			t.Fatalf("Unexpected error : %s", err)
-		} else if ts.expectedErr != "" && err != nil {
-			assert.Contains(t, err.Error(), ts.expectedErr)
-			continue
-		} else if ts.expectedErr != "" && err == nil {
-			t.Fatalf("Expected error %s, but got nothing !", ts.expectedErr)
-		}
+		cstest.AssertErrorContains(t, err, ts.expectedErr)
+
 		if err := readerTomb.Wait(); err != nil {
 			t.Fatal(err)
 		}
@@ -311,15 +297,8 @@ func TestOneShot(t *testing.T) {
 		}
 		tomb := tomb.Tomb{}
 		err := dockerClient.OneShotAcquisition(out, &tomb)
+		cstest.AssertErrorContains(t, err, ts.expectedErr)
 
-		if ts.expectedErr == "" && err != nil {
-			t.Fatalf("Unexpected error : %s", err)
-		} else if ts.expectedErr != "" && err != nil {
-			assert.Contains(t, err.Error(), ts.expectedErr)
-			continue
-		} else if ts.expectedErr != "" && err == nil {
-			t.Fatalf("Expected error %s, but got nothing !", ts.expectedErr)
-		}
 		// else we do the check before actualLines is incremented ...
 		time.Sleep(1 * time.Second)
 		if ts.expectedLines != 0 {

+ 52 - 58
pkg/acquisition/modules/file/file_test.go

@@ -6,6 +6,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/crowdsecurity/crowdsec/pkg/cstest"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	log "github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus/hooks/test"
@@ -70,87 +71,91 @@ func TestConfigureDSN(t *testing.T) {
 	for _, test := range tests {
 		f := FileSource{}
 		err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger)
-		if test.expectedErr != "" {
-			assert.Contains(t, err.Error(), test.expectedErr)
-		} else {
-			assert.Equal(t, err, nil)
-		}
+		cstest.AssertErrorContains(t, err, test.expectedErr)
 	}
 }
 
 func TestOneShot(t *testing.T) {
 	tests := []struct {
-		config         string
-		expectedErr    string
-		expectedOutput string
-		expectedLines  int
-		logLevel       log.Level
-		setup          func()
-		afterConfigure func()
-		teardown       func()
+		config            string
+		expectedConfigErr string
+		expectedErr       string
+		expectedOutput    string
+		expectedLines     int
+		logLevel          log.Level
+		setup             func()
+		afterConfigure    func()
+		teardown          func()
 	}{
 		{
 			config: `
 mode: cat
 filename: /etc/shadow`,
-			expectedErr:    "failed opening /etc/shadow: open /etc/shadow: permission denied",
-			expectedOutput: "",
-			logLevel:       log.WarnLevel,
-			expectedLines:  0,
+			expectedConfigErr: "",
+			expectedErr:       "failed opening /etc/shadow: open /etc/shadow: permission denied",
+			expectedOutput:    "",
+			logLevel:          log.WarnLevel,
+			expectedLines:     0,
 		},
 		{
 			config: `
 mode: cat
 filename: /`,
-			expectedErr:    "",
-			expectedOutput: "/ is a directory, ignoring it",
-			logLevel:       log.WarnLevel,
-			expectedLines:  0,
+			expectedConfigErr: "",
+			expectedErr:       "",
+			expectedOutput:    "/ is a directory, ignoring it",
+			logLevel:          log.WarnLevel,
+			expectedLines:     0,
 		},
 		{
 			config: `
 mode: cat
 filename: "[*-.log"`,
-			expectedErr:    "Glob failure: syntax error in pattern",
-			expectedOutput: "",
-			logLevel:       log.WarnLevel,
-			expectedLines:  0,
+			expectedConfigErr: "Glob failure: syntax error in pattern",
+			expectedErr:       "",
+			expectedOutput:    "",
+			logLevel:          log.WarnLevel,
+			expectedLines:     0,
 		},
 		{
 			config: `
 mode: cat
 filename: /do/not/exist`,
-			expectedErr:    "",
-			expectedOutput: "No matching files for pattern /do/not/exist",
-			logLevel:       log.WarnLevel,
-			expectedLines:  0,
+			expectedConfigErr: "",
+			expectedErr:       "",
+			expectedOutput:    "No matching files for pattern /do/not/exist",
+			logLevel:          log.WarnLevel,
+			expectedLines:     0,
 		},
 		{
 			config: `
 mode: cat
 filename: test_files/test.log`,
-			expectedErr:    "",
-			expectedOutput: "",
-			expectedLines:  5,
-			logLevel:       log.WarnLevel,
+			expectedConfigErr: "",
+			expectedErr:       "",
+			expectedOutput:    "",
+			expectedLines:     5,
+			logLevel:          log.WarnLevel,
 		},
 		{
 			config: `
 mode: cat
 filename: test_files/test.log.gz`,
-			expectedErr:    "",
-			expectedOutput: "",
-			expectedLines:  5,
-			logLevel:       log.WarnLevel,
+			expectedConfigErr: "",
+			expectedErr:       "",
+			expectedOutput:    "",
+			expectedLines:     5,
+			logLevel:          log.WarnLevel,
 		},
 		{
 			config: `
 mode: cat
 filename: test_files/bad.gz`,
-			expectedErr:    "failed to read gz test_files/bad.gz: unexpected EOF",
-			expectedOutput: "",
-			expectedLines:  0,
-			logLevel:       log.WarnLevel,
+			expectedConfigErr: "",
+			expectedErr:       "failed to read gz test_files/bad.gz: unexpected EOF",
+			expectedOutput:    "",
+			expectedLines:     0,
+			logLevel:          log.WarnLevel,
 		},
 		{
 			config: `
@@ -179,12 +184,11 @@ filename: test_files/test_delete.log`,
 			ts.setup()
 		}
 		err := f.Configure([]byte(ts.config), subLogger)
-		if err != nil && ts.expectedErr != "" {
-			assert.Contains(t, err.Error(), ts.expectedErr)
+		cstest.AssertErrorContains(t, err, ts.expectedConfigErr)
+		if err != nil {
 			continue
-		} else if err != nil && ts.expectedErr == "" {
-			t.Fatalf("Unexpected error : %s", err)
 		}
+
 		if ts.afterConfigure != nil {
 			ts.afterConfigure()
 		}
@@ -203,15 +207,11 @@ filename: test_files/test_delete.log`,
 			}()
 		}
 		err = f.OneShotAcquisition(out, &tomb)
+		cstest.AssertErrorContains(t, err, ts.expectedErr)
+
 		if ts.expectedLines != 0 {
 			assert.Equal(t, actualLines, ts.expectedLines)
 		}
-		if ts.expectedErr != "" {
-			if err == nil {
-				t.Fatalf("Expected error but got nothing ! %+v", ts)
-			}
-			assert.Contains(t, err.Error(), ts.expectedErr)
-		}
 		if ts.expectedOutput != "" {
 			assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput)
 			hook.Reset()
@@ -359,13 +359,7 @@ force_inotify: true`,
 			}()
 		}
 		err = f.StreamingAcquisition(out, &tomb)
-
-		if ts.expectedErr != "" {
-			if err == nil {
-				t.Fatalf("Expected error but got nothing ! %+v", ts)
-			}
-			assert.Contains(t, err.Error(), ts.expectedErr)
-		}
+		cstest.AssertErrorContains(t, err, ts.expectedErr)
 
 		if ts.expectedLines != 0 {
 			fd, err := os.Create("test_files/stream.log")

+ 8 - 23
pkg/acquisition/modules/journalctl/journalctl_test.go

@@ -6,6 +6,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/crowdsecurity/crowdsec/pkg/cstest"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	log "github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus/hooks/test"
@@ -44,12 +45,7 @@ journalctl_filter:
 	for _, test := range tests {
 		f := JournalCtlSource{}
 		err := f.Configure([]byte(test.config), subLogger)
-		if test.expectedErr != "" && err == nil {
-			t.Fatalf("Expected err %s but got nil !", test.expectedErr)
-		}
-		if test.expectedErr != "" {
-			assert.Contains(t, err.Error(), test.expectedErr)
-		}
+		cstest.AssertErrorContains(t, err, test.expectedErr)
 	}
 }
 
@@ -93,11 +89,7 @@ func TestConfigureDSN(t *testing.T) {
 	for _, test := range tests {
 		f := JournalCtlSource{}
 		err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger)
-		if test.expectedErr != "" {
-			assert.Contains(t, err.Error(), test.expectedErr)
-		} else {
-			assert.Equal(t, err, nil)
-		}
+		cstest.AssertErrorContains(t, err, test.expectedErr)
 	}
 }
 
@@ -170,14 +162,11 @@ journalctl_filter:
 		}
 
 		err = j.OneShotAcquisition(out, &tomb)
-		if ts.expectedErr == "" && err != nil {
-			t.Fatalf("Unexpected error : %s", err)
-		} else if ts.expectedErr != "" && err != nil {
-			assert.Contains(t, err.Error(), ts.expectedErr)
+		cstest.AssertErrorContains(t, err, ts.expectedErr)
+		if err != nil {
 			continue
-		} else if ts.expectedErr != "" && err == nil {
-			t.Fatalf("Expected error %s, but got nothing !", ts.expectedErr)
 		}
+
 		if ts.expectedLines != 0 {
 			assert.Equal(t, ts.expectedLines, actualLines)
 		}
@@ -250,13 +239,9 @@ journalctl_filter:
 		}
 
 		err = j.StreamingAcquisition(out, &tomb)
-		if ts.expectedErr == "" && err != nil {
-			t.Fatalf("Unexpected error : %s", err)
-		} else if ts.expectedErr != "" && err != nil {
-			assert.Contains(t, err.Error(), ts.expectedErr)
+		cstest.AssertErrorContains(t, err, ts.expectedErr)
+		if err != nil {
 			continue
-		} else if ts.expectedErr != "" && err == nil {
-			t.Fatalf("Expected error %s, but got nothing !", ts.expectedErr)
 		}
 
 		if ts.expectedLines != 0 {

+ 2 - 6
pkg/acquisition/modules/kinesis/kinesis_test.go

@@ -14,6 +14,7 @@ import (
 	"github.com/aws/aws-sdk-go/aws"
 	"github.com/aws/aws-sdk-go/aws/session"
 	"github.com/aws/aws-sdk-go/service/kinesis"
+	"github.com/crowdsecurity/crowdsec/pkg/cstest"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	log "github.com/sirupsen/logrus"
 	"github.com/stretchr/testify/assert"
@@ -138,12 +139,7 @@ stream_arn: arn:aws:kinesis:eu-west-1:123456789012:stream/my-stream`,
 	for _, test := range tests {
 		f := KinesisSource{}
 		err := f.Configure([]byte(test.config), subLogger)
-		if test.expectedErr != "" && err == nil {
-			t.Fatalf("Expected err %s but got nil !", test.expectedErr)
-		}
-		if test.expectedErr != "" {
-			assert.Contains(t, err.Error(), test.expectedErr)
-		}
+		cstest.AssertErrorContains(t, err, test.expectedErr)
 	}
 }
 

+ 5 - 12
pkg/acquisition/modules/syslog/syslog_test.go

@@ -6,6 +6,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/crowdsecurity/crowdsec/pkg/cstest"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	log "github.com/sirupsen/logrus"
 	"gopkg.in/tomb.v2"
@@ -54,12 +55,7 @@ listen_addr: 10.0.0`,
 	for _, test := range tests {
 		s := SyslogSource{}
 		err := s.Configure([]byte(test.config), subLogger)
-		if test.expectedErr != "" {
-			if err == nil {
-				t.Fatalf("Expected error but got nothing : %+v", test)
-			}
-			assert.Contains(t, err.Error(), test.expectedErr)
-		}
+		cstest.AssertErrorContains(t, err, test.expectedErr)
 	}
 }
 
@@ -123,14 +119,11 @@ listen_addr: 127.0.0.1`,
 		tomb := tomb.Tomb{}
 		out := make(chan types.Event)
 		err := s.StreamingAcquisition(out, &tomb)
-		if ts.expectedErr != "" && err == nil {
-			t.Fatalf("expected error but got nothing : %+v", ts)
-		} else if ts.expectedErr == "" && err != nil {
-			t.Fatalf("unexpected error : %s", err)
-		} else if ts.expectedErr != "" && err != nil {
-			assert.Contains(t, err.Error(), ts.expectedErr)
+		cstest.AssertErrorContains(t, err, ts.expectedErr)
+		if err != nil {
 			continue
 		}
+
 		actualLines := 0
 		go writeToSyslog(ts.logs)
 	READLOOP:

+ 17 - 0
pkg/cstest/utils.go

@@ -4,6 +4,9 @@ import (
 	"fmt"
 	"io/ioutil"
 	"os"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
 )
 
 func Copy(sourceFile string, destinationFile string) error {
@@ -79,3 +82,17 @@ func CopyDir(src string, dest string) error {
 
 	return nil
 }
+
+func AssertErrorContains(t *testing.T, err error, expectedErr string) {
+	if expectedErr == "" {
+		if err != nil {
+			t.Fatalf("Unexpected error: %s", err)
+		}
+		assert.Equal(t, err, nil)
+		return
+	}
+	if err == nil {
+		t.Fatalf("Expected '%s', got nil", expectedErr)
+	}
+	assert.Contains(t, err.Error(), expectedErr)
+}