diff --git a/cmd/crowdsec/crowdsec.go b/cmd/crowdsec/crowdsec.go index ddb263368..1b96dfa20 100644 --- a/cmd/crowdsec/crowdsec.go +++ b/cmd/crowdsec/crowdsec.go @@ -10,7 +10,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/acquisition" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket" "github.com/crowdsecurity/crowdsec/pkg/parser" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -20,10 +19,7 @@ import ( ) func initCrowdsec(cConfig *csconfig.Config) (*parser.Parsers, error) { - err := exprhelpers.Init() - if err != nil { - return &parser.Parsers{}, fmt.Errorf("Failed to init expr helpers : %s", err) - } + var err error // Populate cwhub package tools if err := cwhub.GetHubIdx(cConfig.Hub); err != nil { diff --git a/cmd/crowdsec/serve.go b/cmd/crowdsec/serve.go index 161a54875..54044f20a 100644 --- a/cmd/crowdsec/serve.go +++ b/cmd/crowdsec/serve.go @@ -10,6 +10,8 @@ import ( "github.com/pkg/errors" "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket" "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" @@ -163,12 +165,12 @@ func shutdownCrowdsec() error { func shutdown(sig os.Signal, cConfig *csconfig.Config) error { if !cConfig.DisableAgent { if err := shutdownCrowdsec(); err != nil { - return errors.Wrap(err, "Failed to shut down crowdsec") + return errors.Wrap(err, "failed to shut down crowdsec") } } if !cConfig.DisableAPI { if err := shutdownAPI(); err != nil { - return errors.Wrap(err, "Failed to shut down api routines") + return errors.Wrap(err, "failed to shut down api routines") } } return nil @@ -227,6 +229,24 @@ func Serve(cConfig *csconfig.Config) error { apiTomb = tomb.Tomb{} crowdsecTomb = tomb.Tomb{} pluginTomb = tomb.Tomb{} + + if cConfig.API.Server != nil && cConfig.API.Server.DbConfig != nil { + dbClient, err := database.NewClient(cConfig.API.Server.DbConfig) + if err != nil { + return errors.Wrap(err, "failed to get database client") + } + err = exprhelpers.Init(dbClient) + if err != nil { + return errors.Wrap(err, "failed to init expr helpers") + } + } else { + err := exprhelpers.Init(nil) + if err != nil { + return errors.Wrap(err, "failed to init expr helpers") + } + log.Warningln("Exprhelpers loaded without database client.") + } + if !cConfig.DisableAPI { apiServer, err := initAPIServer(cConfig) if err != nil { diff --git a/config/profiles.yaml b/config/profiles.yaml index f4945b7a0..9d81c9298 100644 --- a/config/profiles.yaml +++ b/config/profiles.yaml @@ -5,6 +5,7 @@ filters: decisions: - type: ban duration: 4h +#duration_expr: Sprintf('%dh', (GetDecisionsCount(Alert.GetValue()) + 1) * 4) # notifications: # - slack_default # Set the webhook in /etc/crowdsec/notifications/slack.yaml before enabling this. # - splunk_default # Set the splunk url and token in /etc/crowdsec/notifications/splunk.yaml before enabling this. diff --git a/pkg/apiserver/controllers/controller.go b/pkg/apiserver/controllers/controller.go index 5dddf71fc..6375b0121 100644 --- a/pkg/apiserver/controllers/controller.go +++ b/pkg/apiserver/controllers/controller.go @@ -61,7 +61,7 @@ func (c *Controller) NewV1() error { v1Config := v1.ControllerV1Config{ DbClient: c.DBClient, Ctx: c.Ectx, - Profiles: c.Profiles, + ProfilesCfg: c.Profiles, CapiChan: c.CAPIChan, PluginChannel: c.PluginChannel, ConsoleConfig: *c.ConsoleConfig, diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index b6b59638a..4d79af907 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -11,7 +11,6 @@ import ( jwt "github.com/appleboy/gin-jwt/v2" "github.com/crowdsecurity/crowdsec/pkg/csplugin" - "github.com/crowdsecurity/crowdsec/pkg/csprofiles" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/gin-gonic/gin" @@ -135,7 +134,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { alert.MachineID = machineID if len(alert.Decisions) != 0 { for pIdx, profile := range c.Profiles { - _, matched, err := csprofiles.EvaluateProfile(profile, alert) + _, matched, err := profile.EvaluateProfile(alert) if err != nil { gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return @@ -144,7 +143,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { continue } c.sendAlertToPluginChannel(alert, uint(pIdx)) - if profile.OnSuccess == "break" { + if profile.Cfg.OnSuccess == "break" { break } } @@ -156,7 +155,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { } for pIdx, profile := range c.Profiles { - profileDecisions, matched, err := csprofiles.EvaluateProfile(profile, alert) + profileDecisions, matched, err := profile.EvaluateProfile(alert) if err != nil { gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return @@ -171,7 +170,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { } profileAlert := *alert c.sendAlertToPluginChannel(&profileAlert, uint(pIdx)) - if profile.OnSuccess == "break" { + if profile.Cfg.OnSuccess == "break" { break } } diff --git a/pkg/apiserver/controllers/v1/controller.go b/pkg/apiserver/controllers/v1/controller.go index 13e9f730c..f29d9291f 100644 --- a/pkg/apiserver/controllers/v1/controller.go +++ b/pkg/apiserver/controllers/v1/controller.go @@ -9,8 +9,10 @@ import ( middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/csprofiles" "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/pkg/errors" ) type Controller struct { @@ -18,7 +20,7 @@ type Controller struct { DBClient *database.Client APIKeyHeader string Middlewares *middlewares.Middlewares - Profiles []*csconfig.ProfileCfg + Profiles []*csprofiles.Runtime CAPIChan chan []*models.Alert PluginChannel chan csplugin.ProfileAlert ConsoleConfig csconfig.ConsoleConfig @@ -28,7 +30,7 @@ type Controller struct { type ControllerV1Config struct { DbClient *database.Client Ctx context.Context - Profiles []*csconfig.ProfileCfg + ProfilesCfg []*csconfig.ProfileCfg CapiChan chan []*models.Alert PluginChannel chan csplugin.ProfileAlert ConsoleConfig csconfig.ConsoleConfig @@ -37,11 +39,17 @@ type ControllerV1Config struct { func New(cfg *ControllerV1Config) (*Controller, error) { var err error + + profiles, err := csprofiles.NewProfile(cfg.ProfilesCfg) + if err != nil { + return &Controller{}, errors.Wrapf(err, "failed to compile profiles") + } + v1 := &Controller{ Ectx: cfg.Ctx, DBClient: cfg.DbClient, APIKeyHeader: middlewares.APIKeyHeader, - Profiles: cfg.Profiles, + Profiles: profiles, CAPIChan: cfg.CapiChan, PluginChannel: cfg.PluginChannel, ConsoleConfig: cfg.ConsoleConfig, diff --git a/pkg/csconfig/api_test.go b/pkg/csconfig/api_test.go index 9f4e8794e..f34042358 100644 --- a/pkg/csconfig/api_test.go +++ b/pkg/csconfig/api_test.go @@ -231,7 +231,7 @@ func TestLoadAPIServer(t *testing.T) { err: "", }, { - name: "basic valid configuration", + name: "basic invalid configuration", Input: &Config{ Self: []byte(configData), API: &APICfg{ diff --git a/pkg/csconfig/profiles.go b/pkg/csconfig/profiles.go index e2a1bbb37..05072c263 100644 --- a/pkg/csconfig/profiles.go +++ b/pkg/csconfig/profiles.go @@ -4,29 +4,23 @@ import ( "bytes" "fmt" "io" - "time" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" - "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/yamlpatch" "github.com/pkg/errors" - log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) //Profile structure(s) are used by the local API to "decide" what kind of decision should be applied when a scenario with an active remediation has been triggered type ProfileCfg struct { - Name string `yaml:"name,omitempty"` - Debug *bool `yaml:"debug,omitempty"` - Filters []string `yaml:"filters,omitempty"` //A list of OR'ed expressions. the models.Alert object - RuntimeFilters []*vm.Program `json:"-" yaml:"-"` - DebugFilters []*exprhelpers.ExprDebugger `json:"-" yaml:"-"` - Decisions []models.Decision `yaml:"decisions,omitempty"` - OnSuccess string `yaml:"on_success,omitempty"` //continue or break - OnFailure string `yaml:"on_failure,omitempty"` //continue or break - Notifications []string `yaml:"notifications,omitempty"` + Name string `yaml:"name,omitempty"` + Debug *bool `yaml:"debug,omitempty"` + Filters []string `yaml:"filters,omitempty"` //A list of OR'ed expressions. the models.Alert object + Decisions []models.Decision `yaml:"decisions,omitempty"` + DurationExpr string `yaml:"duration_expr,omitempty"` + OnSuccess string `yaml:"on_success,omitempty"` //continue or break + OnFailure string `yaml:"on_failure,omitempty"` //continue or break + Notifications []string `yaml:"notifications,omitempty"` } func (c *LocalApiServerCfg) LoadProfiles() error { @@ -56,33 +50,6 @@ func (c *LocalApiServerCfg) LoadProfiles() error { c.Profiles = append(c.Profiles, &t) } - for pIdx, profile := range c.Profiles { - var runtimeFilter *vm.Program - var debugFilter *exprhelpers.ExprDebugger - - c.Profiles[pIdx].RuntimeFilters = make([]*vm.Program, len(profile.Filters)) - c.Profiles[pIdx].DebugFilters = make([]*exprhelpers.ExprDebugger, len(profile.Filters)) - - for fIdx, filter := range profile.Filters { - if runtimeFilter, err = expr.Compile(filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"Alert": &models.Alert{}}))); err != nil { - return errors.Wrapf(err, "Error compiling filter of %s", profile.Name) - } - c.Profiles[pIdx].RuntimeFilters[fIdx] = runtimeFilter - if debugFilter, err = exprhelpers.NewDebugger(filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"Alert": &models.Alert{}}))); err != nil { - log.Debugf("Error compiling debug filter of %s : %s", profile.Name, err) - // Don't fail if we can't compile the filter - for now - // return errors.Wrapf(err, "Error compiling debug filter of %s", profile.Name) - } - c.Profiles[pIdx].DebugFilters[fIdx] = debugFilter - } - - for _, decision := range profile.Decisions { - if _, err := time.ParseDuration(*decision.Duration); err != nil { - return errors.Wrapf(err, "Error parsing duration '%s' of %s", *decision.Duration, profile.Name) - } - } - - } if len(c.Profiles) == 0 { return fmt.Errorf("zero profiles loaded for LAPI") } diff --git a/pkg/csconfig/tests/profiles.yaml b/pkg/csconfig/tests/profiles.yaml index 5727f4edd..8468b240e 100644 --- a/pkg/csconfig/tests/profiles.yaml +++ b/pkg/csconfig/tests/profiles.yaml @@ -29,3 +29,13 @@ decisions: - type: ratatatata duration: 1h on_success: break +--- +name: duration_expression +#debug: true +filters: + - Alert.Remediation == true && Alert.GetScope() == "Ip" +decisions: + - type: ban + duration: 1h +duration_expr: sprintf('%dh', 4*4) +on_success: break diff --git a/pkg/csprofiles/csprofiles.go b/pkg/csprofiles/csprofiles.go index 76c7592b9..b2ca6026e 100644 --- a/pkg/csprofiles/csprofiles.go +++ b/pkg/csprofiles/csprofiles.go @@ -2,8 +2,10 @@ package csprofiles import ( "fmt" + "time" "github.com/antonmedv/expr" + "github.com/antonmedv/expr/vm" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/models" @@ -12,10 +14,86 @@ import ( log "github.com/sirupsen/logrus" ) -func GenerateDecisionFromProfile(Profile *csconfig.ProfileCfg, Alert *models.Alert) ([]*models.Decision, error) { +type Runtime struct { + RuntimeFilters []*vm.Program `json:"-" yaml:"-"` + DebugFilters []*exprhelpers.ExprDebugger `json:"-" yaml:"-"` + RuntimeDurationExpr *vm.Program `json:"-" yaml:"-"` + DebugDurationExpr *exprhelpers.ExprDebugger `json:"-" yaml:"-"` + Cfg *csconfig.ProfileCfg `json:"-" yaml:"-"` + Logger *log.Entry `json:"-" yaml:"-"` +} + +var defaultDuration = "4h" + +func NewProfile(profilesCfg []*csconfig.ProfileCfg) ([]*Runtime, error) { + var err error + profilesRuntime := make([]*Runtime, 0) + + for _, profile := range profilesCfg { + var runtimeFilter, runtimeDurationExpr *vm.Program + var debugFilter, debugDurationExpr *exprhelpers.ExprDebugger + runtime := &Runtime{} + xlog := log.New() + if err := types.ConfigureLogger(xlog); err != nil { + log.Fatalf("While creating profiles-specific logger : %s", err) + } + xlog.SetLevel(log.InfoLevel) + runtime.Logger = xlog.WithFields(log.Fields{ + "type": "profile", + "name": profile.Name, + }) + + runtime.RuntimeFilters = make([]*vm.Program, len(profile.Filters)) + runtime.DebugFilters = make([]*exprhelpers.ExprDebugger, len(profile.Filters)) + runtime.Cfg = profile + + for fIdx, filter := range profile.Filters { + if runtimeFilter, err = expr.Compile(filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"Alert": &models.Alert{}}))); err != nil { + return []*Runtime{}, errors.Wrapf(err, "error compiling filter of '%s'", profile.Name) + } + runtime.RuntimeFilters[fIdx] = runtimeFilter + if profile.Debug != nil && *profile.Debug { + if debugFilter, err = exprhelpers.NewDebugger(filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"Alert": &models.Alert{}}))); err != nil { + log.Debugf("Error compiling debug filter of %s : %s", profile.Name, err) + // Don't fail if we can't compile the filter - for now + // return errors.Wrapf(err, "Error compiling debug filter of %s", profile.Name) + } + runtime.DebugFilters[fIdx] = debugFilter + runtime.Logger.Logger.SetLevel(log.DebugLevel) + } + } + + if profile.DurationExpr != "" { + if runtimeDurationExpr, err = expr.Compile(profile.DurationExpr, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"Alert": &models.Alert{}}))); err != nil { + return []*Runtime{}, errors.Wrapf(err, "error compiling duration_expr of %s", profile.Name) + } + + runtime.RuntimeDurationExpr = runtimeDurationExpr + if profile.Debug != nil && *profile.Debug { + if debugDurationExpr, err = exprhelpers.NewDebugger(profile.DurationExpr, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"Alert": &models.Alert{}}))); err != nil { + log.Debugf("Error compiling debug duration_expr of %s : %s", profile.Name, err) + } + runtime.DebugDurationExpr = debugDurationExpr + } + } + + for _, decision := range profile.Decisions { + if runtime.RuntimeDurationExpr == nil { + if _, err := time.ParseDuration(*decision.Duration); err != nil { + return []*Runtime{}, errors.Wrapf(err, "error parsing duration '%s' of %s", *decision.Duration, profile.Name) + } + } + } + + profilesRuntime = append(profilesRuntime, runtime) + } + return profilesRuntime, nil +} + +func (Profile *Runtime) GenerateDecisionFromProfile(Alert *models.Alert) ([]*models.Decision, error) { var decisions []*models.Decision - for _, refDecision := range Profile.Decisions { + for _, refDecision := range Profile.Cfg.Decisions { decision := models.Decision{} /*the reference decision from profile is in sumulated mode */ if refDecision.Simulated != nil && *refDecision.Simulated { @@ -36,7 +114,27 @@ func GenerateDecisionFromProfile(Profile *csconfig.ProfileCfg, Alert *models.Ale } /*some fields are populated from the reference object : duration, scope, type*/ decision.Duration = new(string) - *decision.Duration = *refDecision.Duration + if Profile.Cfg.DurationExpr != "" && Profile.RuntimeDurationExpr != nil { + duration, err := expr.Run(Profile.RuntimeDurationExpr, exprhelpers.GetExprEnv(map[string]interface{}{"Alert": Alert})) + if err != nil { + Profile.Logger.Warningf("Failed to run duration_expr : %v", err) + *decision.Duration = *refDecision.Duration + } else { + durationStr := fmt.Sprint(duration) + if _, err := time.ParseDuration(durationStr); err != nil { + Profile.Logger.Warningf("Failed to parse expr duration result '%s'", duration) + *decision.Duration = *refDecision.Duration + } else { + *decision.Duration = durationStr + } + } + } else { + if refDecision.Duration == nil { + *decision.Duration = defaultDuration + } + *decision.Duration = *refDecision.Duration + } + decision.Type = new(string) *decision.Type = *refDecision.Type @@ -55,54 +153,44 @@ func GenerateDecisionFromProfile(Profile *csconfig.ProfileCfg, Alert *models.Ale return decisions, nil } -var clog *log.Entry - //EvaluateProfile is going to evaluate an Alert against a profile to generate Decisions -func EvaluateProfile(profile *csconfig.ProfileCfg, Alert *models.Alert) ([]*models.Decision, bool, error) { +func (Profile *Runtime) EvaluateProfile(Alert *models.Alert) ([]*models.Decision, bool, error) { var decisions []*models.Decision - if clog == nil { - xlog := log.New() - if err := types.ConfigureLogger(xlog); err != nil { - log.Fatalf("While creating profiles-specific logger : %s", err) - } - xlog.SetLevel(log.TraceLevel) - clog = xlog.WithFields(log.Fields{ - "type": "profile", - }) - } + matched := false - for eIdx, expression := range profile.RuntimeFilters { + for eIdx, expression := range Profile.RuntimeFilters { output, err := expr.Run(expression, exprhelpers.GetExprEnv(map[string]interface{}{"Alert": Alert})) if err != nil { - log.Warningf("failed to run whitelist expr : %v", err) - return nil, matched, errors.Wrapf(err, "while running expression %s", profile.Filters[eIdx]) + Profile.Logger.Warningf("failed to run whitelist expr : %v", err) + return nil, matched, errors.Wrapf(err, "while running expression %s", Profile.Cfg.Filters[eIdx]) } switch out := output.(type) { case bool: - if profile.Debug != nil && *profile.Debug { - profile.DebugFilters[eIdx].Run(clog, out, exprhelpers.GetExprEnv(map[string]interface{}{"Alert": Alert})) + if Profile.Cfg.Debug != nil && *Profile.Cfg.Debug { + Profile.DebugFilters[eIdx].Run(Profile.Logger, out, exprhelpers.GetExprEnv(map[string]interface{}{"Alert": Alert})) } if out { matched = true /*the expression matched, create the associated decision*/ - subdecisions, err := GenerateDecisionFromProfile(profile, Alert) + subdecisions, err := Profile.GenerateDecisionFromProfile(Alert) if err != nil { - return nil, matched, errors.Wrapf(err, "while generating decision from profile %s", profile.Name) + return nil, matched, errors.Wrapf(err, "while generating decision from profile %s", Profile.Cfg.Name) } decisions = append(decisions, subdecisions...) } else { - log.Debugf("Profile %s filter is unsuccessful", profile.Name) - if profile.OnFailure == "break" { + Profile.Logger.Debugf("Profile %s filter is unsuccessful", Profile.Cfg.Name) + if Profile.Cfg.OnFailure == "break" { break } } default: - return nil, matched, fmt.Errorf("unexpected type %t (%v) while running '%s'", output, output, profile.Filters[eIdx]) + return nil, matched, fmt.Errorf("unexpected type %t (%v) while running '%s'", output, output, Profile.Cfg.Filters[eIdx]) } } + return decisions, matched, nil } diff --git a/pkg/csprofiles/csprofiles_test.go b/pkg/csprofiles/csprofiles_test.go index 645532a03..258319748 100644 --- a/pkg/csprofiles/csprofiles_test.go +++ b/pkg/csprofiles/csprofiles_test.go @@ -5,40 +5,117 @@ import ( "reflect" "testing" - "github.com/antonmedv/expr" - "github.com/antonmedv/expr/vm" "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/models" + "gotest.tools/v3/assert" ) var ( scope = "Country" typ = "ban" - simulated = false + boolFalse = false + boolTrue = true duration = "1h" value = "CH" scenario = "ssh-bf" ) +func TestNewProfile(t *testing.T) { + tests := []struct { + name string + profileCfg *csconfig.ProfileCfg + expectedNbProfile int + }{ + { + name: "filter ok and duration_expr ok", + profileCfg: &csconfig.ProfileCfg{ + Filters: []string{ + "1==1", + }, + DurationExpr: "1==1", + Debug: &boolFalse, + Decisions: []models.Decision{ + {Type: &typ, Scope: &scope, Simulated: &boolTrue, Duration: &duration}, + }, + }, + expectedNbProfile: 1, + }, + { + name: "filter NOK and duration_expr ok", + profileCfg: &csconfig.ProfileCfg{ + Filters: []string{ + "1==1", + "unknownExprHelper() == 'foo'", + }, + DurationExpr: "1==1", + Debug: &boolFalse, + Decisions: []models.Decision{ + {Type: &typ, Scope: &scope, Simulated: &boolFalse, Duration: &duration}, + }, + }, + expectedNbProfile: 0, + }, + { + name: "filter ok and duration_expr NOK", + profileCfg: &csconfig.ProfileCfg{ + Filters: []string{ + "1==1", + }, + DurationExpr: "unknownExprHelper() == 'foo'", + Debug: &boolFalse, + Decisions: []models.Decision{ + {Type: &typ, Scope: &scope, Simulated: &boolFalse, Duration: &duration}, + }, + }, + expectedNbProfile: 0, + }, + { + name: "filter ok and duration_expr ok + DEBUG", + profileCfg: &csconfig.ProfileCfg{ + Filters: []string{ + "1==1", + }, + DurationExpr: "1==1", + Debug: &boolTrue, + Decisions: []models.Decision{ + {Type: &typ, Scope: &scope, Simulated: &boolFalse, Duration: &duration}, + }, + }, + expectedNbProfile: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + profilesCfg := []*csconfig.ProfileCfg{ + test.profileCfg, + } + profile, _ := NewProfile(profilesCfg) + fmt.Printf("expected : %+v | result : %+v", test.expectedNbProfile, len(profile)) + assert.Equal(t, test.expectedNbProfile, len(profile)) + }) + } +} + func TestEvaluateProfile(t *testing.T) { type args struct { - profile *csconfig.ProfileCfg - Alert *models.Alert + profileCfg *csconfig.ProfileCfg + Alert *models.Alert } tests := []struct { name string args args expectedDecisionCount int // count of expected decisions + expectedDuration string expectedMatchStatus bool }{ { name: "simple pass single expr", args: args{ - profile: &csconfig.ProfileCfg{ - Filters: []string{fmt.Sprintf("Alert.GetScenario() == \"%s\"", scenario)}, - RuntimeFilters: []*vm.Program{}, + profileCfg: &csconfig.ProfileCfg{ + Filters: []string{fmt.Sprintf("Alert.GetScenario() == \"%s\"", scenario)}, + Debug: &boolFalse, }, Alert: &models.Alert{Remediation: true, Scenario: &scenario}, }, @@ -48,9 +125,8 @@ func TestEvaluateProfile(t *testing.T) { { name: "simple fail single expr", args: args{ - profile: &csconfig.ProfileCfg{ - Filters: []string{"Alert.GetScenario() == \"Foo\""}, - RuntimeFilters: []*vm.Program{}, + profileCfg: &csconfig.ProfileCfg{ + Filters: []string{"Alert.GetScenario() == \"Foo\""}, }, Alert: &models.Alert{Remediation: true}, }, @@ -60,9 +136,8 @@ func TestEvaluateProfile(t *testing.T) { { name: "1 expr fail 1 expr pass should still eval to match", args: args{ - profile: &csconfig.ProfileCfg{ - Filters: []string{"1==1", "1!=1"}, - RuntimeFilters: []*vm.Program{}, + profileCfg: &csconfig.ProfileCfg{ + Filters: []string{"1==1", "1!=1"}, }, Alert: &models.Alert{Remediation: true}, }, @@ -72,12 +147,11 @@ func TestEvaluateProfile(t *testing.T) { { name: "simple filter with 2 decision", args: args{ - profile: &csconfig.ProfileCfg{ - Filters: []string{"1==1"}, - RuntimeFilters: []*vm.Program{}, + profileCfg: &csconfig.ProfileCfg{ + Filters: []string{"1==1"}, Decisions: []models.Decision{ - {Type: &typ, Scope: &scope, Simulated: &simulated, Duration: &duration}, - {Type: &typ, Scope: &scope, Simulated: &simulated, Duration: &duration}, + {Type: &typ, Scope: &scope, Simulated: &boolTrue, Duration: &duration}, + {Type: &typ, Scope: &scope, Simulated: &boolFalse, Duration: &duration}, }, }, Alert: &models.Alert{Remediation: true, Scenario: &scenario, Source: &models.Source{Value: &value}}, @@ -85,20 +159,42 @@ func TestEvaluateProfile(t *testing.T) { expectedDecisionCount: 2, expectedMatchStatus: true, }, + { + name: "simple filter with decision_expr", + args: args{ + profileCfg: &csconfig.ProfileCfg{ + Filters: []string{"1==1"}, + Decisions: []models.Decision{ + {Type: &typ, Scope: &scope, Simulated: &boolFalse}, + }, + DurationExpr: "Sprintf('%dh', 4*4)", + }, + Alert: &models.Alert{Remediation: true, Scenario: &scenario, Source: &models.Source{Value: &value}}, + }, + expectedDecisionCount: 1, + expectedDuration: "16h", + expectedMatchStatus: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - for _, filter := range tt.args.profile.Filters { - runtimeFilter, _ := expr.Compile(filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"Alert": &models.Alert{}}))) - tt.args.profile.RuntimeFilters = append(tt.args.profile.RuntimeFilters, runtimeFilter) + profilesCfg := []*csconfig.ProfileCfg{ + tt.args.profileCfg, } - got, got1, _ := EvaluateProfile(tt.args.profile, tt.args.Alert) + profile, err := NewProfile(profilesCfg) + if err != nil { + t.Errorf("failed to get newProfile : %+v", err) + } + got, got1, _ := profile[0].EvaluateProfile(tt.args.Alert) if !reflect.DeepEqual(len(got), tt.expectedDecisionCount) { t.Errorf("EvaluateProfile() got = %+v, want %+v", got, tt.expectedDecisionCount) } if got1 != tt.expectedMatchStatus { t.Errorf("EvaluateProfile() got1 = %v, want %v", got1, tt.expectedMatchStatus) } + if tt.expectedDuration != "" { + assert.Equal(t, tt.expectedDuration, *got[0].Duration, "The two durations should be the same") + } }) } } diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index 92d83bd8b..7e700f9d8 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -188,7 +188,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in } duration, err := time.ParseDuration(*decisionItem.Duration) if err != nil { - return 0, 0, 0, errors.Wrapf(ParseDurationFail, "decision duration '%v' : %s", decisionItem.Duration, err) + return 0, 0, 0, errors.Wrapf(ParseDurationFail, "decision duration '%+v' : %s", *decisionItem.Duration, err) } if decisionItem.Scope == nil { log.Warning("nil scope in community decision") @@ -425,7 +425,7 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([ duration, err := time.ParseDuration(*decisionItem.Duration) if err != nil { - return []string{}, errors.Wrapf(ParseDurationFail, "decision duration '%v' : %s", decisionItem.Duration, err) + return []string{}, errors.Wrapf(ParseDurationFail, "decision duration '%+v' : %s", *decisionItem.Duration, err) } /*if the scope is IP or Range, convert the value to integers */ diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index 710d37b5e..46950462e 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -115,74 +115,9 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] } } - if ip_sz == 4 { - - if contains { /*decision contains {start_ip,end_ip}*/ - query = query.Where(decision.And( - decision.StartIPLTE(start_ip), - decision.EndIPGTE(end_ip), - decision.IPSizeEQ(int64(ip_sz)), - )) - } else { /*decision is contained within {start_ip,end_ip}*/ - query = query.Where(decision.And( - decision.StartIPGTE(start_ip), - decision.EndIPLTE(end_ip), - decision.IPSizeEQ(int64(ip_sz)), - )) - } - } else if ip_sz == 16 { - - if contains { /*decision contains {start_ip,end_ip}*/ - query = query.Where(decision.And( - //matching addr size - decision.IPSizeEQ(int64(ip_sz)), - decision.Or( - //decision.start_ip < query.start_ip - decision.StartIPLT(start_ip), - decision.And( - //decision.start_ip == query.start_ip - decision.StartIPEQ(start_ip), - //decision.start_suffix <= query.start_suffix - decision.StartSuffixLTE(start_sfx), - )), - decision.Or( - //decision.end_ip > query.end_ip - decision.EndIPGT(end_ip), - decision.And( - //decision.end_ip == query.end_ip - decision.EndIPEQ(end_ip), - //decision.end_suffix >= query.end_suffix - decision.EndSuffixGTE(end_sfx), - ), - ), - )) - } else { /*decision is contained {start_ip,end_ip}*/ - query = query.Where(decision.And( - //matching addr size - decision.IPSizeEQ(int64(ip_sz)), - decision.Or( - //decision.start_ip > query.start_ip - decision.StartIPGT(start_ip), - decision.And( - //decision.start_ip == query.start_ip - decision.StartIPEQ(start_ip), - //decision.start_suffix >= query.start_suffix - decision.StartSuffixGTE(start_sfx), - )), - decision.Or( - //decision.end_ip < query.end_ip - decision.EndIPLT(end_ip), - decision.And( - //decision.end_ip == query.end_ip - decision.EndIPEQ(end_ip), - //decision.end_suffix <= query.end_suffix - decision.EndSuffixLTE(end_sfx), - ), - ), - )) - } - } else if ip_sz != 0 { - return nil, nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) + query, err = applyStartIpEndIpFilter(query, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) + if err != nil { + return nil, nil, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") } return query, joinPredicate, nil } @@ -594,6 +529,133 @@ func (c *Client) SoftDeleteDecisionByID(decisionID int) (int, error) { return nbUpdated, nil } +func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { + var err error + var start_ip, start_sfx, end_ip, end_sfx int64 + var ip_sz, count int + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue) + + if err != nil { + return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err) + } + + contains := true + decisions := c.Ent.Decision.Query() + decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) + if err != nil { + return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") + } + + count, err = decisions.Count(c.CTX) + if err != nil { + return 0, errors.Wrapf(err, "fail to count decisions") + } + + return count, nil +} + +func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Time) (int, error) { + var err error + var start_ip, start_sfx, end_ip, end_sfx int64 + var ip_sz, count int + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue) + + if err != nil { + return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err) + } + + contains := true + decisions := c.Ent.Decision.Query().Where( + decision.CreatedAtGT(since), + decision.UntilGT(time.Now().UTC()), + ) + decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) + if err != nil { + return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") + } + count, err = decisions.Count(c.CTX) + if err != nil { + return 0, errors.Wrapf(err, "fail to count decisions") + } + + return count, nil +} + +func applyStartIpEndIpFilter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) { + if ip_sz == 4 { + if contains { + /*Decision contains {start_ip,end_ip}*/ + decisions = decisions.Where(decision.And( + decision.StartIPLTE(start_ip), + decision.EndIPGTE(end_ip), + decision.IPSizeEQ(int64(ip_sz)), + )) + } else { + /*Decision is contained within {start_ip,end_ip}*/ + decisions = decisions.Where(decision.And( + decision.StartIPGTE(start_ip), + decision.EndIPLTE(end_ip), + decision.IPSizeEQ(int64(ip_sz)), + )) + } + } else if ip_sz == 16 { + /*decision contains {start_ip,end_ip}*/ + if contains { + decisions = decisions.Where(decision.And( + //matching addr size + decision.IPSizeEQ(int64(ip_sz)), + decision.Or( + //decision.start_ip < query.start_ip + decision.StartIPLT(start_ip), + decision.And( + //decision.start_ip == query.start_ip + decision.StartIPEQ(start_ip), + //decision.start_suffix <= query.start_suffix + decision.StartSuffixLTE(start_sfx), + )), + decision.Or( + //decision.end_ip > query.end_ip + decision.EndIPGT(end_ip), + decision.And( + //decision.end_ip == query.end_ip + decision.EndIPEQ(end_ip), + //decision.end_suffix >= query.end_suffix + decision.EndSuffixGTE(end_sfx), + ), + ), + )) + } else { + /*decision is contained within {start_ip,end_ip}*/ + decisions = decisions.Where(decision.And( + //matching addr size + decision.IPSizeEQ(int64(ip_sz)), + decision.Or( + //decision.start_ip > query.start_ip + decision.StartIPGT(start_ip), + decision.And( + //decision.start_ip == query.start_ip + decision.StartIPEQ(start_ip), + //decision.start_suffix >= query.start_suffix + decision.StartSuffixGTE(start_sfx), + )), + decision.Or( + //decision.end_ip < query.end_ip + decision.EndIPLT(end_ip), + decision.And( + //decision.end_ip == query.end_ip + decision.EndIPEQ(end_ip), + //decision.end_suffix <= query.end_suffix + decision.EndSuffixLTE(end_sfx), + ), + ), + )) + } + } else if ip_sz != 0 { + return nil, errors.Wrapf(InvalidFilter, "unknown ip size %d", ip_sz) + } + return decisions, nil +} + func decisionPredicatesFromStr(s string, predicateFunc func(string) predicate.Decision) []predicate.Decision { words := strings.Split(s, ",") predicates := make([]predicate.Decision, len(words)) diff --git a/pkg/exprhelpers/exprlib.go b/pkg/exprhelpers/exprlib.go index 47396ebbf..af3034922 100644 --- a/pkg/exprhelpers/exprlib.go +++ b/pkg/exprhelpers/exprlib.go @@ -14,12 +14,14 @@ import ( "github.com/c-robinson/iplib" + "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/davecgh/go-spew/spew" log "github.com/sirupsen/logrus" ) var dataFile map[string][]string var dataFileRegex map[string][]*regexp.Regexp +var dbClient *database.Client func Atof(x string) float64 { log.Debugf("debug atof %s", x) @@ -40,28 +42,31 @@ func Lower(s string) string { func GetExprEnv(ctx map[string]interface{}) map[string]interface{} { var ExprLib = map[string]interface{}{ - "Atof": Atof, - "JsonExtract": JsonExtract, - "JsonExtractUnescape": JsonExtractUnescape, - "JsonExtractLib": JsonExtractLib, - "JsonExtractSlice": JsonExtractSlice, - "JsonExtractObject": JsonExtractObject, - "ToJsonString": ToJson, - "File": File, - "RegexpInFile": RegexpInFile, - "Upper": Upper, - "Lower": Lower, - "IpInRange": IpInRange, - "TimeNow": TimeNow, - "ParseUri": ParseUri, - "PathUnescape": PathUnescape, - "QueryUnescape": QueryUnescape, - "PathEscape": PathEscape, - "QueryEscape": QueryEscape, - "XMLGetAttributeValue": XMLGetAttributeValue, - "XMLGetNodeValue": XMLGetNodeValue, - "IpToRange": IpToRange, - "IsIPV6": IsIPV6, + "Atof": Atof, + "JsonExtract": JsonExtract, + "JsonExtractUnescape": JsonExtractUnescape, + "JsonExtractLib": JsonExtractLib, + "JsonExtractSlice": JsonExtractSlice, + "JsonExtractObject": JsonExtractObject, + "ToJsonString": ToJson, + "File": File, + "RegexpInFile": RegexpInFile, + "Upper": Upper, + "Lower": Lower, + "IpInRange": IpInRange, + "TimeNow": TimeNow, + "ParseUri": ParseUri, + "PathUnescape": PathUnescape, + "QueryUnescape": QueryUnescape, + "PathEscape": PathEscape, + "QueryEscape": QueryEscape, + "XMLGetAttributeValue": XMLGetAttributeValue, + "XMLGetNodeValue": XMLGetNodeValue, + "IpToRange": IpToRange, + "IsIPV6": IsIPV6, + "GetDecisionsCount": GetDecisionsCount, + "GetDecisionsSinceCount": GetDecisionsSinceCount, + "Sprintf": fmt.Sprintf, } for k, v := range ctx { ExprLib[k] = v @@ -69,9 +74,10 @@ func GetExprEnv(ctx map[string]interface{}) map[string]interface{} { return ExprLib } -func Init() error { +func Init(databaseClient *database.Client) error { dataFile = make(map[string][]string) dataFileRegex = make(map[string][]*regexp.Regexp) + dbClient = databaseClient return nil } @@ -242,3 +248,35 @@ func KeyExists(key string, dict map[string]interface{}) bool { _, ok := dict[key] return ok } + +func GetDecisionsCount(value string) int { + if dbClient == nil { + log.Error("No database config to call GetDecisionsCount()") + return 0 + } + count, err := dbClient.CountDecisionsByValue(value) + if err != nil { + log.Errorf("Failed to get decisions count from value '%s'", value) + return 0 + } + return count +} + +func GetDecisionsSinceCount(value string, since string) int { + if dbClient == nil { + log.Error("No database config to call GetDecisionsCount()") + return 0 + } + sinceDuration, err := time.ParseDuration(since) + if err != nil { + log.Errorf("Failed to parse since parameter '%s' : %s", since, err) + return 0 + } + sinceTime := time.Now().UTC().Add(-sinceDuration) + count, err := dbClient.CountDecisionsSinceByValue(value, sinceTime) + if err != nil { + log.Errorf("Failed to get decisions count from value '%s'", value) + return 0 + } + return count +} diff --git a/pkg/exprhelpers/exprlib_test.go b/pkg/exprhelpers/exprlib_test.go index 2f59e3e4e..bb404e4da 100644 --- a/pkg/exprhelpers/exprlib_test.go +++ b/pkg/exprhelpers/exprlib_test.go @@ -1,9 +1,17 @@ package exprhelpers import ( + "context" "fmt" + "os" "time" + "github.com/pkg/errors" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "testing" @@ -17,8 +25,25 @@ var ( TestFolder = "tests" ) +func getDBClient(t *testing.T) *database.Client { + t.Helper() + dbPath, err := os.CreateTemp("", "*sqlite") + if err != nil { + t.Fatal(err) + } + testDbClient, err := database.NewClient(&csconfig.DatabaseCfg{ + Type: "sqlite", + DbName: "crowdsec", + DbPath: dbPath.Name(), + }) + if err != nil { + t.Fatal(err) + } + return testDbClient +} + func TestVisitor(t *testing.T) { - if err := Init(); err != nil { + if err := Init(nil); err != nil { log.Fatalf(err.Error()) } @@ -105,7 +130,7 @@ func TestVisitor(t *testing.T) { } func TestRegexpInFile(t *testing.T) { - if err := Init(); err != nil { + if err := Init(nil); err != nil { log.Fatalf(err.Error()) } @@ -162,7 +187,7 @@ func TestRegexpInFile(t *testing.T) { } func TestFileInit(t *testing.T) { - if err := Init(); err != nil { + if err := Init(nil); err != nil { log.Fatalf(err.Error()) } @@ -230,7 +255,7 @@ func TestFileInit(t *testing.T) { } func TestFile(t *testing.T) { - if err := Init(); err != nil { + if err := Init(nil); err != nil { log.Fatalf(err.Error()) } @@ -731,3 +756,218 @@ func TestLower(t *testing.T) { log.Printf("test '%s' : OK", test.name) } } + +func TestGetDecisionsCount(t *testing.T) { + var err error + var start_ip, start_sfx, end_ip, end_sfx int64 + var ip_sz int + existingIP := "1.2.3.4" + unknownIP := "1.2.3.5" + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(existingIP) + if err != nil { + t.Errorf("unable to convert '%s' to int: %s", existingIP, err) + } + // Add sample data to DB + dbClient = getDBClient(t) + + decision := dbClient.Ent.Decision.Create(). + SetUntil(time.Now().Add(time.Hour)). + SetScenario("crowdsec/test"). + SetStartIP(start_ip). + SetStartSuffix(start_sfx). + SetEndIP(end_ip). + SetEndSuffix(end_sfx). + SetIPSize(int64(ip_sz)). + SetType("ban"). + SetScope("IP"). + SetValue(existingIP). + SetOrigin("CAPI"). + SaveX(context.Background()) + + if decision == nil { + assert.Error(t, errors.Errorf("Failed to create sample decision")) + } + + tests := []struct { + name string + env map[string]interface{} + code string + result string + err string + }{ + { + name: "GetDecisionsCount() test: existing IP count", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &existingIP, + }, + Decisions: []*models.Decision{ + { + Value: &existingIP, + }, + }, + }, + "GetDecisionsCount": GetDecisionsCount, + "sprintf": fmt.Sprintf, + }, + code: "sprintf('%d', GetDecisionsCount(Alert.GetValue()))", + result: "1", + err: "", + }, + { + name: "GetDecisionsCount() test: unknown IP count", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &unknownIP, + }, + Decisions: []*models.Decision{ + { + Value: &unknownIP, + }, + }, + }, + "GetDecisionsCount": GetDecisionsCount, + "sprintf": fmt.Sprintf, + }, + code: "sprintf('%d', GetDecisionsCount(Alert.GetValue()))", + result: "0", + err: "", + }, + } + + for _, test := range tests { + program, err := expr.Compile(test.code, expr.Env(GetExprEnv(test.env))) + require.NoError(t, err) + output, err := expr.Run(program, GetExprEnv(test.env)) + require.NoError(t, err) + require.Equal(t, test.result, output) + log.Printf("test '%s' : OK", test.name) + } +} +func TestGetDecisionsSinceCount(t *testing.T) { + var err error + var start_ip, start_sfx, end_ip, end_sfx int64 + var ip_sz int + existingIP := "1.2.3.4" + unknownIP := "1.2.3.5" + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(existingIP) + if err != nil { + t.Errorf("unable to convert '%s' to int: %s", existingIP, err) + } + // Add sample data to DB + dbClient = getDBClient(t) + + decision := dbClient.Ent.Decision.Create(). + SetUntil(time.Now().Add(time.Hour)). + SetScenario("crowdsec/test"). + SetStartIP(start_ip). + SetStartSuffix(start_sfx). + SetEndIP(end_ip). + SetEndSuffix(end_sfx). + SetIPSize(int64(ip_sz)). + SetType("ban"). + SetScope("IP"). + SetValue(existingIP). + SetOrigin("CAPI"). + SaveX(context.Background()) + if decision == nil { + assert.Error(t, errors.Errorf("Failed to create sample decision")) + } + decision2 := dbClient.Ent.Decision.Create(). + SetCreatedAt(time.Now().AddDate(0, 0, -1)). + SetUntil(time.Now().Add(time.Hour)). + SetScenario("crowdsec/test"). + SetStartIP(start_ip). + SetStartSuffix(start_sfx). + SetEndIP(end_ip). + SetEndSuffix(end_sfx). + SetIPSize(int64(ip_sz)). + SetType("ban"). + SetScope("IP"). + SetValue(existingIP). + SetOrigin("CAPI"). + SaveX(context.Background()) + if decision2 == nil { + assert.Error(t, errors.Errorf("Failed to create sample decision")) + } + + tests := []struct { + name string + env map[string]interface{} + code string + result string + err string + }{ + { + name: "GetDecisionsSinceCount() test: existing IP count since more than 1 day", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &existingIP, + }, + Decisions: []*models.Decision{ + { + Value: &existingIP, + }, + }, + }, + "GetDecisionsSinceCount": GetDecisionsSinceCount, + "sprintf": fmt.Sprintf, + }, + code: "sprintf('%d', GetDecisionsSinceCount(Alert.GetValue(), '25h'))", + result: "2", + err: "", + }, + { + name: "GetDecisionsSinceCount() test: existing IP count since more than 1 hour", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &existingIP, + }, + Decisions: []*models.Decision{ + { + Value: &existingIP, + }, + }, + }, + "GetDecisionsSinceCount": GetDecisionsSinceCount, + "sprintf": fmt.Sprintf, + }, + code: "sprintf('%d', GetDecisionsSinceCount(Alert.GetValue(), '1h'))", + result: "1", + err: "", + }, + { + name: "GetDecisionsSinceCount() test: unknown IP count", + env: map[string]interface{}{ + "Alert": &models.Alert{ + Source: &models.Source{ + Value: &unknownIP, + }, + Decisions: []*models.Decision{ + { + Value: &unknownIP, + }, + }, + }, + "GetDecisionsSinceCount": GetDecisionsSinceCount, + "sprintf": fmt.Sprintf, + }, + code: "sprintf('%d', GetDecisionsSinceCount(Alert.GetValue(), '1h'))", + result: "0", + err: "", + }, + } + + for _, test := range tests { + program, err := expr.Compile(test.code, expr.Env(GetExprEnv(test.env))) + require.NoError(t, err) + output, err := expr.Run(program, GetExprEnv(test.env)) + require.NoError(t, err) + require.Equal(t, test.result, output) + log.Printf("test '%s' : OK", test.name) + } +} diff --git a/pkg/exprhelpers/jsonextract_test.go b/pkg/exprhelpers/jsonextract_test.go index 30825478e..dd8b9ea87 100644 --- a/pkg/exprhelpers/jsonextract_test.go +++ b/pkg/exprhelpers/jsonextract_test.go @@ -8,7 +8,7 @@ import ( ) func TestJsonExtract(t *testing.T) { - if err := Init(); err != nil { + if err := Init(nil); err != nil { log.Fatalf(err.Error()) } @@ -54,7 +54,7 @@ func TestJsonExtract(t *testing.T) { } func TestJsonExtractUnescape(t *testing.T) { - if err := Init(); err != nil { + if err := Init(nil); err != nil { log.Fatalf(err.Error()) } @@ -94,7 +94,7 @@ func TestJsonExtractUnescape(t *testing.T) { } func TestJsonExtractSlice(t *testing.T) { - if err := Init(); err != nil { + if err := Init(nil); err != nil { log.Fatalf(err.Error()) } @@ -144,7 +144,7 @@ func TestJsonExtractSlice(t *testing.T) { } func TestJsonExtractObject(t *testing.T) { - if err := Init(); err != nil { + if err := Init(nil); err != nil { log.Fatalf(err.Error()) } diff --git a/pkg/exprhelpers/xml_test.go b/pkg/exprhelpers/xml_test.go index ca1f54b37..a6fdae3f9 100644 --- a/pkg/exprhelpers/xml_test.go +++ b/pkg/exprhelpers/xml_test.go @@ -8,7 +8,7 @@ import ( ) func TestXMLGetAttributeValue(t *testing.T) { - if err := Init(); err != nil { + if err := Init(nil); err != nil { log.Fatalf(err.Error()) } @@ -67,7 +67,7 @@ func TestXMLGetAttributeValue(t *testing.T) { } func TestXMLGetNodeValue(t *testing.T) { - if err := Init(); err != nil { + if err := Init(nil); err != nil { log.Fatalf(err.Error()) } diff --git a/pkg/leakybucket/buckets_test.go b/pkg/leakybucket/buckets_test.go index fb4934c6f..23e5b1c79 100644 --- a/pkg/leakybucket/buckets_test.go +++ b/pkg/leakybucket/buckets_test.go @@ -33,7 +33,7 @@ func TestBucket(t *testing.T) { envSetting = os.Getenv("TEST_ONLY") tomb *tomb.Tomb = &tomb.Tomb{} ) - err := exprhelpers.Init() + err := exprhelpers.Init(nil) if err != nil { log.Fatalf("exprhelpers init failed: %s", err) } diff --git a/pkg/models/helpers.go b/pkg/models/helpers.go index d476e1fc3..0724d89ae 100644 --- a/pkg/models/helpers.go +++ b/pkg/models/helpers.go @@ -11,6 +11,13 @@ func (a *Alert) GetScope() string { return *a.Source.Scope } +func (a *Alert) GetValue() string { + if a.Source.Value == nil { + return "" + } + return *a.Source.Value +} + func (a *Alert) GetScenario() string { if a.Scenario == nil { return "" diff --git a/pkg/parser/node.go b/pkg/parser/node.go index 1b0282bde..96107f19f 100644 --- a/pkg/parser/node.go +++ b/pkg/parser/node.go @@ -58,7 +58,7 @@ type Node struct { //Statics can be present in any type of node and is executed last Statics []types.ExtraField `yaml:"statics,omitempty"` //Whitelists - Whitelist types.Whitelist `yaml:"whitelist,omitempty"` + Whitelist Whitelist `yaml:"whitelist,omitempty"` Data []*types.DataSource `yaml:"data,omitempty"` } @@ -531,7 +531,7 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { valid = true } for _, filter := range n.Whitelist.Exprs { - expression := &types.ExprWhitelist{} + expression := &ExprWhitelist{} expression.Filter, err = expr.Compile(filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) if err != nil { n.Logger.Fatalf("Unable to compile whitelist expression '%s' : %v.", filter, err) diff --git a/pkg/parser/parsing_test.go b/pkg/parser/parsing_test.go index ba87438b7..6a6e8dbf0 100644 --- a/pkg/parser/parsing_test.go +++ b/pkg/parser/parsing_test.go @@ -146,7 +146,7 @@ func prepTests() (*UnixParserCtx, EnricherCtx, error) { ectx EnricherCtx ) - err = exprhelpers.Init() + err = exprhelpers.Init(nil) if err != nil { log.Fatalf("exprhelpers init failed: %s", err) } diff --git a/pkg/types/whitelist.go b/pkg/parser/whitelist.go similarity index 97% rename from pkg/types/whitelist.go rename to pkg/parser/whitelist.go index bc8723222..c56ad40a7 100644 --- a/pkg/types/whitelist.go +++ b/pkg/parser/whitelist.go @@ -1,4 +1,4 @@ -package types +package parser import ( "net"