瀏覽代碼

do not send more than group_threshold alerts at once to a notification plugin (#2264)

* do not send more than group_threshold alerts at once to a notification plugin
* Use generic Chunks function, updated tests

---------

Co-authored-by: Sebastien Blot <sebastien@crowdsec.net>
mmetc 2 年之前
父節點
當前提交
e3cb4ab2c4
共有 2 個文件被更改,包括 45 次插入8 次删除
  1. 10 2
      pkg/csplugin/broker.go
  2. 35 6
      pkg/csplugin/broker_test.go

+ 10 - 2
pkg/csplugin/broker.go

@@ -20,6 +20,7 @@ import (
 	"gopkg.in/yaml.v2"
 
 	"github.com/crowdsecurity/go-cs-lib/pkg/csstring"
+	"github.com/crowdsecurity/go-cs-lib/pkg/slicetools"
 
 	"github.com/crowdsecurity/crowdsec/pkg/csconfig"
 	"github.com/crowdsecurity/crowdsec/pkg/models"
@@ -116,8 +117,15 @@ loop:
 			pb.alertsByPluginName[pluginName] = make([]*models.Alert, 0)
 			pluginMutex.Unlock()
 			go func() {
-				if err := pb.pushNotificationsToPlugin(pluginName, tmpAlerts); err != nil {
-					log.WithField("plugin:", pluginName).Error(err)
+				//Chunk alerts to respect group_threshold
+				threshold := pb.pluginConfigByName[pluginName].GroupThreshold
+				if threshold == 0 {
+					threshold = 1
+				}
+				for _, chunk := range slicetools.Chunks(tmpAlerts, threshold) {
+					if err := pb.pushNotificationsToPlugin(pluginName, chunk); err != nil {
+						log.WithField("plugin:", pluginName).Error(err)
+					}
 				}
 			}()
 

+ 35 - 6
pkg/csplugin/broker_test.go

@@ -3,7 +3,9 @@
 package csplugin
 
 import (
+	"bytes"
 	"encoding/json"
+	"io"
 	"os"
 	"testing"
 	"time"
@@ -278,21 +280,35 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() {
 	pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}}
 	pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}}
 	pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}}
-	time.Sleep(100 * time.Millisecond)
+	time.Sleep(time.Second)
 
 	// because of group threshold, we shouldn't have data yet
 	assert.NoFileExists(t, "./out")
 	pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}}
-	time.Sleep(100 * time.Millisecond)
+	pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}}
+	pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}}
+	time.Sleep(time.Second)
 
 	// and now we should
 	content, err := os.ReadFile("./out")
 	require.NoError(t, err, "Error reading file")
 
+	decoder := json.NewDecoder(bytes.NewReader(content))
+
 	var alerts []models.Alert
-	err = json.Unmarshal(content, &alerts)
+
+	// two notifications, one with 4 alerts, one with 2 alerts
+
+	err = decoder.Decode(&alerts)
 	assert.NoError(t, err)
 	assert.Len(t, alerts, 4)
+
+	err = decoder.Decode(&alerts)
+	assert.NoError(t, err)
+	assert.Len(t, alerts, 2)
+
+	err = decoder.Decode(&alerts)
+	assert.Equal(t, err, io.EOF)
 }
 
 func (s *PluginSuite) TestBrokerRunTimeThreshold() {
@@ -346,13 +362,26 @@ func (s *PluginSuite) TestBrokerRunSimple() {
 
 	pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}}
 	pb.PluginChannel <- ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}}
-	time.Sleep(time.Millisecond * 200)
+	// make it wait a bit, CI can be slow
+	time.Sleep(time.Second)
 
 	content, err := os.ReadFile("./out")
 	require.NoError(t, err, "Error reading file")
 
+	decoder := json.NewDecoder(bytes.NewReader(content))
+
 	var alerts []models.Alert
-	err = json.Unmarshal(content, &alerts)
+
+	// two notifications, one alert each
+
+	err = decoder.Decode(&alerts)
 	assert.NoError(t, err)
-	assert.Len(t, alerts, 2)
+	assert.Len(t, alerts, 1)
+
+	err = decoder.Decode(&alerts)
+	assert.NoError(t, err)
+	assert.Len(t, alerts, 1)
+
+	err = decoder.Decode(&alerts)
+	assert.Equal(t, err, io.EOF)
 }