watcher_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. package csplugin
  2. import (
  3. "context"
  4. "log"
  5. "runtime"
  6. "testing"
  7. "time"
  8. "github.com/crowdsecurity/crowdsec/pkg/models"
  9. "gopkg.in/tomb.v2"
  10. "gotest.tools/v3/assert"
  11. )
  12. var ctx = context.Background()
  13. func resetTestTomb(testTomb *tomb.Tomb) {
  14. testTomb.Kill(nil)
  15. if err := testTomb.Wait(); err != nil {
  16. log.Fatal(err)
  17. }
  18. }
  19. func resetWatcherAlertCounter(pw *PluginWatcher) {
  20. pw.AlertCountByPluginName.Lock()
  21. for k := range pw.AlertCountByPluginName.data {
  22. pw.AlertCountByPluginName.data[k] = 0
  23. }
  24. pw.AlertCountByPluginName.Unlock()
  25. }
  26. func insertNAlertsToPlugin(pw *PluginWatcher, n int, pluginName string) {
  27. for i := 0; i < n; i++ {
  28. pw.Inserts <- pluginName
  29. }
  30. }
  31. func listenChannelWithTimeout(ctx context.Context, channel chan string) error {
  32. select {
  33. case x := <-channel:
  34. log.Printf("received -> %v", x)
  35. case <-ctx.Done():
  36. return ctx.Err()
  37. }
  38. return nil
  39. }
  40. func TestPluginWatcherInterval(t *testing.T) {
  41. if runtime.GOOS == "windows" {
  42. t.Skip("Skipping test on windows because timing is not reliable")
  43. }
  44. pw := PluginWatcher{}
  45. alertsByPluginName := make(map[string][]*models.Alert)
  46. testTomb := tomb.Tomb{}
  47. configs := map[string]PluginConfig{
  48. "testPlugin": {
  49. GroupWait: time.Millisecond,
  50. },
  51. }
  52. pw.Init(configs, alertsByPluginName)
  53. pw.Start(&testTomb)
  54. ct, cancel := context.WithTimeout(ctx, time.Microsecond)
  55. defer cancel()
  56. err := listenChannelWithTimeout(ct, pw.PluginEvents)
  57. assert.ErrorContains(t, err, "context deadline exceeded")
  58. resetTestTomb(&testTomb)
  59. testTomb = tomb.Tomb{}
  60. pw.Start(&testTomb)
  61. ct, cancel = context.WithTimeout(ctx, time.Millisecond*5)
  62. defer cancel()
  63. err = listenChannelWithTimeout(ct, pw.PluginEvents)
  64. assert.NilError(t, err)
  65. resetTestTomb(&testTomb)
  66. // This is to avoid the int complaining
  67. }
  68. func TestPluginAlertCountWatcher(t *testing.T) {
  69. if runtime.GOOS == "windows" {
  70. t.Skip("Skipping test on windows because timing is not reliable")
  71. }
  72. pw := PluginWatcher{}
  73. alertsByPluginName := make(map[string][]*models.Alert)
  74. configs := map[string]PluginConfig{
  75. "testPlugin": {
  76. GroupThreshold: 5,
  77. },
  78. }
  79. testTomb := tomb.Tomb{}
  80. pw.Init(configs, alertsByPluginName)
  81. pw.Start(&testTomb)
  82. // Channel won't contain any events since threshold is not crossed.
  83. ct, cancel := context.WithTimeout(ctx, time.Second)
  84. defer cancel()
  85. err := listenChannelWithTimeout(ct, pw.PluginEvents)
  86. assert.ErrorContains(t, err, "context deadline exceeded")
  87. // Channel won't contain any events since threshold is not crossed.
  88. resetWatcherAlertCounter(&pw)
  89. insertNAlertsToPlugin(&pw, 4, "testPlugin")
  90. ct, cancel = context.WithTimeout(ctx, time.Second)
  91. defer cancel()
  92. err = listenChannelWithTimeout(ct, pw.PluginEvents)
  93. assert.ErrorContains(t, err, "context deadline exceeded")
  94. // Channel will contain an event since threshold is crossed.
  95. resetWatcherAlertCounter(&pw)
  96. insertNAlertsToPlugin(&pw, 5, "testPlugin")
  97. ct, cancel = context.WithTimeout(ctx, time.Second)
  98. defer cancel()
  99. err = listenChannelWithTimeout(ct, pw.PluginEvents)
  100. assert.NilError(t, err)
  101. resetTestTomb(&testTomb)
  102. }