watcher_test.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. package csplugin
  2. import (
  3. "context"
  4. "log"
  5. "testing"
  6. "time"
  7. "github.com/crowdsecurity/crowdsec/pkg/models"
  8. "gopkg.in/tomb.v2"
  9. "gotest.tools/v3/assert"
  10. )
  11. var ctx = context.Background()
  12. func resetTestTomb(testTomb *tomb.Tomb) {
  13. testTomb.Kill(nil)
  14. if err := testTomb.Wait(); err != nil {
  15. log.Fatal(err)
  16. }
  17. }
  18. func resetWatcherAlertCounter(pw *PluginWatcher) {
  19. pw.AlertCountByPluginName.Lock()
  20. for k := range pw.AlertCountByPluginName.data {
  21. pw.AlertCountByPluginName.data[k] = 0
  22. }
  23. pw.AlertCountByPluginName.Unlock()
  24. }
  25. func insertNAlertsToPlugin(pw *PluginWatcher, n int, pluginName string) {
  26. for i := 0; i < n; i++ {
  27. pw.Inserts <- pluginName
  28. }
  29. }
  30. func listenChannelWithTimeout(ctx context.Context, channel chan string) error {
  31. select {
  32. case x := <-channel:
  33. log.Printf("received -> %v", x)
  34. case <-ctx.Done():
  35. return ctx.Err()
  36. }
  37. return nil
  38. }
  39. func TestPluginWatcherInterval(t *testing.T) {
  40. pw := PluginWatcher{}
  41. alertsByPluginName := make(map[string][]*models.Alert)
  42. testTomb := tomb.Tomb{}
  43. configs := map[string]PluginConfig{
  44. "testPlugin": {
  45. GroupWait: time.Millisecond,
  46. },
  47. }
  48. pw.Init(configs, alertsByPluginName)
  49. pw.Start(&testTomb)
  50. ct, cancel := context.WithTimeout(ctx, time.Microsecond)
  51. defer cancel()
  52. err := listenChannelWithTimeout(ct, pw.PluginEvents)
  53. assert.ErrorContains(t, err, "context deadline exceeded")
  54. resetTestTomb(&testTomb)
  55. testTomb = tomb.Tomb{}
  56. pw.Start(&testTomb)
  57. ct, cancel = context.WithTimeout(ctx, time.Millisecond*5)
  58. defer cancel()
  59. err = listenChannelWithTimeout(ct, pw.PluginEvents)
  60. assert.NilError(t, err)
  61. resetTestTomb(&testTomb)
  62. // This is to avoid the int complaining
  63. }
  64. func TestPluginAlertCountWatcher(t *testing.T) {
  65. pw := PluginWatcher{}
  66. alertsByPluginName := make(map[string][]*models.Alert)
  67. configs := map[string]PluginConfig{
  68. "testPlugin": {
  69. GroupThreshold: 5,
  70. },
  71. }
  72. testTomb := tomb.Tomb{}
  73. pw.Init(configs, alertsByPluginName)
  74. pw.Start(&testTomb)
  75. // Channel won't contain any events since threshold is not crossed.
  76. ct, cancel := context.WithTimeout(ctx, time.Second)
  77. defer cancel()
  78. err := listenChannelWithTimeout(ct, pw.PluginEvents)
  79. assert.ErrorContains(t, err, "context deadline exceeded")
  80. // Channel won't contain any events since threshold is not crossed.
  81. resetWatcherAlertCounter(&pw)
  82. insertNAlertsToPlugin(&pw, 4, "testPlugin")
  83. ct, cancel = context.WithTimeout(ctx, time.Second)
  84. defer cancel()
  85. err = listenChannelWithTimeout(ct, pw.PluginEvents)
  86. assert.ErrorContains(t, err, "context deadline exceeded")
  87. // Channel will contain an event since threshold is crossed.
  88. resetWatcherAlertCounter(&pw)
  89. insertNAlertsToPlugin(&pw, 5, "testPlugin")
  90. ct, cancel = context.WithTimeout(ctx, time.Second)
  91. defer cancel()
  92. err = listenChannelWithTimeout(ct, pw.PluginEvents)
  93. assert.NilError(t, err)
  94. resetTestTomb(&testTomb)
  95. }