watcher_test.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. package csplugin
  2. import (
  3. "context"
  4. "log"
  5. "runtime"
  6. "testing"
  7. "time"
  8. "github.com/stretchr/testify/require"
  9. "gopkg.in/tomb.v2"
  10. "github.com/crowdsecurity/go-cs-lib/cstest"
  11. "github.com/crowdsecurity/crowdsec/pkg/models"
  12. )
  13. var ctx = context.Background()
  14. func resetTestTomb(testTomb *tomb.Tomb, pw *PluginWatcher) {
  15. testTomb.Kill(nil)
  16. <-pw.PluginEvents
  17. if err := testTomb.Wait(); err != nil {
  18. log.Fatal(err)
  19. }
  20. }
  21. func resetWatcherAlertCounter(pw *PluginWatcher) {
  22. pw.AlertCountByPluginName.Lock()
  23. for k := range pw.AlertCountByPluginName.data {
  24. pw.AlertCountByPluginName.data[k] = 0
  25. }
  26. pw.AlertCountByPluginName.Unlock()
  27. }
  28. func insertNAlertsToPlugin(pw *PluginWatcher, n int, pluginName string) {
  29. for i := 0; i < n; i++ {
  30. pw.Inserts <- pluginName
  31. }
  32. }
  33. func listenChannelWithTimeout(ctx context.Context, channel chan string) error {
  34. select {
  35. case x := <-channel:
  36. log.Printf("received -> %v", x)
  37. case <-ctx.Done():
  38. return ctx.Err()
  39. }
  40. return nil
  41. }
  42. func TestPluginWatcherInterval(t *testing.T) {
  43. if runtime.GOOS == "windows" {
  44. t.Skip("Skipping test on windows because timing is not reliable")
  45. }
  46. pw := PluginWatcher{}
  47. alertsByPluginName := make(map[string][]*models.Alert)
  48. testTomb := tomb.Tomb{}
  49. configs := map[string]PluginConfig{
  50. "testPlugin": {
  51. GroupWait: time.Millisecond,
  52. },
  53. }
  54. pw.Init(configs, alertsByPluginName)
  55. pw.Start(&testTomb)
  56. ct, cancel := context.WithTimeout(ctx, time.Microsecond)
  57. defer cancel()
  58. err := listenChannelWithTimeout(ct, pw.PluginEvents)
  59. cstest.RequireErrorContains(t, err, "context deadline exceeded")
  60. resetTestTomb(&testTomb, &pw)
  61. testTomb = tomb.Tomb{}
  62. pw.Start(&testTomb)
  63. ct, cancel = context.WithTimeout(ctx, time.Millisecond*5)
  64. defer cancel()
  65. err = listenChannelWithTimeout(ct, pw.PluginEvents)
  66. require.NoError(t, err)
  67. resetTestTomb(&testTomb, &pw)
  68. // This is to avoid the int complaining
  69. }
  70. func TestPluginAlertCountWatcher(t *testing.T) {
  71. if runtime.GOOS == "windows" {
  72. t.Skip("Skipping test on windows because timing is not reliable")
  73. }
  74. pw := PluginWatcher{}
  75. alertsByPluginName := make(map[string][]*models.Alert)
  76. configs := map[string]PluginConfig{
  77. "testPlugin": {
  78. GroupThreshold: 5,
  79. },
  80. }
  81. testTomb := tomb.Tomb{}
  82. pw.Init(configs, alertsByPluginName)
  83. pw.Start(&testTomb)
  84. // Channel won't contain any events since threshold is not crossed.
  85. ct, cancel := context.WithTimeout(ctx, time.Second)
  86. defer cancel()
  87. err := listenChannelWithTimeout(ct, pw.PluginEvents)
  88. cstest.RequireErrorContains(t, err, "context deadline exceeded")
  89. // Channel won't contain any events since threshold is not crossed.
  90. resetWatcherAlertCounter(&pw)
  91. insertNAlertsToPlugin(&pw, 4, "testPlugin")
  92. ct, cancel = context.WithTimeout(ctx, time.Second)
  93. defer cancel()
  94. err = listenChannelWithTimeout(ct, pw.PluginEvents)
  95. cstest.RequireErrorContains(t, err, "context deadline exceeded")
  96. // Channel will contain an event since threshold is crossed.
  97. resetWatcherAlertCounter(&pw)
  98. insertNAlertsToPlugin(&pw, 5, "testPlugin")
  99. ct, cancel = context.WithTimeout(ctx, time.Second)
  100. defer cancel()
  101. err = listenChannelWithTimeout(ct, pw.PluginEvents)
  102. require.NoError(t, err)
  103. resetTestTomb(&testTomb, &pw)
  104. }