Browse Source

cleanup + fix flaky tests in file_test.go, apic_test.go (#1773)

mmetc 2 years ago
parent
commit
edced6818a
5 changed files with 394 additions and 392 deletions
  1. 231 217
      pkg/acquisition/modules/file/file_test.go
  2. 19 24
      pkg/apiserver/apic.go
  3. 129 147
      pkg/apiserver/apic_test.go
  4. 1 4
      pkg/apiserver/apiserver.go
  5. 14 0
      pkg/cstest/utils.go

+ 231 - 217
pkg/acquisition/modules/file/file_test.go

@@ -1,4 +1,4 @@
-package fileacquisition
+package fileacquisition_test
 
 import (
 	"fmt"
@@ -7,32 +7,39 @@ import (
 	"testing"
 	"time"
 
+	fileacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/file"
 	"github.com/crowdsecurity/crowdsec/pkg/cstest"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	log "github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus/hooks/test"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 	"gopkg.in/tomb.v2"
 )
 
 func TestBadConfiguration(t *testing.T) {
 	tests := []struct {
+		name        string
 		config      string
 		expectedErr string
 	}{
 		{
-			config:      `foobar: asd.log`,
+			name:        "extra configuration key",
+			config:      "foobar: asd.log",
 			expectedErr: "line 1: field foobar not found in type fileacquisition.FileConfiguration",
 		},
 		{
-			config:      `mode: tail`,
+			name:        "missing filenames",
+			config:      "mode: tail",
 			expectedErr: "no filename or filenames configuration provided",
 		},
 		{
+			name:        "glob syntax error",
 			config:      `filename: "[asd-.log"`,
 			expectedErr: "Glob failure: syntax error in pattern",
 		},
 		{
+			name: "bad exclude regexp",
 			config: `filenames: ["asd.log"]
 exclude_regexps: ["as[a-$d"]`,
 			expectedErr: "Could not compile regexp as",
@@ -42,20 +49,24 @@ exclude_regexps: ["as[a-$d"]`,
 	subLogger := log.WithFields(log.Fields{
 		"type": "file",
 	})
-	for _, test := range tests {
-		f := FileSource{}
-		err := f.Configure([]byte(test.config), subLogger)
-		assert.Contains(t, err.Error(), test.expectedErr)
+
+	for _, tc := range tests {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
+			f := fileacquisition.FileSource{}
+			err := f.Configure([]byte(tc.config), subLogger)
+			cstest.RequireErrorContains(t, err, tc.expectedErr)
+		})
 	}
 }
 
 func TestConfigureDSN(t *testing.T) {
-	var file string
-	if runtime.GOOS != "windows" {
-		file = "/etc/passwd"
-	} else {
-		file = "C:\\Windows\\System32\\drivers\\etc\\hosts"
+	file := "/etc/passwd"
+
+	if runtime.GOOS == "windows" {
+		file = `C:\Windows\System32\drivers\etc\hosts`
 	}
+
 	tests := []struct {
 		dsn         string
 		expectedErr string
@@ -69,37 +80,41 @@ func TestConfigureDSN(t *testing.T) {
 			expectedErr: "empty file:// DSN",
 		},
 		{
-			dsn:         fmt.Sprintf("file://%s?log_level=warn", file),
-			expectedErr: "",
+			dsn: fmt.Sprintf("file://%s?log_level=warn", file),
 		},
 		{
 			dsn:         fmt.Sprintf("file://%s?log_level=foobar", file),
 			expectedErr: "unknown level foobar: not a valid logrus Level:",
 		},
 	}
+
 	subLogger := log.WithFields(log.Fields{
 		"type": "file",
 	})
-	for _, test := range tests {
-		f := FileSource{}
-		err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger)
-		cstest.AssertErrorContains(t, err, test.expectedErr)
+
+	for _, tc := range tests {
+		tc := tc
+		t.Run(tc.dsn, func(t *testing.T) {
+			f := fileacquisition.FileSource{}
+			err := f.ConfigureByDSN(tc.dsn, map[string]string{"type": "testtype"}, subLogger)
+			cstest.RequireErrorContains(t, err, tc.expectedErr)
+		})
 	}
 }
 
 func TestOneShot(t *testing.T) {
-	var permDeniedFile string
-	var permDeniedError string
-	if runtime.GOOS != "windows" {
-		permDeniedFile = "/etc/shadow"
-		permDeniedError = "failed opening /etc/shadow: open /etc/shadow: permission denied"
-	} else {
-		//Technically, this is not a permission denied error, but we just want to test what happens
-		//if we do not have access to the file
-		permDeniedFile = "C:\\Windows\\System32\\config\\SAM"
-		permDeniedError = "failed opening C:\\Windows\\System32\\config\\SAM: open C:\\Windows\\System32\\config\\SAM: The process cannot access the file because it is being used by another process."
+	permDeniedFile := "/etc/shadow"
+	permDeniedError := "failed opening /etc/shadow: open /etc/shadow: permission denied"
+
+	if runtime.GOOS == "windows" {
+		// Technically, this is not a permission denied error, but we just want to test what happens
+		// if we do not have access to the file
+		permDeniedFile = `C:\Windows\System32\config\SAM`
+		permDeniedError = `failed opening C:\Windows\System32\config\SAM: open C:\Windows\System32\config\SAM: The process cannot access the file because it is being used by another process.`
 	}
+
 	tests := []struct {
+		name              string
 		config            string
 		expectedConfigErr string
 		expectedErr       string
@@ -111,76 +126,68 @@ func TestOneShot(t *testing.T) {
 		teardown          func()
 	}{
 		{
+			name: "permission denied",
 			config: fmt.Sprintf(`
 mode: cat
 filename: %s`, permDeniedFile),
-			expectedConfigErr: "",
-			expectedErr:       permDeniedError,
-			expectedOutput:    "",
-			logLevel:          log.WarnLevel,
-			expectedLines:     0,
+			expectedErr:   permDeniedError,
+			logLevel:      log.WarnLevel,
+			expectedLines: 0,
 		},
 		{
+			name: "ignored directory",
 			config: `
 mode: cat
 filename: /`,
-			expectedConfigErr: "",
-			expectedErr:       "",
-			expectedOutput:    "/ is a directory, ignoring it",
-			logLevel:          log.WarnLevel,
-			expectedLines:     0,
+			expectedOutput: "/ is a directory, ignoring it",
+			logLevel:       log.WarnLevel,
+			expectedLines:  0,
 		},
 		{
+			name: "glob syntax error",
 			config: `
 mode: cat
 filename: "[*-.log"`,
 			expectedConfigErr: "Glob failure: syntax error in pattern",
-			expectedErr:       "",
-			expectedOutput:    "",
 			logLevel:          log.WarnLevel,
 			expectedLines:     0,
 		},
 		{
+			name: "no matching files",
 			config: `
 mode: cat
 filename: /do/not/exist`,
-			expectedConfigErr: "",
-			expectedErr:       "",
-			expectedOutput:    "No matching files for pattern /do/not/exist",
-			logLevel:          log.WarnLevel,
-			expectedLines:     0,
+			expectedOutput: "No matching files for pattern /do/not/exist",
+			logLevel:       log.WarnLevel,
+			expectedLines:  0,
 		},
 		{
+			name: "test.log",
 			config: `
 mode: cat
 filename: test_files/test.log`,
-			expectedConfigErr: "",
-			expectedErr:       "",
-			expectedOutput:    "",
-			expectedLines:     5,
-			logLevel:          log.WarnLevel,
+			expectedLines: 5,
+			logLevel:      log.WarnLevel,
 		},
 		{
+			name: "test.log.gz",
 			config: `
 mode: cat
 filename: test_files/test.log.gz`,
-			expectedConfigErr: "",
-			expectedErr:       "",
-			expectedOutput:    "",
-			expectedLines:     5,
-			logLevel:          log.WarnLevel,
+			expectedLines: 5,
+			logLevel:      log.WarnLevel,
 		},
 		{
+			name: "unexpected end of gzip stream",
 			config: `
 mode: cat
 filename: test_files/bad.gz`,
-			expectedConfigErr: "",
-			expectedErr:       "failed to read gz test_files/bad.gz: unexpected EOF",
-			expectedOutput:    "",
-			expectedLines:     0,
-			logLevel:          log.WarnLevel,
+			expectedErr:   "failed to read gz test_files/bad.gz: unexpected EOF",
+			expectedLines: 0,
+			logLevel:      log.WarnLevel,
 		},
 		{
+			name: "deleted file",
 			config: `
 mode: cat
 filename: test_files/test_delete.log`,
@@ -195,77 +202,84 @@ filename: test_files/test_delete.log`,
 		},
 	}
 
-	for _, ts := range tests {
-		logger, hook := test.NewNullLogger()
-		logger.SetLevel(ts.logLevel)
-		subLogger := logger.WithFields(log.Fields{
-			"type": "file",
-		})
-		tomb := tomb.Tomb{}
-		out := make(chan types.Event)
-		f := FileSource{}
-		if ts.setup != nil {
-			ts.setup()
-		}
-		err := f.Configure([]byte(ts.config), subLogger)
-		cstest.AssertErrorContains(t, err, ts.expectedConfigErr)
-		if err != nil {
-			continue
-		}
-
-		if ts.afterConfigure != nil {
-			ts.afterConfigure()
-		}
-		actualLines := 0
-		if ts.expectedLines != 0 {
-			go func() {
-			READLOOP:
-				for {
-					select {
-					case <-out:
-						actualLines++
-					case <-time.After(1 * time.Second):
-						break READLOOP
+	for _, tc := range tests {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
+			logger, hook := test.NewNullLogger()
+			logger.SetLevel(tc.logLevel)
+
+			subLogger := logger.WithFields(log.Fields{
+				"type": "file",
+			})
+
+			tomb := tomb.Tomb{}
+			out := make(chan types.Event)
+			f := fileacquisition.FileSource{}
+
+			if tc.setup != nil {
+				tc.setup()
+			}
+
+			err := f.Configure([]byte(tc.config), subLogger)
+			cstest.RequireErrorContains(t, err, tc.expectedConfigErr)
+			if tc.expectedConfigErr != "" {
+				return
+			}
+
+			if tc.afterConfigure != nil {
+				tc.afterConfigure()
+			}
+
+			actualLines := 0
+			if tc.expectedLines != 0 {
+				go func() {
+					for {
+						select {
+						case <-out:
+							actualLines++
+						case <-time.After(2 * time.Second):
+							return
+						}
 					}
-				}
-			}()
-		}
-		err = f.OneShotAcquisition(out, &tomb)
-		cstest.AssertErrorContains(t, err, ts.expectedErr)
-
-		if ts.expectedLines != 0 {
-			assert.Equal(t, actualLines, ts.expectedLines)
-		}
-		if ts.expectedOutput != "" {
-			assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput)
-			hook.Reset()
-		}
-		if ts.teardown != nil {
-			ts.teardown()
-		}
+				}()
+			}
+
+			err = f.OneShotAcquisition(out, &tomb)
+			cstest.RequireErrorContains(t, err, tc.expectedErr)
+
+			if tc.expectedLines != 0 {
+				assert.Equal(t, tc.expectedLines, actualLines)
+			}
+
+			if tc.expectedOutput != "" {
+				assert.Contains(t, hook.LastEntry().Message, tc.expectedOutput)
+				hook.Reset()
+			}
+			if tc.teardown != nil {
+				tc.teardown()
+			}
+		})
 	}
 }
 
 func TestLiveAcquisition(t *testing.T) {
-	var permDeniedFile string
-	var permDeniedError string
-	var testPattern string
-	if runtime.GOOS != "windows" {
-		permDeniedFile = "/etc/shadow"
-		permDeniedError = "unable to read /etc/shadow : open /etc/shadow: permission denied"
-		testPattern = "test_files/*.log"
-	} else {
-		//Technically, this is not a permission denied error, but we just want to test what happens
-		//if we do not have access to the file
-		permDeniedFile = "C:\\Windows\\System32\\config\\SAM"
-		permDeniedError = "unable to read C:\\Windows\\System32\\config\\SAM : open C:\\Windows\\System32\\config\\SAM: The process cannot access the file because it is being used by another process"
-		testPattern = "test_files\\\\*.log" // the \ must be escaped twice: once for the string, once for the yaml config
+	permDeniedFile := "/etc/shadow"
+	permDeniedError := "unable to read /etc/shadow : open /etc/shadow: permission denied"
+	testPattern := "test_files/*.log"
+
+	if runtime.GOOS == "windows" {
+		// Technically, this is not a permission denied error, but we just want to test what happens
+		// if we do not have access to the file
+		permDeniedFile = `C:\Windows\System32\config\SAM`
+		permDeniedError = `unable to read C:\Windows\System32\config\SAM : open C:\Windows\System32\config\SAM: The process cannot access the file because it is being used by another process`
+		testPattern = `test_files\\*.log` // the \ must be escaped for the yaml config
 	}
+
 	tests := []struct {
+		name           string
 		config         string
 		expectedErr    string
 		expectedOutput string
-		name           string
 		expectedLines  int
 		logLevel       log.Level
 		setup          func()
@@ -276,7 +290,6 @@ func TestLiveAcquisition(t *testing.T) {
 			config: fmt.Sprintf(`
 mode: tail
 filename: %s`, permDeniedFile),
-			expectedErr:    "",
 			expectedOutput: permDeniedError,
 			logLevel:       log.InfoLevel,
 			expectedLines:  0,
@@ -286,7 +299,6 @@ filename: %s`, permDeniedFile),
 			config: `
 mode: tail
 filename: /`,
-			expectedErr:    "",
 			expectedOutput: "/ is a directory, ignoring it",
 			logLevel:       log.WarnLevel,
 			expectedLines:  0,
@@ -296,7 +308,6 @@ filename: /`,
 			config: `
 mode: tail
 filename: /do/not/exist`,
-			expectedErr:    "",
 			expectedOutput: "No matching files for pattern /do/not/exist",
 			logLevel:       log.WarnLevel,
 			expectedLines:  0,
@@ -308,11 +319,9 @@ mode: tail
 filenames:
  - %s
 force_inotify: true`, testPattern),
-			expectedErr:    "",
-			expectedOutput: "",
-			expectedLines:  5,
-			logLevel:       log.DebugLevel,
-			name:           "basicGlob",
+			expectedLines: 5,
+			logLevel:      log.DebugLevel,
+			name:          "basicGlob",
 		},
 		{
 			config: fmt.Sprintf(`
@@ -320,11 +329,9 @@ mode: tail
 filenames:
  - %s
 force_inotify: true`, testPattern),
-			expectedErr:    "",
-			expectedOutput: "",
-			expectedLines:  0,
-			logLevel:       log.DebugLevel,
-			name:           "GlobInotify",
+			expectedLines: 0,
+			logLevel:      log.DebugLevel,
+			name:          "GlobInotify",
 			afterConfigure: func() {
 				f, _ := os.Create("test_files/a.log")
 				f.Close()
@@ -338,19 +345,17 @@ mode: tail
 filenames:
  - %s
 force_inotify: true`, testPattern),
-			expectedErr:    "",
-			expectedOutput: "",
-			expectedLines:  5,
-			logLevel:       log.DebugLevel,
-			name:           "GlobInotifyChmod",
+			expectedLines: 5,
+			logLevel:      log.DebugLevel,
+			name:          "GlobInotifyChmod",
 			afterConfigure: func() {
 				f, _ := os.Create("test_files/a.log")
 				f.Close()
 				time.Sleep(1 * time.Second)
-				os.Chmod("test_files/a.log", 0000)
+				os.Chmod("test_files/a.log", 0o000)
 			},
 			teardown: func() {
-				os.Chmod("test_files/a.log", 0644)
+				os.Chmod("test_files/a.log", 0o644)
 				os.Remove("test_files/a.log")
 			},
 		},
@@ -360,13 +365,11 @@ mode: tail
 filenames:
  - %s
 force_inotify: true`, testPattern),
-			expectedErr:    "",
-			expectedOutput: "",
-			expectedLines:  5,
-			logLevel:       log.DebugLevel,
-			name:           "InotifyMkDir",
+			expectedLines: 5,
+			logLevel:      log.DebugLevel,
+			name:          "InotifyMkDir",
 			afterConfigure: func() {
-				os.Mkdir("test_files/pouet/", 0700)
+				os.Mkdir("test_files/pouet/", 0o700)
 			},
 			teardown: func() {
 				os.Remove("test_files/pouet/")
@@ -374,101 +377,112 @@ force_inotify: true`, testPattern),
 		},
 	}
 
-	for _, ts := range tests {
-		t.Logf("test: %s", ts.name)
-		logger, hook := test.NewNullLogger()
-		logger.SetLevel(ts.logLevel)
-		subLogger := logger.WithFields(log.Fields{
-			"type": "file",
-		})
-		tomb := tomb.Tomb{}
-		out := make(chan types.Event)
-		f := FileSource{}
-		if ts.setup != nil {
-			ts.setup()
-		}
-		err := f.Configure([]byte(ts.config), subLogger)
-		if err != nil {
-			t.Fatalf("Unexpected error : %s", err)
-		}
-		if ts.afterConfigure != nil {
-			ts.afterConfigure()
-		}
-		actualLines := 0
-		if ts.expectedLines != 0 {
-			go func() {
-			READLOOP:
-				for {
-					select {
-					case <-out:
-						actualLines++
-					case <-time.After(2 * time.Second):
-						break READLOOP
+	for _, tc := range tests {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
+			logger, hook := test.NewNullLogger()
+			logger.SetLevel(tc.logLevel)
+
+			subLogger := logger.WithFields(log.Fields{
+				"type": "file",
+			})
+
+			tomb := tomb.Tomb{}
+			out := make(chan types.Event)
+
+			f := fileacquisition.FileSource{}
+
+			if tc.setup != nil {
+				tc.setup()
+			}
+
+			err := f.Configure([]byte(tc.config), subLogger)
+			require.NoError(t, err)
+
+			if tc.afterConfigure != nil {
+				tc.afterConfigure()
+			}
+
+			actualLines := 0
+			if tc.expectedLines != 0 {
+				go func() {
+					for {
+						select {
+						case <-out:
+							actualLines++
+						case <-time.After(2 * time.Second):
+							return
+						}
 					}
-				}
-			}()
-		}
-		err = f.StreamingAcquisition(out, &tomb)
-		cstest.AssertErrorContains(t, err, ts.expectedErr)
-
-		if ts.expectedLines != 0 {
-			fd, err := os.Create("test_files/stream.log")
-			if err != nil {
-				t.Fatalf("could not create test file : %s", err)
+				}()
 			}
-			for i := 0; i < 5; i++ {
-				_, err = fd.WriteString(fmt.Sprintf("%d\n", i))
+
+			err = f.StreamingAcquisition(out, &tomb)
+			cstest.RequireErrorContains(t, err, tc.expectedErr)
+
+			if tc.expectedLines != 0 {
+				fd, err := os.Create("test_files/stream.log")
 				if err != nil {
-					t.Fatalf("could not write test file : %s", err)
-					os.Remove("test_files/stream.log")
+					t.Fatalf("could not create test file : %s", err)
+				}
+
+				for i := 0; i < 5; i++ {
+					_, err = fmt.Fprintf(fd, "%d\n", i)
+					if err != nil {
+						t.Fatalf("could not write test file : %s", err)
+						os.Remove("test_files/stream.log")
+					}
 				}
+
+				fd.Close()
+				// we sleep to make sure we detect the new file
+				time.Sleep(1 * time.Second)
+				os.Remove("test_files/stream.log")
+				assert.Equal(t, tc.expectedLines, actualLines)
 			}
-			fd.Close()
-			//we sleep to make sure we detect the new file
-			time.Sleep(1 * time.Second)
-			os.Remove("test_files/stream.log")
-			assert.Equal(t, ts.expectedLines, actualLines)
-		}
-
-		if ts.expectedOutput != "" {
-			if hook.LastEntry() == nil {
-				t.Fatalf("expected output %s, but got nothing", ts.expectedOutput)
+
+			if tc.expectedOutput != "" {
+				if hook.LastEntry() == nil {
+					t.Fatalf("expected output %s, but got nothing", tc.expectedOutput)
+				}
+
+				assert.Contains(t, hook.LastEntry().Message, tc.expectedOutput)
+				hook.Reset()
 			}
-			assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput)
-			hook.Reset()
-		}
 
-		if ts.teardown != nil {
-			ts.teardown()
-		}
+			if tc.teardown != nil {
+				tc.teardown()
+			}
 
-		tomb.Kill(nil)
+			tomb.Kill(nil)
+		})
 	}
 }
 
 func TestExclusion(t *testing.T) {
-
 	config := `filenames: ["test_files/*.log*"]
 exclude_regexps: ["\\.gz$"]`
 	logger, hook := test.NewNullLogger()
-	//logger.SetLevel(ts.logLevel)
+	// logger.SetLevel(ts.logLevel)
 	subLogger := logger.WithFields(log.Fields{
 		"type": "file",
 	})
-	f := FileSource{}
-	err := f.Configure([]byte(config), subLogger)
-	if err != nil {
+
+	f := fileacquisition.FileSource{}
+	if err := f.Configure([]byte(config), subLogger); err != nil {
 		subLogger.Fatalf("unexpected error: %s", err)
 	}
-	var expectedLogOutput string
+
+	expectedLogOutput := "Skipping file test_files/test.log.gz as it matches exclude pattern"
+
 	if runtime.GOOS == "windows" {
-		expectedLogOutput = "Skipping file test_files\\test.log.gz as it matches exclude pattern \\.gz"
-	} else {
-		expectedLogOutput = "Skipping file test_files/test.log.gz as it matches exclude pattern"
+		expectedLogOutput = `Skipping file test_files\test.log.gz as it matches exclude pattern \.gz`
 	}
+
 	if hook.LastEntry() == nil {
 		t.Fatalf("expected output %s, but got nothing", expectedLogOutput)
 	}
+
 	assert.Contains(t, hook.LastEntry().Message, expectedLogOutput)
 	hook.Reset()
 }

+ 19 - 24
pkg/apiserver/apic.go

@@ -400,7 +400,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio
 	return alerts
 }
 
-//we receive only one list of decisions, that we need to break-up :
+// we receive only one list of decisions, that we need to break-up :
 // one alert for "community blocklist"
 // one alert per list we're subscribed to
 func (a *apic) PullTop() error {
@@ -432,7 +432,7 @@ func (a *apic) PullTop() error {
 		return nil
 	}
 
-	//we receive only one list of decisions, that we need to break-up :
+	// we receive only one list of decisions, that we need to break-up :
 	// one alert for "community blocklist"
 	// one alert per list we're subscribed to
 	alertsFromCapi := createAlertsForDecisions(data.New)
@@ -541,37 +541,32 @@ func (a *apic) GetMetrics() (*models.Metrics, error) {
 	return metric, nil
 }
 
-func (a *apic) SendMetrics() error {
+func (a *apic) SendMetrics(stop chan (bool)) {
 	defer types.CatchPanic("lapi/metricsToAPIC")
 
-	metrics, err := a.GetMetrics()
-	if err != nil {
-		log.Errorf("unable to get metrics (%s), will retry", err)
-	}
-	_, _, err = a.apiClient.Metrics.Add(context.Background(), metrics)
-	if err != nil {
-		log.Errorf("unable to send metrics (%s), will retry", err)
-	}
-	log.Infof("capi metrics: metrics sent successfully")
-	log.Infof("Start send metrics to CrowdSec Central API (interval: %s)", MetricsInterval)
+	log.Infof("Start send metrics to CrowdSec Central API (interval: %s)", a.metricsInterval)
 	ticker := time.NewTicker(a.metricsInterval)
 	for {
+		metrics, err := a.GetMetrics()
+		if err != nil {
+			log.Errorf("unable to get metrics (%s), will retry", err)
+		}
+		_, _, err = a.apiClient.Metrics.Add(context.Background(), metrics)
+		if err != nil {
+			log.Errorf("capi metrics: failed: %s", err)
+		} else {
+			log.Infof("capi metrics: metrics sent successfully")
+		}
+
 		select {
+		case <-stop:
+			return
 		case <-ticker.C:
-			metrics, err := a.GetMetrics()
-			if err != nil {
-				log.Errorf("unable to get metrics (%s), will retry", err)
-			}
-			_, _, err = a.apiClient.Metrics.Add(context.Background(), metrics)
-			if err != nil {
-				log.Errorf("capi metrics: failed: %s", err)
-			} else {
-				log.Infof("capi metrics: metrics sent successfully")
-			}
+			continue
 		case <-a.metricsTomb.Dying(): // if one apic routine is dying, do we kill the others?
 			a.pullTomb.Kill(nil)
 			a.pushTomb.Kill(nil)
-			return nil
+			return
 		}
 	}
 }

+ 129 - 147
pkg/apiserver/apic_test.go

@@ -8,13 +8,13 @@ import (
 	"net/url"
 	"os"
 	"reflect"
-	"sort"
 	"sync"
 	"testing"
 	"time"
 
 	"github.com/crowdsecurity/crowdsec/pkg/apiclient"
 	"github.com/crowdsecurity/crowdsec/pkg/csconfig"
+	"github.com/crowdsecurity/crowdsec/pkg/cstest"
 	"github.com/crowdsecurity/crowdsec/pkg/cwversion"
 	"github.com/crowdsecurity/crowdsec/pkg/database"
 	"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
@@ -24,23 +24,20 @@ import (
 	"github.com/jarcoal/httpmock"
 	"github.com/sirupsen/logrus"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 	"gopkg.in/tomb.v2"
 )
 
 func getDBClient(t *testing.T) *database.Client {
 	t.Helper()
 	dbPath, err := os.CreateTemp("", "*sqlite")
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 	dbClient, err := database.NewClient(&csconfig.DatabaseCfg{
 		Type:   "sqlite",
 		DbName: "crowdsec",
 		DbPath: dbPath.Name(),
 	})
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 	return dbClient
 }
 
@@ -98,11 +95,9 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) {
 
 func TestAPICCAPIPullIsOld(t *testing.T) {
 	api := getAPIC(t)
-	isOld, err := api.CAPIPullIsOld()
-	if err != nil {
-		t.Fatal(err)
-	}
 
+	isOld, err := api.CAPIPullIsOld()
+	require.NoError(t, err)
 	assert.True(t, isOld)
 
 	decision := api.dbClient.Ent.Decision.Create().
@@ -123,16 +118,13 @@ func TestAPICCAPIPullIsOld(t *testing.T) {
 		SaveX(context.Background())
 
 	isOld, err = api.CAPIPullIsOld()
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
 	assert.False(t, isOld)
 }
 
 func TestAPICFetchScenariosListFromDB(t *testing.T) {
-	api := getAPIC(t)
-	testCases := []struct {
+	tests := []struct {
 		name                    string
 		machineIDsWithScenarios map[string]string
 		expectedScenarios       []string
@@ -154,8 +146,10 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
 		},
 	}
 
-	for _, tc := range testCases {
+	for _, tc := range tests {
+		tc := tc
 		t.Run(tc.name, func(t *testing.T) {
+			api := getAPIC(t)
 			for machineID, scenarios := range tc.machineIDsWithScenarios {
 				api.dbClient.Ent.Machine.Create().
 					SetMachineId(machineID).
@@ -164,17 +158,14 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
 					SetScenarios(scenarios).
 					ExecX(context.Background())
 			}
+
 			scenarios, err := api.FetchScenariosListFromDB()
 			for machineID := range tc.machineIDsWithScenarios {
 				api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background())
 			}
-			if err != nil {
-				t.Fatal(err)
-			} else {
-				sort.Strings(scenarios)
-				sort.Strings(tc.expectedScenarios)
-				assert.Equal(t, scenarios, tc.expectedScenarios)
-			}
+			require.NoError(t, err)
+
+			assert.ElementsMatch(t, tc.expectedScenarios, scenarios)
 		})
 
 	}
@@ -196,11 +187,10 @@ func TestNewAPIC(t *testing.T) {
 		consoleConfig *csconfig.ConsoleConfig
 	}
 	tests := []struct {
-		name          string
-		args          args
-		wantErr       bool
-		errorContains string
-		action        func()
+		name        string
+		args        args
+		expectedErr string
+		action      func()
 	}{
 		{
 			name:   "simple",
@@ -217,20 +207,16 @@ func TestNewAPIC(t *testing.T) {
 				dbClient:      getDBClient(t),
 				consoleConfig: LoadTestConfig().API.Server.ConsoleConfig,
 			},
-			wantErr:       true,
-			errorContains: "first path segment in URL cannot contain colon",
+			expectedErr: "first path segment in URL cannot contain colon",
 		},
 	}
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
+	for _, tc := range tests {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
 			setConfig()
-			tt.action()
-			_, err := NewAPIC(testConfig, tt.args.dbClient, tt.args.consoleConfig)
-			if tt.wantErr {
-				assert.ErrorContains(t, err, tt.errorContains)
-			} else {
-				assert.NoError(t, err)
-			}
+			tc.action()
+			_, err := NewAPIC(testConfig, tc.args.dbClient, tc.args.consoleConfig)
+			cstest.RequireErrorContains(t, err, tc.expectedErr)
 		})
 	}
 }
@@ -268,17 +254,16 @@ func TestAPICHandleDeletedDecisions(t *testing.T) {
 	}}, deleteCounters)
 
 	assert.NoError(t, err)
-	assert.Equal(t, nbDeleted, 2)
-	assert.Equal(t, deleteCounters[SCOPE_CAPI]["all"], 2)
+	assert.Equal(t, 2, nbDeleted)
+	assert.Equal(t, 2, deleteCounters[SCOPE_CAPI]["all"])
 }
 
 func TestAPICGetMetrics(t *testing.T) {
-	api := getAPIC(t)
-	cleanUp := func() {
+	cleanUp := func(api *apic) {
 		api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background())
 		api.dbClient.Ent.Machine.Delete().ExecX(context.Background())
 	}
-	testCases := []struct {
+	tests := []struct {
 		name           string
 		machineIDs     []string
 		bouncers       []string
@@ -322,11 +307,13 @@ func TestAPICGetMetrics(t *testing.T) {
 			},
 		},
 	}
-	for _, testCase := range testCases {
-		t.Run(testCase.name, func(t *testing.T) {
-			cleanUp()
-			for i, machineID := range testCase.machineIDs {
-				api.dbClient.Ent.Machine.Create().
+	for _, tc := range tests {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
+			apiClient := getAPIC(t)
+			cleanUp(apiClient)
+			for i, machineID := range tc.machineIDs {
+				apiClient.dbClient.Ent.Machine.Create().
 					SetMachineId(machineID).
 					SetPassword(testPassword.String()).
 					SetIpAddress(fmt.Sprintf("1.2.3.%d", i)).
@@ -336,8 +323,8 @@ func TestAPICGetMetrics(t *testing.T) {
 					ExecX(context.Background())
 			}
 
-			for i, bouncerName := range testCase.bouncers {
-				api.dbClient.Ent.Bouncer.Create().
+			for i, bouncerName := range tc.bouncers {
+				apiClient.dbClient.Ent.Bouncer.Create().
 					SetIPAddress(fmt.Sprintf("1.2.3.%d", i)).
 					SetName(bouncerName).
 					SetAPIKey("foobar").
@@ -346,19 +333,17 @@ func TestAPICGetMetrics(t *testing.T) {
 					ExecX(context.Background())
 			}
 
-			if foundMetrics, err := api.GetMetrics(); err != nil {
-				t.Fatal(err)
-			} else {
-				assert.Equal(t, foundMetrics.Bouncers, testCase.expectedMetric.Bouncers)
-				assert.Equal(t, foundMetrics.Machines, testCase.expectedMetric.Machines)
+			foundMetrics, err := apiClient.GetMetrics()
+			require.NoError(t, err)
+
+			assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers)
+			assert.Equal(t, tc.expectedMetric.Machines, foundMetrics.Machines)
 
-			}
 		})
 	}
 }
 
 func TestCreateAlertsForDecision(t *testing.T) {
-
 	httpBfDecisionList := &models.Decision{
 		Origin:   &SCOPE_LISTS,
 		Scenario: types.StrPtr("crowdsecurity/http-bf"),
@@ -427,10 +412,11 @@ func TestCreateAlertsForDecision(t *testing.T) {
 			},
 		},
 	}
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			if got := createAlertsForDecisions(tt.args.decisions); !reflect.DeepEqual(got, tt.want) {
-				t.Errorf("createAlertsForDecisions() = %v, want %v", got, tt.want)
+	for _, tc := range tests {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
+			if got := createAlertsForDecisions(tc.args.decisions); !reflect.DeepEqual(got, tc.want) {
+				t.Errorf("createAlertsForDecisions() = %v, want %v", got, tc.want)
 			}
 		})
 	}
@@ -503,11 +489,12 @@ func TestFillAlertsWithDecisions(t *testing.T) {
 			},
 		},
 	}
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			add_counters, _ := makeAddAndDeleteCounters()
-			if got := fillAlertsWithDecisions(tt.args.alerts, tt.args.decisions, add_counters); !reflect.DeepEqual(got, tt.want) {
-				t.Errorf("fillAlertsWithDecisions() = %v, want %v", got, tt.want)
+	for _, tc := range tests {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
+			addCounters, _ := makeAddAndDeleteCounters()
+			if got := fillAlertsWithDecisions(tc.args.alerts, tc.args.decisions, addCounters); !reflect.DeepEqual(got, tc.want) {
+				t.Errorf("fillAlertsWithDecisions() = %v, want %v", got, tc.want)
 			}
 		})
 	}
@@ -586,24 +573,19 @@ func TestAPICPullTop(t *testing.T) {
 		),
 	))
 	url, err := url.ParseRequestURI("http://api.crowdsec.net/")
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
+
 	apic, err := apiclient.NewDefaultClient(
 		url,
 		"/api",
 		fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
 		nil,
 	)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
 	api.apiClient = apic
 	err = api.PullTop()
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
 	assertTotalDecisionCount(t, api.dbClient, 5)
 	assertTotalValidDecisionCount(t, api.dbClient, 4)
@@ -619,24 +601,23 @@ func TestAPICPullTop(t *testing.T) {
 	for _, alert := range alerts {
 		alertScenario[alert.SourceScope]++
 	}
-	assert.Equal(t, len(alertScenario), 3)
-	assert.Equal(t, alertScenario[SCOPE_CAPI_ALIAS], 1)
-	assert.Equal(t, alertScenario["lists:crowdsecurity/ssh-bf"], 1)
-	assert.Equal(t, alertScenario["lists:crowdsecurity/http-bf"], 1)
+	assert.Equal(t, 3, len(alertScenario))
+	assert.Equal(t, 1, alertScenario[SCOPE_CAPI_ALIAS])
+	assert.Equal(t, 1, alertScenario["lists:crowdsecurity/ssh-bf"])
+	assert.Equal(t, 1, alertScenario["lists:crowdsecurity/http-bf"])
 
 	for _, decisions := range validDecisions {
 		decisionScenarioFreq[decisions.Scenario]++
 	}
 
-	assert.Equal(t, decisionScenarioFreq["crowdsecurity/http-bf"], 1)
-	assert.Equal(t, decisionScenarioFreq["crowdsecurity/ssh-bf"], 1)
-	assert.Equal(t, decisionScenarioFreq["crowdsecurity/test1"], 1)
-	assert.Equal(t, decisionScenarioFreq["crowdsecurity/test2"], 1)
+	assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/http-bf"], 1)
+	assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/ssh-bf"], 1)
+	assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test1"], 1)
+	assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test2"], 1)
 }
 
 func TestAPICPush(t *testing.T) {
-
-	testCases := []struct {
+	tests := []struct {
 		name          string
 		alerts        []*models.Alert
 		expectedCalls int
@@ -683,14 +664,14 @@ func TestAPICPush(t *testing.T) {
 		},
 	}
 
-	for _, testCase := range testCases {
-		t.Run(testCase.name, func(t *testing.T) {
+	for _, tc := range tests {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
 			api := getAPIC(t)
 			api.pushInterval = time.Millisecond
 			url, err := url.ParseRequestURI("http://api.crowdsec.net/")
-			if err != nil {
-				t.Fatal(err)
-			}
+			require.NoError(t, err)
+
 			httpmock.Activate()
 			defer httpmock.DeactivateAndReset()
 			apic, err := apiclient.NewDefaultClient(
@@ -699,31 +680,28 @@ func TestAPICPush(t *testing.T) {
 				fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
 				nil,
 			)
-			if err != nil {
-				t.Fatal(err)
-			}
+			require.NoError(t, err)
+
 			api.apiClient = apic
 			httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/signals", httpmock.NewBytesResponder(200, []byte{}))
 			go func() {
-				api.alertToPush <- testCase.alerts
+				api.alertToPush <- tc.alerts
 				time.Sleep(time.Second)
 				api.Shutdown()
 			}()
-			if err := api.Push(); err != nil {
-				t.Fatal(err)
-			}
-			assert.Equal(t, httpmock.GetTotalCallCount(), testCase.expectedCalls)
+			err = api.Push()
+			require.NoError(t, err)
+			assert.Equal(t, tc.expectedCalls, httpmock.GetTotalCallCount())
 		})
 	}
 }
 
 func TestAPICSendMetrics(t *testing.T) {
-	api := getAPIC(t)
-	testCases := []struct {
+	tests := []struct {
 		name            string
 		duration        time.Duration
 		expectedCalls   int
-		setUp           func()
+		setUp           func(*apic)
 		metricsInterval time.Duration
 	}{
 		{
@@ -731,14 +709,15 @@ func TestAPICSendMetrics(t *testing.T) {
 			duration:        time.Millisecond * 30,
 			metricsInterval: time.Millisecond * 5,
 			expectedCalls:   5,
-			setUp:           func() {},
+			setUp:           func(api *apic) {},
 		},
 		{
 			name:            "with some metrics",
 			duration:        time.Millisecond * 30,
 			metricsInterval: time.Millisecond * 5,
 			expectedCalls:   5,
-			setUp: func() {
+			setUp: func(api *apic) {
+				api.dbClient.Ent.Machine.Delete().ExecX(context.Background())
 				api.dbClient.Ent.Machine.Create().
 					SetMachineId("1234").
 					SetPassword(testPassword.String()).
@@ -748,6 +727,7 @@ func TestAPICSendMetrics(t *testing.T) {
 					SetUpdatedAt(time.Time{}).
 					ExecX(context.Background())
 
+				api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background())
 				api.dbClient.Ent.Bouncer.Create().
 					SetIPAddress("1.2.3.6").
 					SetName("someBouncer").
@@ -758,44 +738,49 @@ func TestAPICSendMetrics(t *testing.T) {
 			},
 		},
 	}
-	for _, testCase := range testCases {
-		t.Run(testCase.name, func(t *testing.T) {
-			api = getAPIC(t)
-			api.pushInterval = time.Millisecond
+
+	httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/metrics/", httpmock.NewBytesResponder(200, []byte{}))
+	httpmock.Activate()
+	defer httpmock.Deactivate()
+
+	for _, tc := range tests {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
 			url, err := url.ParseRequestURI("http://api.crowdsec.net/")
-			if err != nil {
-				t.Fatal(err)
-			}
-			httpmock.Activate()
-			defer httpmock.DeactivateAndReset()
-			apic, err := apiclient.NewDefaultClient(
+			require.NoError(t, err)
+
+			apiClient, err := apiclient.NewDefaultClient(
 				url,
 				"/api",
 				fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
 				nil,
 			)
-			if err != nil {
-				t.Fatal(err)
-			}
-			api.apiClient = apic
-			api.metricsInterval = testCase.metricsInterval
-			httpmock.RegisterNoResponder(httpmock.NewBytesResponder(200, []byte{}))
-			testCase.setUp()
+			require.NoError(t, err)
 
-			go func() {
-				if err := api.SendMetrics(); err != nil {
-					panic(err)
-				}
-			}()
-			time.Sleep(testCase.duration)
-			assert.LessOrEqual(t, absDiff(testCase.expectedCalls, httpmock.GetTotalCallCount()), 2)
+			api := getAPIC(t)
+			api.pushInterval = time.Millisecond
+			api.apiClient = apiClient
+			api.metricsInterval = tc.metricsInterval
+			tc.setUp(api)
+
+			stop := make(chan bool)
+			httpmock.ZeroCallCounters()
+			go api.SendMetrics(stop)
+			time.Sleep(tc.duration)
+			stop <- true
+
+			info := httpmock.GetCallCountInfo()
+			noResponderCalls := info["NO_RESPONDER"]
+			responderCalls := info["POST http://api.crowdsec.net/api/metrics/"]
+			assert.LessOrEqual(t, absDiff(tc.expectedCalls, responderCalls), 2)
+			assert.Zero(t, noResponderCalls)
 		})
 	}
 }
 
 func TestAPICPull(t *testing.T) {
 	api := getAPIC(t)
-	testCases := []struct {
+	tests := []struct {
 		name                  string
 		setUp                 func()
 		expectedDecisionCount int
@@ -820,14 +805,13 @@ func TestAPICPull(t *testing.T) {
 		},
 	}
 
-	for _, testCase := range testCases {
-		t.Run(testCase.name, func(t *testing.T) {
+	for _, tc := range tests {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
 			api = getAPIC(t)
 			api.pullInterval = time.Millisecond
 			url, err := url.ParseRequestURI("http://api.crowdsec.net/")
-			if err != nil {
-				t.Fatal(err)
-			}
+			require.NoError(t, err)
 			httpmock.Activate()
 			defer httpmock.DeactivateAndReset()
 			apic, err := apiclient.NewDefaultClient(
@@ -836,9 +820,7 @@ func TestAPICPull(t *testing.T) {
 				fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
 				nil,
 			)
-			if err != nil {
-				t.Fatal(err)
-			}
+			require.NoError(t, err)
 			api.apiClient = apic
 			httpmock.RegisterNoResponder(httpmock.NewBytesResponder(200, jsonMarshalX(
 				models.DecisionsStreamResponse{
@@ -854,7 +836,7 @@ func TestAPICPull(t *testing.T) {
 					},
 				},
 			)))
-			testCase.setUp()
+			tc.setUp()
 			var buf bytes.Buffer
 			go func() {
 				logrus.SetOutput(&buf)
@@ -865,15 +847,14 @@ func TestAPICPull(t *testing.T) {
 			//Slightly long because the CI runner for windows are slow, and this can lead to random failure
 			time.Sleep(time.Millisecond * 500)
 			logrus.SetOutput(os.Stderr)
-			assert.Contains(t, buf.String(), testCase.logContains)
-			assertTotalDecisionCount(t, api.dbClient, testCase.expectedDecisionCount)
+			assert.Contains(t, buf.String(), tc.logContains)
+			assertTotalDecisionCount(t, api.dbClient, tc.expectedDecisionCount)
 		})
 	}
 }
 
 func TestShouldShareAlert(t *testing.T) {
-
-	testCases := []struct {
+	tests := []struct {
 		name          string
 		consoleConfig *csconfig.ConsoleConfig
 		alert         *models.Alert
@@ -948,10 +929,11 @@ func TestShouldShareAlert(t *testing.T) {
 		},
 	}
 
-	for _, testCase := range testCases {
-		t.Run(testCase.name, func(t *testing.T) {
-			ret := shouldShareAlert(testCase.alert, testCase.consoleConfig)
-			assert.Equal(t, ret, testCase.expectedRet)
+	for _, tc := range tests {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
+			ret := shouldShareAlert(tc.alert, tc.consoleConfig)
+			assert.Equal(t, tc.expectedRet, ret)
 		})
 	}
 }

+ 1 - 4
pkg/apiserver/apiserver.go

@@ -311,10 +311,7 @@ func (s *APIServer) Run(apiReady chan bool) error {
 			return nil
 		})
 		s.apic.metricsTomb.Go(func() error {
-			if err := s.apic.SendMetrics(); err != nil {
-				log.Errorf("capi metrics: %s", err)
-				return err
-			}
+			s.apic.SendMetrics(make(chan bool))
 			return nil
 		})
 	}

+ 14 - 0
pkg/cstest/utils.go

@@ -7,6 +7,7 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func Copy(sourceFile string, destinationFile string) error {
@@ -110,6 +111,8 @@ func CopyDir(src string, dest string) error {
 }
 
 func AssertErrorContains(t *testing.T, err error, expectedErr string) {
+	t.Helper()
+
 	if expectedErr != "" {
 		assert.ErrorContains(t, err, expectedErr)
 		return
@@ -117,3 +120,14 @@ func AssertErrorContains(t *testing.T, err error, expectedErr string) {
 
 	assert.NoError(t, err)
 }
+
+func RequireErrorContains(t *testing.T, err error, expectedErr string) {
+	t.Helper()
+
+	if expectedErr != "" {
+		require.ErrorContains(t, err, expectedErr)
+		return
+	}
+
+	require.NoError(t, err)
+}