watcher_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. package csplugin
  2. import (
  3. "context"
  4. "log"
  5. "runtime"
  6. "testing"
  7. "time"
  8. "github.com/crowdsecurity/crowdsec/pkg/models"
  9. "gopkg.in/tomb.v2"
  10. "gotest.tools/v3/assert"
  11. )
  12. var ctx = context.Background()
  13. func resetTestTomb(testTomb *tomb.Tomb, pw *PluginWatcher) {
  14. testTomb.Kill(nil)
  15. <-pw.PluginEvents
  16. if err := testTomb.Wait(); err != nil {
  17. log.Fatal(err)
  18. }
  19. }
  20. func resetWatcherAlertCounter(pw *PluginWatcher) {
  21. pw.AlertCountByPluginName.Lock()
  22. for k := range pw.AlertCountByPluginName.data {
  23. pw.AlertCountByPluginName.data[k] = 0
  24. }
  25. pw.AlertCountByPluginName.Unlock()
  26. }
  27. func insertNAlertsToPlugin(pw *PluginWatcher, n int, pluginName string) {
  28. for i := 0; i < n; i++ {
  29. pw.Inserts <- pluginName
  30. }
  31. }
  32. func listenChannelWithTimeout(ctx context.Context, channel chan string) error {
  33. select {
  34. case x := <-channel:
  35. log.Printf("received -> %v", x)
  36. case <-ctx.Done():
  37. return ctx.Err()
  38. }
  39. return nil
  40. }
  41. func TestPluginWatcherInterval(t *testing.T) {
  42. if runtime.GOOS == "windows" {
  43. t.Skip("Skipping test on windows because timing is not reliable")
  44. }
  45. pw := PluginWatcher{}
  46. alertsByPluginName := make(map[string][]*models.Alert)
  47. testTomb := tomb.Tomb{}
  48. configs := map[string]PluginConfig{
  49. "testPlugin": {
  50. GroupWait: time.Millisecond,
  51. },
  52. }
  53. pw.Init(configs, alertsByPluginName)
  54. pw.Start(&testTomb)
  55. ct, cancel := context.WithTimeout(ctx, time.Microsecond)
  56. defer cancel()
  57. err := listenChannelWithTimeout(ct, pw.PluginEvents)
  58. assert.ErrorContains(t, err, "context deadline exceeded")
  59. resetTestTomb(&testTomb, &pw)
  60. testTomb = tomb.Tomb{}
  61. pw.Start(&testTomb)
  62. ct, cancel = context.WithTimeout(ctx, time.Millisecond*5)
  63. defer cancel()
  64. err = listenChannelWithTimeout(ct, pw.PluginEvents)
  65. assert.NilError(t, err)
  66. resetTestTomb(&testTomb, &pw)
  67. // This is to avoid the int complaining
  68. }
  69. func TestPluginAlertCountWatcher(t *testing.T) {
  70. if runtime.GOOS == "windows" {
  71. t.Skip("Skipping test on windows because timing is not reliable")
  72. }
  73. pw := PluginWatcher{}
  74. alertsByPluginName := make(map[string][]*models.Alert)
  75. configs := map[string]PluginConfig{
  76. "testPlugin": {
  77. GroupThreshold: 5,
  78. },
  79. }
  80. testTomb := tomb.Tomb{}
  81. pw.Init(configs, alertsByPluginName)
  82. pw.Start(&testTomb)
  83. // Channel won't contain any events since threshold is not crossed.
  84. ct, cancel := context.WithTimeout(ctx, time.Second)
  85. defer cancel()
  86. err := listenChannelWithTimeout(ct, pw.PluginEvents)
  87. assert.ErrorContains(t, err, "context deadline exceeded")
  88. // Channel won't contain any events since threshold is not crossed.
  89. resetWatcherAlertCounter(&pw)
  90. insertNAlertsToPlugin(&pw, 4, "testPlugin")
  91. ct, cancel = context.WithTimeout(ctx, time.Second)
  92. defer cancel()
  93. err = listenChannelWithTimeout(ct, pw.PluginEvents)
  94. assert.ErrorContains(t, err, "context deadline exceeded")
  95. // Channel will contain an event since threshold is crossed.
  96. resetWatcherAlertCounter(&pw)
  97. insertNAlertsToPlugin(&pw, 5, "testPlugin")
  98. ct, cancel = context.WithTimeout(ctx, time.Second)
  99. defer cancel()
  100. err = listenChannelWithTimeout(ct, pw.PluginEvents)
  101. assert.NilError(t, err)
  102. resetTestTomb(&testTomb, &pw)
  103. }