diff --git a/cmd/crowdsec-cli/config_show.go b/cmd/crowdsec-cli/config_show.go index 82f56732d..2e1fc7092 100644 --- a/cmd/crowdsec-cli/config_show.go +++ b/cmd/crowdsec-cli/config_show.go @@ -7,11 +7,12 @@ import ( "text/template" "github.com/antonmedv/expr" + log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "gopkg.in/yaml.v2" - log "github.com/sirupsen/logrus" "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" ) func showConfigKey(key string) error { @@ -19,7 +20,10 @@ func showConfigKey(key string) error { Config *csconfig.Config } - program, err := expr.Compile(key, expr.Env(Env{})) + opts := []expr.Option{} + opts = append(opts, exprhelpers.GetExprOptions(map[string]interface{}{})...) + opts = append(opts, expr.Env(Env{})) + program, err := expr.Compile(key, opts...) if err != nil { return err } @@ -50,7 +54,6 @@ func showConfigKey(key string) error { return nil } - var configShowTemplate = `Global: {{- if .ConfigPaths }} @@ -172,7 +175,6 @@ Central API: {{- end }} ` - func runConfigShow(cmd *cobra.Command, args []string) error { flags := cmd.Flags() @@ -218,7 +220,6 @@ func runConfigShow(cmd *cobra.Command, args []string) error { return nil } - func NewConfigShowCmd() *cobra.Command { cmdConfigShow := &cobra.Command{ Use: "show", diff --git a/go.mod b/go.mod index 8b94ef948..735090136 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/AlecAivazis/survey/v2 v2.2.7 github.com/Microsoft/go-winio v0.5.2 // indirect github.com/alexliesenfeld/health v0.5.1 - github.com/antonmedv/expr v1.12.2 + github.com/antonmedv/expr v1.12.5 github.com/appleboy/gin-jwt/v2 v2.8.0 github.com/aws/aws-sdk-go v1.42.25 github.com/buger/jsonparser v1.1.1 diff --git a/go.sum b/go.sum index 3b2bc3b13..d78bfb3fd 100644 --- a/go.sum +++ b/go.sum @@ -90,6 +90,10 @@ github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antonmedv/expr v1.12.2 h1:nlRcu4uHI6oSKCf6GHJTcT7hIf7dFAjgfvG0MWb7Cu0= github.com/antonmedv/expr v1.12.2/go.mod h1:FPC8iWArxls7axbVLsW+kpg1mz29A1b2M6jt+hZfDkU= +github.com/antonmedv/expr v1.12.4 h1:YRkeF7r0cejMS47bDYe3Jyes7L9t1AhpunC+Duq+R9k= +github.com/antonmedv/expr v1.12.4/go.mod h1:FPC8iWArxls7axbVLsW+kpg1mz29A1b2M6jt+hZfDkU= +github.com/antonmedv/expr v1.12.5 h1:Fq4okale9swwL3OeLLs9WD9H6GbgBLJyN/NUHRv+n0E= +github.com/antonmedv/expr v1.12.5/go.mod h1:FPC8iWArxls7axbVLsW+kpg1mz29A1b2M6jt+hZfDkU= github.com/apparentlymart/go-textseg/v13 v13.0.0 h1:Y+KvPE1NYz0xl601PVImeQfFyEy6iT90AvPUL1NNfNw= github.com/apparentlymart/go-textseg/v13 v13.0.0/go.mod h1:ZK2fH7c4NqDTLtiYLvIkEghdlcqw7yxLeM89kiTRPUo= github.com/appleboy/gin-jwt/v2 v2.8.0 h1:Glo7cb9eBR+hj8Y7WzgfkOlqCaNLjP+RV4dNO3fpdps= diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_test.go b/pkg/acquisition/modules/wineventlog/wineventlog_test.go index a24a63efc..54fddc3d8 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_test.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_test.go @@ -214,8 +214,8 @@ event_ids: } t.Fatalf("timeout") case e := <-c: - - linesRead = append(linesRead, exprhelpers.XMLGetNodeValue(e.Line.Raw, "/Event/EventData[1]/Data")) + line, _ := exprhelpers.XMLGetNodeValue(e.Line.Raw, "/Event/EventData[1]/Data") + linesRead = append(linesRead, line.(string)) if len(linesRead) == len(lines) { break READLOOP } diff --git a/pkg/alertcontext/alertcontext.go b/pkg/alertcontext/alertcontext.go index 9cf6a586c..16a6f3bfe 100644 --- a/pkg/alertcontext/alertcontext.go +++ b/pkg/alertcontext/alertcontext.go @@ -30,7 +30,7 @@ type Context struct { func ValidateContextExpr(key string, expressions []string) error { for _, expression := range expressions { - _, err := expr.Compile(expression, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + _, err := expr.Compile(expression, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return fmt.Errorf("compilation of '%s' failed: %v", expression, err) } @@ -63,7 +63,7 @@ func NewAlertContext(contextToSend map[string][]string, valueLength int) error { for key, values := range contextToSend { alertContext.ContextToSendCompiled[key] = make([]*vm.Program, 0) for _, value := range values { - valueCompiled, err := expr.Compile(value, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + valueCompiled, err := expr.Compile(value, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return fmt.Errorf("compilation of '%s' context value failed: %v", value, err) } @@ -117,7 +117,7 @@ func EventToContext(events []types.Event) (models.Meta, []error) { } for _, value := range values { var val string - output, err := expr.Run(value, exprhelpers.GetExprEnv(map[string]interface{}{"evt": evt})) + output, err := expr.Run(value, map[string]interface{}{"evt": evt}) if err != nil { errors = append(errors, fmt.Errorf("failed to get value for %s : %v", key, err)) continue diff --git a/pkg/csprofiles/csprofiles.go b/pkg/csprofiles/csprofiles.go index b0fe2c79f..29e6cdf36 100644 --- a/pkg/csprofiles/csprofiles.go +++ b/pkg/csprofiles/csprofiles.go @@ -53,12 +53,13 @@ func NewProfile(profilesCfg []*csconfig.ProfileCfg) ([]*Runtime, error) { return []*Runtime{}, errors.Wrapf(err, "invalid 'on_failure' for '%s' : %s", profile.Name, runtime.Cfg.OnFailure) } for fIdx, filter := range profile.Filters { - if runtimeFilter, err = expr.Compile(filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"Alert": &models.Alert{}}))); err != nil { + + if runtimeFilter, err = expr.Compile(filter, exprhelpers.GetExprOptions(map[string]interface{}{"Alert": &models.Alert{}})...); err != nil { return []*Runtime{}, errors.Wrapf(err, "error compiling filter of '%s'", profile.Name) } runtime.RuntimeFilters[fIdx] = runtimeFilter if profile.Debug != nil && *profile.Debug { - if debugFilter, err = exprhelpers.NewDebugger(filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"Alert": &models.Alert{}}))); err != nil { + if debugFilter, err = exprhelpers.NewDebugger(filter, exprhelpers.GetExprOptions(map[string]interface{}{"Alert": &models.Alert{}})...); err != nil { log.Debugf("Error compiling debug filter of %s : %s", profile.Name, err) // Don't fail if we can't compile the filter - for now // return errors.Wrapf(err, "Error compiling debug filter of %s", profile.Name) @@ -69,13 +70,13 @@ func NewProfile(profilesCfg []*csconfig.ProfileCfg) ([]*Runtime, error) { } if profile.DurationExpr != "" { - if runtimeDurationExpr, err = expr.Compile(profile.DurationExpr, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"Alert": &models.Alert{}}))); err != nil { + if runtimeDurationExpr, err = expr.Compile(profile.DurationExpr, exprhelpers.GetExprOptions(map[string]interface{}{"Alert": &models.Alert{}})...); err != nil { return []*Runtime{}, errors.Wrapf(err, "error compiling duration_expr of %s", profile.Name) } runtime.RuntimeDurationExpr = runtimeDurationExpr if profile.Debug != nil && *profile.Debug { - if debugDurationExpr, err = exprhelpers.NewDebugger(profile.DurationExpr, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"Alert": &models.Alert{}}))); err != nil { + if debugDurationExpr, err = exprhelpers.NewDebugger(profile.DurationExpr, exprhelpers.GetExprOptions(map[string]interface{}{"Alert": &models.Alert{}})...); err != nil { log.Debugf("Error compiling debug duration_expr of %s : %s", profile.Name, err) } runtime.DebugDurationExpr = debugDurationExpr @@ -120,7 +121,7 @@ func (Profile *Runtime) GenerateDecisionFromProfile(Alert *models.Alert) ([]*mod /*some fields are populated from the reference object : duration, scope, type*/ decision.Duration = new(string) if Profile.Cfg.DurationExpr != "" && Profile.RuntimeDurationExpr != nil { - duration, err := expr.Run(Profile.RuntimeDurationExpr, exprhelpers.GetExprEnv(map[string]interface{}{"Alert": Alert})) + duration, err := expr.Run(Profile.RuntimeDurationExpr, map[string]interface{}{"Alert": Alert}) if err != nil { Profile.Logger.Warningf("Failed to run duration_expr : %v", err) *decision.Duration = *refDecision.Duration @@ -164,7 +165,7 @@ func (Profile *Runtime) EvaluateProfile(Alert *models.Alert) ([]*models.Decision matched := false for eIdx, expression := range Profile.RuntimeFilters { - output, err := expr.Run(expression, exprhelpers.GetExprEnv(map[string]interface{}{"Alert": Alert})) + output, err := expr.Run(expression, map[string]interface{}{"Alert": Alert}) if err != nil { Profile.Logger.Warningf("failed to run whitelist expr : %v", err) return nil, matched, errors.Wrapf(err, "while running expression %s", Profile.Cfg.Filters[eIdx]) @@ -172,7 +173,7 @@ func (Profile *Runtime) EvaluateProfile(Alert *models.Alert) ([]*models.Decision switch out := output.(type) { case bool: if Profile.Cfg.Debug != nil && *Profile.Cfg.Debug { - Profile.DebugFilters[eIdx].Run(Profile.Logger, out, exprhelpers.GetExprEnv(map[string]interface{}{"Alert": Alert})) + Profile.DebugFilters[eIdx].Run(Profile.Logger, out, map[string]interface{}{"Alert": Alert}) } if out { matched = true diff --git a/pkg/csprofiles/csprofiles_test.go b/pkg/csprofiles/csprofiles_test.go index d9a757092..81b21e1fb 100644 --- a/pkg/csprofiles/csprofiles_test.go +++ b/pkg/csprofiles/csprofiles_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/models" "gotest.tools/v3/assert" ) @@ -104,6 +105,9 @@ func TestEvaluateProfile(t *testing.T) { profileCfg *csconfig.ProfileCfg Alert *models.Alert } + + exprhelpers.Init(nil) + tests := []struct { name string args args diff --git a/pkg/exprhelpers/crowdsec_cti.go b/pkg/exprhelpers/crowdsec_cti.go index 5596f8278..6440295c8 100644 --- a/pkg/exprhelpers/crowdsec_cti.go +++ b/pkg/exprhelpers/crowdsec_cti.go @@ -73,7 +73,9 @@ func CrowdsecCTIInitCache(size int, ttl time.Duration) { CacheExpiration = ttl } -func CrowdsecCTI(ip string) (*cticlient.SmokeItem, error) { +// func CrowdsecCTI(ip string) (*cticlient.SmokeItem, error) { +func CrowdsecCTI(params ...any) (any, error) { + ip := params[0].(string) if !CTIApiEnabled { ctiClient.Logger.Warningf("Crowdsec CTI API is disabled, please check your configuration") return &cticlient.SmokeItem{}, cticlient.ErrDisabled diff --git a/pkg/exprhelpers/crowdsec_cti_test.go b/pkg/exprhelpers/crowdsec_cti_test.go index 7e9bc0c31..41afcd6ee 100644 --- a/pkg/exprhelpers/crowdsec_cti_test.go +++ b/pkg/exprhelpers/crowdsec_cti_test.go @@ -157,13 +157,16 @@ func TestCache(t *testing.T) { })) item, err := CrowdsecCTI("1.2.3.4") - assert.Equal(t, "1.2.3.4", item.Ip) + ctiResp := item.(*cticlient.SmokeItem) + assert.Equal(t, "1.2.3.4", ctiResp.Ip) assert.Equal(t, CTIApiEnabled, true) assert.Equal(t, CTICache.Len(true), 1) assert.Equal(t, err, nil) item, err = CrowdsecCTI("1.2.3.4") - assert.Equal(t, "1.2.3.4", item.Ip) + ctiResp = item.(*cticlient.SmokeItem) + + assert.Equal(t, "1.2.3.4", ctiResp.Ip) assert.Equal(t, CTIApiEnabled, true) assert.Equal(t, CTICache.Len(true), 1) assert.Equal(t, err, nil) @@ -173,7 +176,9 @@ func TestCache(t *testing.T) { assert.Equal(t, CTICache.Len(true), 0) item, err = CrowdsecCTI("1.2.3.4") - assert.Equal(t, "1.2.3.4", item.Ip) + ctiResp = item.(*cticlient.SmokeItem) + + assert.Equal(t, "1.2.3.4", ctiResp.Ip) assert.Equal(t, CTIApiEnabled, true) assert.Equal(t, CTICache.Len(true), 1) assert.Equal(t, err, nil) diff --git a/pkg/exprhelpers/expr_lib.go b/pkg/exprhelpers/expr_lib.go new file mode 100644 index 000000000..13e58b877 --- /dev/null +++ b/pkg/exprhelpers/expr_lib.go @@ -0,0 +1,384 @@ +package exprhelpers + +import ( + "time" + + "github.com/crowdsecurity/crowdsec/pkg/cticlient" +) + +type exprCustomFunc struct { + name string + function func(params ...any) (any, error) + signature []interface{} +} + +var exprFuncs = []exprCustomFunc{ + { + name: "CrowdsecCTI", + function: CrowdsecCTI, + signature: []interface{}{ + new(func(string) (*cticlient.SmokeItem, error)), + }, + }, + { + name: "Distance", + function: Distance, + signature: []interface{}{ + new(func(string, string, string, string) (float64, error)), + }, + }, + { + name: "GetFromStash", + function: GetFromStash, + signature: []interface{}{ + new(func(string, string) (string, error)), + }, + }, + { + name: "Atof", + function: Atof, + signature: []interface{}{ + new(func(string) float64), + }, + }, + { + name: "JsonExtract", + function: JsonExtract, + signature: []interface{}{ + new(func(string, string) string), + }, + }, + { + name: "JsonExtractUnescape", + function: JsonExtractUnescape, + signature: []interface{}{ + new(func(string, ...string) string), + }, + }, + { + name: "JsonExtractLib", + function: JsonExtractLib, + signature: []interface{}{ + new(func(string, ...string) string), + }, + }, + { + name: "JsonExtractSlice", + function: JsonExtractSlice, + signature: []interface{}{ + new(func(string, string) []interface{}), + }, + }, + { + name: "JsonExtractObject", + function: JsonExtractObject, + signature: []interface{}{ + new(func(string, string) map[string]interface{}), + }, + }, + { + name: "ToJsonString", + function: ToJson, + signature: []interface{}{ + new(func(interface{}) string), + }, + }, + { + name: "File", + function: File, + signature: []interface{}{ + new(func(string) []string), + }, + }, + { + name: "RegexpInFile", + function: RegexpInFile, + signature: []interface{}{ + new(func(string, string) bool), + }, + }, + { + name: "Upper", + function: Upper, + signature: []interface{}{ + new(func(string) string), + }, + }, + { + name: "Lower", + function: Lower, + signature: []interface{}{ + new(func(string) string), + }, + }, + { + name: "IpInRange", + function: IpInRange, + signature: []interface{}{ + new(func(string, string) bool), + }, + }, + { + name: "TimeNow", + function: TimeNow, + signature: []interface{}{ + new(func() string), + }, + }, + { + name: "ParseUri", + function: ParseUri, + signature: []interface{}{ + new(func(string) map[string][]string), + }, + }, + { + name: "PathUnescape", + function: PathUnescape, + signature: []interface{}{ + new(func(string) string), + }, + }, + { + name: "QueryUnescape", + function: QueryUnescape, + signature: []interface{}{ + new(func(string) string), + }, + }, + { + name: "PathEscape", + function: PathEscape, + signature: []interface{}{ + new(func(string) string), + }, + }, + { + name: "QueryEscape", + function: QueryEscape, + signature: []interface{}{ + new(func(string) string), + }, + }, + { + name: "XMLGetAttributeValue", + function: XMLGetAttributeValue, + signature: []interface{}{ + new(func(string, string, string) string), + }, + }, + { + name: "XMLGetNodeValue", + function: XMLGetNodeValue, + signature: []interface{}{ + new(func(string, string) string), + }, + }, + { + name: "IpToRange", + function: IpToRange, + signature: []interface{}{ + new(func(string, string) string), + }, + }, + { + name: "IsIPV6", + function: IsIPV6, + signature: []interface{}{ + new(func(string) bool), + }, + }, + { + name: "IsIPV4", + function: IsIPV4, + signature: []interface{}{ + new(func(string) bool), + }, + }, + { + name: "IsIP", + function: IsIP, + signature: []interface{}{ + new(func(string) bool), + }, + }, + { + name: "LookupHost", + function: LookupHost, + signature: []interface{}{ + new(func(string) []string), + }, + }, + { + name: "GetDecisionsCount", + function: GetDecisionsCount, + signature: []interface{}{ + new(func(string) int), + }, + }, + { + name: "GetDecisionsSinceCount", + function: GetDecisionsSinceCount, + signature: []interface{}{ + new(func(string, string) int), + }, + }, + { + name: "Sprintf", + function: Sprintf, + signature: []interface{}{ + new(func(string, ...interface{}) string), + }, + }, + { + name: "ParseUnix", + function: ParseUnix, + signature: []interface{}{ + new(func(string) string), + }, + }, + { + name: "SetInStash", //FIXME: signature will probably blow everything up + function: SetInStash, + signature: []interface{}{ + new(func(string, string, string, *time.Duration) error), + }, + }, + { + name: "Fields", + function: Fields, + signature: []interface{}{ + new(func(string) []string), + }, + }, + { + name: "Index", + function: Index, + signature: []interface{}{ + new(func(string, string) int), + }, + }, + { + name: "IndexAny", + function: IndexAny, + signature: []interface{}{ + new(func(string, string) int), + }, + }, + { + name: "Join", + function: Join, + signature: []interface{}{ + new(func([]string, string) string), + }, + }, + { + name: "Split", + function: Split, + signature: []interface{}{ + new(func(string, string) []string), + }, + }, + { + name: "SplitAfter", + function: SplitAfter, + signature: []interface{}{ + new(func(string, string) []string), + }, + }, + { + name: "SplitAfterN", + function: SplitAfterN, + signature: []interface{}{ + new(func(string, string, int) []string), + }, + }, + { + name: "SplitN", + function: SplitN, + signature: []interface{}{ + new(func(string, string, int) []string), + }, + }, + { + name: "Replace", + function: Replace, + signature: []interface{}{ + new(func(string, string, string, int) string), + }, + }, + { + name: "ReplaceAll", + function: ReplaceAll, + signature: []interface{}{ + new(func(string, string, string) string), + }, + }, + { + name: "Trim", + function: Trim, + signature: []interface{}{ + new(func(string, string) string), + }, + }, + { + name: "TrimLeft", + function: TrimLeft, + signature: []interface{}{ + new(func(string, string) string), + }, + }, + { + name: "TrimRight", + function: TrimRight, + signature: []interface{}{ + new(func(string, string) string), + }, + }, + { + name: "TrimSpace", + function: TrimSpace, + signature: []interface{}{ + new(func(string) string), + }, + }, + { + name: "TrimPrefix", + function: TrimPrefix, + signature: []interface{}{ + new(func(string, string) string), + }, + }, + { + name: "TrimSuffix", + function: TrimSuffix, + signature: []interface{}{ + new(func(string, string) string), + }, + }, + { + name: "Get", + function: Get, + signature: []interface{}{ + new(func([]string, int) string), + }, + }, + { + name: "ToString", + function: ToString, + signature: []interface{}{ + new(func(interface{}) string), + }, + }, + { + name: "Match", + function: Match, + signature: []interface{}{ + new(func(string, string) bool), + }, + }, +} + +//go 1.20 "CutPrefix": strings.CutPrefix, +//go 1.20 "CutSuffix": strings.CutSuffix, +//"Cut": strings.Cut, -> returns more than 2 values, not supported by expr diff --git a/pkg/exprhelpers/exprlib_test.go b/pkg/exprhelpers/exprlib_test.go index 49f1ba1c9..7fb471680 100644 --- a/pkg/exprhelpers/exprlib_test.go +++ b/pkg/exprhelpers/exprlib_test.go @@ -100,17 +100,17 @@ func TestVisitor(t *testing.T) { }) for _, test := range tests { - compiledFilter, err := expr.Compile(test.filter, expr.Env(GetExprEnv(test.env))) + compiledFilter, err := expr.Compile(test.filter, GetExprOptions(test.env)...) if err != nil && test.err == nil { log.Fatalf("compile: %s", err) } - debugFilter, err := NewDebugger(test.filter, expr.Env(GetExprEnv(test.env))) + debugFilter, err := NewDebugger(test.filter, GetExprOptions(test.env)...) if err != nil && test.err == nil { log.Fatalf("debug: %s", err) } if compiledFilter != nil { - result, err := expr.Run(compiledFilter, GetExprEnv(test.env)) + result, err := expr.Run(compiledFilter, test.env) if err != nil && test.err == nil { log.Fatalf("run : %s", err) } @@ -120,37 +120,49 @@ func TestVisitor(t *testing.T) { } if debugFilter != nil { - debugFilter.Run(clog, test.result, GetExprEnv(test.env)) + debugFilter.Run(clog, test.result, test.env) } } } func TestMatch(t *testing.T) { + err := Init(nil) + require.NoError(t, err) tests := []struct { glob string val string ret bool + expr string }{ - {"foo", "foo", true}, - {"foo", "bar", false}, - {"foo*", "foo", true}, - {"foo*", "foobar", true}, - {"foo*", "barfoo", false}, - {"foo*", "bar", false}, - {"*foo", "foo", true}, - {"*foo", "barfoo", true}, - {"foo*r", "foobar", true}, - {"foo*r", "foobazr", true}, - {"foo?ar", "foobar", true}, - {"foo?ar", "foobazr", false}, - {"foo?ar", "foobaz", false}, - {"*foo?ar?", "foobar", false}, - {"*foo?ar?", "foobare", true}, - {"*foo?ar?", "rafoobar", false}, - {"*foo?ar?", "rafoobare", true}, + {"foo", "foo", true, `Match(pattern, name)`}, + {"foo", "bar", false, `Match(pattern, name)`}, + {"foo*", "foo", true, `Match(pattern, name)`}, + {"foo*", "foobar", true, `Match(pattern, name)`}, + {"foo*", "barfoo", false, `Match(pattern, name)`}, + {"foo*", "bar", false, `Match(pattern, name)`}, + {"*foo", "foo", true, `Match(pattern, name)`}, + {"*foo", "barfoo", true, `Match(pattern, name)`}, + {"foo*r", "foobar", true, `Match(pattern, name)`}, + {"foo*r", "foobazr", true, `Match(pattern, name)`}, + {"foo?ar", "foobar", true, `Match(pattern, name)`}, + {"foo?ar", "foobazr", false, `Match(pattern, name)`}, + {"foo?ar", "foobaz", false, `Match(pattern, name)`}, + {"*foo?ar?", "foobar", false, `Match(pattern, name)`}, + {"*foo?ar?", "foobare", true, `Match(pattern, name)`}, + {"*foo?ar?", "rafoobar", false, `Match(pattern, name)`}, + {"*foo?ar?", "rafoobare", true, `Match(pattern, name)`}, } for _, test := range tests { - ret := Match(test.glob, test.val) + env := map[string]interface{}{ + "pattern": test.glob, + "name": test.val, + } + vm, err := expr.Compile(test.expr, GetExprOptions(env)...) + if err != nil { + t.Fatalf("pattern:%s val:%s NOK %s", test.glob, test.val, err) + } + ret, err := expr.Run(vm, env) + assert.NoError(t, err) if isOk := assert.Equal(t, test.ret, ret); !isOk { t.Fatalf("pattern:%s val:%s NOK %t != %t", test.glob, test.val, ret, test.ret) } @@ -158,19 +170,45 @@ func TestMatch(t *testing.T) { } func TestDistanceHelper(t *testing.T) { + err := Init(nil) + require.NoError(t, err) - //one set of coord is empty - ret, err := Distance("0.0", "0.0", "12.1", "12.1") - assert.NoError(t, err) - assert.Equal(t, 0.0, ret) - //those aren't even coords - ret, err = Distance("lol", "42.1", "12.1", "12.1") - assert.NotNil(t, err) - assert.Equal(t, 0.0, ret) - //real ones - ret, err = Distance("51.45", "1.15", "41.54", "12.27") - assert.NoError(t, err) - assert.Equal(t, 1389.1793118293067, ret) + tests := []struct { + lat1 string + lon1 string + lat2 string + lon2 string + dist float64 + valid bool + expr string + name string + }{ + {"51.45", "1.15", "41.54", "12.27", 1389.1793118293067, true, `Distance(lat1, lon1, lat2, lon2)`, "valid"}, + {"lol", "1.15", "41.54", "12.27", 0.0, false, `Distance(lat1, lon1, lat2, lon2)`, "invalid lat1"}, + {"0.0", "0.0", "12.1", "12.1", 0.0, true, `Distance(lat1, lon1, lat2, lon2)`, "empty coord"}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + env := map[string]interface{}{ + "lat1": test.lat1, + "lon1": test.lon1, + "lat2": test.lat2, + "lon2": test.lon2, + } + vm, err := expr.Compile(test.expr, GetExprOptions(env)...) + if err != nil { + t.Fatalf("pattern:%s val:%s NOK %s", test.lat1, test.lon1, err) + } + ret, err := expr.Run(vm, env) + if test.valid { + assert.NoError(t, err) + assert.Equal(t, test.dist, ret) + } else { + assert.NotNil(t, err) + } + }) + } } func TestRegexpCacheBehavior(t *testing.T) { @@ -185,12 +223,12 @@ func TestRegexpCacheBehavior(t *testing.T) { err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: types.IntPtr(1)}) require.NoError(t, err) - ret := RegexpInFile("crowdsec", filename) - assert.False(t, ret) + ret, _ := RegexpInFile("crowdsec", filename) + assert.False(t, ret.(bool)) assert.Equal(t, 1, dataFileRegexCache[filename].Len(false)) - ret = RegexpInFile("Crowdsec", filename) - assert.True(t, ret) + ret, _ = RegexpInFile("Crowdsec", filename) + assert.True(t, ret.(bool)) assert.Equal(t, 1, dataFileRegexCache[filename].Len(false)) //cache with TTL @@ -198,8 +236,8 @@ func TestRegexpCacheBehavior(t *testing.T) { err = RegexpCacheInit(filename, types.DataSource{Type: "regex", Size: types.IntPtr(2), TTL: &ttl}) require.NoError(t, err) - ret = RegexpInFile("crowdsec", filename) - assert.False(t, ret) + ret, _ = RegexpInFile("crowdsec", filename) + assert.False(t, ret.(bool)) assert.Equal(t, 1, dataFileRegexCache[filename].Len(true)) time.Sleep(1 * time.Second) @@ -249,11 +287,11 @@ func TestRegexpInFile(t *testing.T) { } for _, test := range tests { - compiledFilter, err := expr.Compile(test.filter, expr.Env(GetExprEnv(map[string]interface{}{}))) + compiledFilter, err := expr.Compile(test.filter, GetExprOptions(map[string]interface{}{})...) if err != nil { log.Fatal(err) } - result, err := expr.Run(compiledFilter, GetExprEnv(map[string]interface{}{})) + result, err := expr.Run(compiledFilter, map[string]interface{}{}) if err != nil { log.Fatal(err) } @@ -374,11 +412,11 @@ func TestFile(t *testing.T) { } for _, test := range tests { - compiledFilter, err := expr.Compile(test.filter, expr.Env(GetExprEnv(map[string]interface{}{}))) + compiledFilter, err := expr.Compile(test.filter, GetExprOptions(map[string]interface{}{})...) if err != nil { log.Fatal(err) } - result, err := expr.Run(compiledFilter, GetExprEnv(map[string]interface{}{})) + result, err := expr.Run(compiledFilter, map[string]interface{}{}) if err != nil { log.Fatal(err) } @@ -391,6 +429,8 @@ func TestFile(t *testing.T) { } func TestIpInRange(t *testing.T) { + err := Init(nil) + assert.NoError(t, err) tests := []struct { name string env map[string]interface{} @@ -401,9 +441,8 @@ func TestIpInRange(t *testing.T) { { name: "IpInRange() test: basic test", env: map[string]interface{}{ - "ip": "192.168.0.1", - "ipRange": "192.168.0.0/24", - "IpInRange": IpInRange, + "ip": "192.168.0.1", + "ipRange": "192.168.0.0/24", }, code: "IpInRange(ip, ipRange)", result: true, @@ -412,9 +451,8 @@ func TestIpInRange(t *testing.T) { { name: "IpInRange() test: malformed IP", env: map[string]interface{}{ - "ip": "192.168.0", - "ipRange": "192.168.0.0/24", - "IpInRange": IpInRange, + "ip": "192.168.0", + "ipRange": "192.168.0.0/24", }, code: "IpInRange(ip, ipRange)", result: false, @@ -423,9 +461,8 @@ func TestIpInRange(t *testing.T) { { name: "IpInRange() test: malformed IP range", env: map[string]interface{}{ - "ip": "192.168.0.0/255", - "ipRange": "192.168.0.0/24", - "IpInRange": IpInRange, + "ip": "192.168.0.0/255", + "ipRange": "192.168.0.0/24", }, code: "IpInRange(ip, ipRange)", result: false, @@ -434,7 +471,7 @@ func TestIpInRange(t *testing.T) { } for _, test := range tests { - program, err := expr.Compile(test.code, expr.Env(test.env)) + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) require.NoError(t, err) output, err := expr.Run(program, test.env) require.NoError(t, err) @@ -445,6 +482,8 @@ func TestIpInRange(t *testing.T) { } func TestIpToRange(t *testing.T) { + err := Init(nil) + assert.NoError(t, err) tests := []struct { name string env map[string]interface{} @@ -455,9 +494,8 @@ func TestIpToRange(t *testing.T) { { name: "IpToRange() test: IPv4", env: map[string]interface{}{ - "ip": "192.168.1.1", - "netmask": "16", - "IpToRange": IpToRange, + "ip": "192.168.1.1", + "netmask": "16", }, code: "IpToRange(ip, netmask)", result: "192.168.0.0/16", @@ -466,9 +504,8 @@ func TestIpToRange(t *testing.T) { { name: "IpToRange() test: IPv6", env: map[string]interface{}{ - "ip": "2001:db8::1", - "netmask": "/64", - "IpToRange": IpToRange, + "ip": "2001:db8::1", + "netmask": "/64", }, code: "IpToRange(ip, netmask)", result: "2001:db8::/64", @@ -477,9 +514,8 @@ func TestIpToRange(t *testing.T) { { name: "IpToRange() test: malformed netmask", env: map[string]interface{}{ - "ip": "192.168.0.1", - "netmask": "test", - "IpToRange": IpToRange, + "ip": "192.168.0.1", + "netmask": "test", }, code: "IpToRange(ip, netmask)", result: "", @@ -488,9 +524,8 @@ func TestIpToRange(t *testing.T) { { name: "IpToRange() test: malformed IP", env: map[string]interface{}{ - "ip": "a.b.c.d", - "netmask": "24", - "IpToRange": IpToRange, + "ip": "a.b.c.d", + "netmask": "24", }, code: "IpToRange(ip, netmask)", result: "", @@ -499,9 +534,8 @@ func TestIpToRange(t *testing.T) { { name: "IpToRange() test: too high netmask", env: map[string]interface{}{ - "ip": "192.168.1.1", - "netmask": "35", - "IpToRange": IpToRange, + "ip": "192.168.1.1", + "netmask": "35", }, code: "IpToRange(ip, netmask)", result: "", @@ -510,7 +544,7 @@ func TestIpToRange(t *testing.T) { } for _, test := range tests { - program, err := expr.Compile(test.code, expr.Env(test.env)) + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) require.NoError(t, err) output, err := expr.Run(program, test.env) require.NoError(t, err) @@ -521,39 +555,72 @@ func TestIpToRange(t *testing.T) { } func TestAtof(t *testing.T) { - testFloat := "1.5" - expectedFloat := 1.5 - if Atof(testFloat) != expectedFloat { - t.Fatalf("Atof should return 1.5 as a float") + err := Init(nil) + assert.NoError(t, err) + + tests := []struct { + name string + env map[string]interface{} + code string + result float64 + }{ + { + name: "Atof() test: basic test", + env: map[string]interface{}{ + "testFloat": "1.5", + }, + code: "Atof(testFloat)", + result: 1.5, + }, + { + name: "Atof() test: bad float", + env: map[string]interface{}{ + "testFloat": "1aaa.5", + }, + code: "Atof(testFloat)", + result: 0.0, + }, } - log.Printf("test 'Atof()' : OK") - - //bad float - testFloat = "1aaa.5" - expectedFloat = 0.0 - - if Atof(testFloat) != expectedFloat { - t.Fatalf("Atof should return a negative value (error) as a float got") + for _, test := range tests { + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) + require.NoError(t, err) + output, err := expr.Run(program, test.env) + require.NoError(t, err) + require.Equal(t, test.result, output) } - - log.Printf("test 'Atof()' : OK") } func TestUpper(t *testing.T) { testStr := "test" expectedStr := "TEST" - if Upper(testStr) != expectedStr { - t.Fatalf("Upper() should return test in upper case") + env := map[string]interface{}{ + "testStr": testStr, } - log.Printf("test 'Upper()' : OK") + err := Init(nil) + assert.NoError(t, err) + vm, err := expr.Compile("Upper(testStr)", GetExprOptions(env)...) + assert.NoError(t, err) + + out, err := expr.Run(vm, env) + + assert.NoError(t, err) + v, ok := out.(string) + if !ok { + t.Fatalf("Upper() should return a string") + } + + if v != expectedStr { + t.Fatalf("Upper() should return test in upper case") + } } func TestTimeNow(t *testing.T) { - ti, err := time.Parse(time.RFC3339, TimeNow()) + now, _ := TimeNow() + ti, err := time.Parse(time.RFC3339, now.(string)) if err != nil { t.Fatalf("Error parsing the return value of TimeNow: %s", err) } @@ -625,7 +692,7 @@ func TestParseUri(t *testing.T) { } for _, test := range tests { - program, err := expr.Compile(test.code, expr.Env(test.env)) + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) require.NoError(t, err) output, err := expr.Run(program, test.env) require.NoError(t, err) @@ -665,7 +732,7 @@ func TestQueryEscape(t *testing.T) { } for _, test := range tests { - program, err := expr.Compile(test.code, expr.Env(test.env)) + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) require.NoError(t, err) output, err := expr.Run(program, test.env) require.NoError(t, err) @@ -705,7 +772,7 @@ func TestPathEscape(t *testing.T) { } for _, test := range tests { - program, err := expr.Compile(test.code, expr.Env(test.env)) + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) require.NoError(t, err) output, err := expr.Run(program, test.env) require.NoError(t, err) @@ -745,7 +812,7 @@ func TestPathUnescape(t *testing.T) { } for _, test := range tests { - program, err := expr.Compile(test.code, expr.Env(test.env)) + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) require.NoError(t, err) output, err := expr.Run(program, test.env) require.NoError(t, err) @@ -785,7 +852,7 @@ func TestQueryUnescape(t *testing.T) { } for _, test := range tests { - program, err := expr.Compile(test.code, expr.Env(test.env)) + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) require.NoError(t, err) output, err := expr.Run(program, test.env) require.NoError(t, err) @@ -825,7 +892,7 @@ func TestLower(t *testing.T) { } for _, test := range tests { - program, err := expr.Compile(test.code, expr.Env(test.env)) + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) require.NoError(t, err) output, err := expr.Run(program, test.env) require.NoError(t, err) @@ -865,6 +932,9 @@ func TestGetDecisionsCount(t *testing.T) { assert.Error(t, errors.Errorf("Failed to create sample decision")) } + err = Init(dbClient) + assert.NoError(t, err) + tests := []struct { name string env map[string]interface{} @@ -885,10 +955,8 @@ func TestGetDecisionsCount(t *testing.T) { }, }, }, - "GetDecisionsCount": GetDecisionsCount, - "sprintf": fmt.Sprintf, }, - code: "sprintf('%d', GetDecisionsCount(Alert.GetValue()))", + code: "Sprintf('%d', GetDecisionsCount(Alert.GetValue()))", result: "1", err: "", }, @@ -905,19 +973,17 @@ func TestGetDecisionsCount(t *testing.T) { }, }, }, - "GetDecisionsCount": GetDecisionsCount, - "sprintf": fmt.Sprintf, }, - code: "sprintf('%d', GetDecisionsCount(Alert.GetValue()))", + code: "Sprintf('%d', GetDecisionsCount(Alert.GetValue()))", result: "0", err: "", }, } for _, test := range tests { - program, err := expr.Compile(test.code, expr.Env(GetExprEnv(test.env))) + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) require.NoError(t, err) - output, err := expr.Run(program, GetExprEnv(test.env)) + output, err := expr.Run(program, test.env) require.NoError(t, err) require.Equal(t, test.result, output) log.Printf("test '%s' : OK", test.name) @@ -970,6 +1036,9 @@ func TestGetDecisionsSinceCount(t *testing.T) { assert.Error(t, errors.Errorf("Failed to create sample decision")) } + err = Init(dbClient) + assert.NoError(t, err) + tests := []struct { name string env map[string]interface{} @@ -990,10 +1059,8 @@ func TestGetDecisionsSinceCount(t *testing.T) { }, }, }, - "GetDecisionsSinceCount": GetDecisionsSinceCount, - "sprintf": fmt.Sprintf, }, - code: "sprintf('%d', GetDecisionsSinceCount(Alert.GetValue(), '25h'))", + code: "Sprintf('%d', GetDecisionsSinceCount(Alert.GetValue(), '25h'))", result: "2", err: "", }, @@ -1010,10 +1077,8 @@ func TestGetDecisionsSinceCount(t *testing.T) { }, }, }, - "GetDecisionsSinceCount": GetDecisionsSinceCount, - "sprintf": fmt.Sprintf, }, - code: "sprintf('%d', GetDecisionsSinceCount(Alert.GetValue(), '1h'))", + code: "Sprintf('%d', GetDecisionsSinceCount(Alert.GetValue(), '1h'))", result: "1", err: "", }, @@ -1030,19 +1095,17 @@ func TestGetDecisionsSinceCount(t *testing.T) { }, }, }, - "GetDecisionsSinceCount": GetDecisionsSinceCount, - "sprintf": fmt.Sprintf, }, - code: "sprintf('%d', GetDecisionsSinceCount(Alert.GetValue(), '1h'))", + code: "Sprintf('%d', GetDecisionsSinceCount(Alert.GetValue(), '1h'))", result: "0", err: "", }, } for _, test := range tests { - program, err := expr.Compile(test.code, expr.Env(GetExprEnv(test.env))) + program, err := expr.Compile(test.code, GetExprOptions(test.env)...) require.NoError(t, err) - output, err := expr.Run(program, GetExprEnv(test.env)) + output, err := expr.Run(program, test.env) require.NoError(t, err) require.Equal(t, test.result, output) log.Printf("test '%s' : OK", test.name) @@ -1088,113 +1151,156 @@ func TestParseUnixTime(t *testing.T) { if tc.expectedErr != "" { return } - require.WithinDuration(t, tc.expected, output, time.Second) + require.WithinDuration(t, tc.expected, output.(time.Time), time.Second) }) } } func TestIsIp(t *testing.T) { + if err := Init(nil); err != nil { + log.Fatal(err) + } tests := []struct { - name string - method func(string) bool - value string - expected bool + name string + expr string + value string + expected bool + expectedBuildErr bool }{ { name: "IsIPV4() test: valid IPv4", - method: IsIPV4, + expr: `IsIPV4(value)`, value: "1.2.3.4", expected: true, }, { name: "IsIPV6() test: valid IPv6", - method: IsIPV6, + expr: `IsIPV6(value)`, value: "1.2.3.4", expected: false, }, { name: "IsIPV6() test: valid IPv6", - method: IsIPV6, + expr: `IsIPV6(value)`, value: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", expected: true, }, { name: "IsIPV4() test: valid IPv6", - method: IsIPV4, + expr: `IsIPV4(value)`, value: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", expected: false, }, { name: "IsIP() test: invalid IP", - method: IsIP, + expr: `IsIP(value)`, value: "foo.bar", expected: false, }, { name: "IsIP() test: valid IPv4", - method: IsIP, + expr: `IsIP(value)`, value: "1.2.3.4", expected: true, }, { name: "IsIP() test: valid IPv6", - method: IsIP, + expr: `IsIP(value)`, value: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", expected: true, }, { name: "IsIPV4() test: invalid IPv4", - method: IsIPV4, + expr: `IsIPV4(value)`, value: "foo.bar", expected: false, }, { name: "IsIPV6() test: invalid IPv6", - method: IsIPV6, + expr: `IsIPV6(value)`, value: "foo.bar", expected: false, }, + { + name: "IsIPV4() test: invalid type", + expr: `IsIPV4(42)`, + value: "", + expected: false, + expectedBuildErr: true, + }, + { + name: "IsIP() test: invalid type", + expr: `IsIP(42)`, + value: "", + expected: false, + expectedBuildErr: true, + }, + { + name: "IsIPV6() test: invalid type", + expr: `IsIPV6(42)`, + value: "", + expected: false, + expectedBuildErr: true, + }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - output := tc.method(tc.value) - require.Equal(t, tc.expected, output) + vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) + if tc.expectedBuildErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) + assert.NoError(t, err) + assert.IsType(t, tc.expected, output) + assert.Equal(t, tc.expected, output.(bool)) }) } } func TestToString(t *testing.T) { + err := Init(nil) + require.NoError(t, err) tests := []struct { name string value interface{} expected string + expr string }{ { name: "ToString() test: valid string", value: "foo", expected: "foo", + expr: `ToString(value)`, }, { name: "ToString() test: valid string", value: interface{}("foo"), expected: "foo", + expr: `ToString(value)`, }, { name: "ToString() test: invalid type", value: 1, expected: "", + expr: `ToString(value)`, }, { name: "ToString() test: invalid type 2", value: interface{}(nil), expected: "", + expr: `ToString(value)`, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - output := ToString(tc.value) + vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) + assert.NoError(t, err) + output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) + assert.NoError(t, err) require.Equal(t, tc.expected, output) }) } diff --git a/pkg/exprhelpers/exprlib.go b/pkg/exprhelpers/helpers.go similarity index 56% rename from pkg/exprhelpers/exprlib.go rename to pkg/exprhelpers/helpers.go index e1d82be4e..0583c9955 100644 --- a/pkg/exprhelpers/exprlib.go +++ b/pkg/exprhelpers/helpers.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/antonmedv/expr" "github.com/bluele/gcache" "github.com/c-robinson/iplib" "github.com/cespare/xxhash/v2" @@ -43,135 +44,29 @@ var RegexpCacheMetrics = prometheus.NewGaugeVec( var dbClient *database.Client -func Get(arr []string, index int) string { - if index >= len(arr) { - return "" - } - return arr[index] -} +var exprFunctionOptions []expr.Option -func Atof(x string) float64 { - log.Debugf("debug atof %s", x) - ret, err := strconv.ParseFloat(x, 64) - if err != nil { - log.Warningf("Atof : can't convert float '%s' : %v", x, err) - } +func GetExprOptions(ctx map[string]interface{}) []expr.Option { + ret := []expr.Option{} + ret = append(ret, exprFunctionOptions...) + ret = append(ret, expr.Env(ctx)) return ret } -func Upper(s string) string { - return strings.ToUpper(s) -} - -func Lower(s string) string { - return strings.ToLower(s) -} - -func GetExprEnv(ctx map[string]interface{}) map[string]interface{} { - var ExprLib = map[string]interface{}{ - "Atof": Atof, - "JsonExtract": JsonExtract, - "JsonExtractUnescape": JsonExtractUnescape, - "JsonExtractLib": JsonExtractLib, - "JsonExtractSlice": JsonExtractSlice, - "JsonExtractObject": JsonExtractObject, - "ToJsonString": ToJson, - "File": File, - "RegexpInFile": RegexpInFile, - "Upper": Upper, - "Lower": Lower, - "IpInRange": IpInRange, - "TimeNow": TimeNow, - "ParseUri": ParseUri, - "PathUnescape": PathUnescape, - "QueryUnescape": QueryUnescape, - "PathEscape": PathEscape, - "QueryEscape": QueryEscape, - "XMLGetAttributeValue": XMLGetAttributeValue, - "XMLGetNodeValue": XMLGetNodeValue, - "IpToRange": IpToRange, - "IsIPV6": IsIPV6, - "IsIPV4": IsIPV4, - "IsIP": IsIP, - "LookupHost": LookupHost, - "GetDecisionsCount": GetDecisionsCount, - "GetDecisionsSinceCount": GetDecisionsSinceCount, - "Sprintf": fmt.Sprintf, - "CrowdsecCTI": CrowdsecCTI, - "ParseUnix": ParseUnix, - "GetFromStash": cache.GetKey, - "SetInStash": cache.SetKey, - //go 1.20 "CutPrefix": strings.CutPrefix, - //go 1.20 "CutSuffix": strings.CutSuffix, - //"Cut": strings.Cut, -> returns more than 2 values, not supported by expr - "Fields": strings.Fields, - "Index": strings.Index, - "IndexAny": strings.IndexAny, - "Join": strings.Join, - "Split": strings.Split, - "SplitAfter": strings.SplitAfter, - "SplitAfterN": strings.SplitAfterN, - "SplitN": strings.SplitN, - "Replace": strings.Replace, - "ReplaceAll": strings.ReplaceAll, - "Trim": strings.Trim, - "TrimLeft": strings.TrimLeft, - "TrimRight": strings.TrimRight, - "TrimSpace": strings.TrimSpace, - "TrimPrefix": strings.TrimPrefix, - "TrimSuffix": strings.TrimSuffix, - "Get": Get, - "String": ToString, - "Distance": Distance, - "Match": Match, - } - for k, v := range ctx { - ExprLib[k] = v - } - return ExprLib -} - -func Distance(lat1 string, long1 string, lat2 string, long2 string) (float64, error) { - lat1f, err := strconv.ParseFloat(lat1, 64) - if err != nil { - log.Warningf("lat1 is not a float : %v", err) - return 0, fmt.Errorf("lat1 is not a float : %v", err) - } - long1f, err := strconv.ParseFloat(long1, 64) - if err != nil { - log.Warningf("long1 is not a float : %v", err) - return 0, fmt.Errorf("long1 is not a float : %v", err) - } - lat2f, err := strconv.ParseFloat(lat2, 64) - if err != nil { - log.Warningf("lat2 is not a float : %v", err) - - return 0, fmt.Errorf("lat2 is not a float : %v", err) - } - long2f, err := strconv.ParseFloat(long2, 64) - if err != nil { - log.Warningf("long2 is not a float : %v", err) - - return 0, fmt.Errorf("long2 is not a float : %v", err) - } - - //either set of coordinates is 0,0, return 0 to avoid FPs - if (lat1f == 0.0 && long1f == 0.0) || (lat2f == 0.0 && long2f == 0.0) { - log.Warningf("one of the coordinates is 0,0, returning 0") - return 0, nil - } - - first := haversine.Coord{Lat: lat1f, Lon: long1f} - second := haversine.Coord{Lat: lat2f, Lon: long2f} - - _, km := haversine.Distance(first, second) - return km, nil -} - func Init(databaseClient *database.Client) error { dataFile = make(map[string][]string) dataFileRegex = make(map[string][]*regexp.Regexp) dbClient = databaseClient + + exprFunctionOptions = []expr.Option{} + for _, function := range exprFuncs { + exprFunctionOptions = append(exprFunctionOptions, + expr.Function(function.name, + function.function, + function.signature..., + )) + } + return nil } @@ -263,43 +158,132 @@ func FileInit(fileFolder string, filename string, fileType string) error { return nil } -func QueryEscape(s string) string { - return url.QueryEscape(s) +//Expr helpers + +// func Get(arr []string, index int) string { +func Get(params ...any) (any, error) { + arr := params[0].([]string) + index := params[1].(int) + if index >= len(arr) { + return "", nil + } + return arr[index], nil } -func PathEscape(s string) string { - return url.PathEscape(s) +// func Atof(x string) float64 { +func Atof(params ...any) (any, error) { + x := params[0].(string) + log.Debugf("debug atof %s", x) + ret, err := strconv.ParseFloat(x, 64) + if err != nil { + log.Warningf("Atof : can't convert float '%s' : %v", x, err) + } + return ret, nil } -func PathUnescape(s string) string { +// func Upper(s string) string { +func Upper(params ...any) (any, error) { + s := params[0].(string) + return strings.ToUpper(s), nil +} + +// func Lower(s string) string { +func Lower(params ...any) (any, error) { + s := params[0].(string) + return strings.ToLower(s), nil +} + +// func Distance(lat1 string, long1 string, lat2 string, long2 string) (float64, error) { +func Distance(params ...any) (any, error) { + lat1 := params[0].(string) + long1 := params[1].(string) + lat2 := params[2].(string) + long2 := params[3].(string) + lat1f, err := strconv.ParseFloat(lat1, 64) + if err != nil { + log.Warningf("lat1 is not a float : %v", err) + return 0.0, fmt.Errorf("lat1 is not a float : %v", err) + } + long1f, err := strconv.ParseFloat(long1, 64) + if err != nil { + log.Warningf("long1 is not a float : %v", err) + return 0.0, fmt.Errorf("long1 is not a float : %v", err) + } + lat2f, err := strconv.ParseFloat(lat2, 64) + if err != nil { + log.Warningf("lat2 is not a float : %v", err) + + return 0.0, fmt.Errorf("lat2 is not a float : %v", err) + } + long2f, err := strconv.ParseFloat(long2, 64) + if err != nil { + log.Warningf("long2 is not a float : %v", err) + + return 0.0, fmt.Errorf("long2 is not a float : %v", err) + } + + //either set of coordinates is 0,0, return 0 to avoid FPs + if (lat1f == 0.0 && long1f == 0.0) || (lat2f == 0.0 && long2f == 0.0) { + log.Warningf("one of the coordinates is 0,0, returning 0") + return 0.0, nil + } + + first := haversine.Coord{Lat: lat1f, Lon: long1f} + second := haversine.Coord{Lat: lat2f, Lon: long2f} + + _, km := haversine.Distance(first, second) + return km, nil +} + +// func QueryEscape(s string) string { +func QueryEscape(params ...any) (any, error) { + s := params[0].(string) + return url.QueryEscape(s), nil +} + +// func PathEscape(s string) string { +func PathEscape(params ...any) (any, error) { + s := params[0].(string) + return url.PathEscape(s), nil +} + +// func PathUnescape(s string) string { +func PathUnescape(params ...any) (any, error) { + s := params[0].(string) ret, err := url.PathUnescape(s) if err != nil { log.Debugf("unable to PathUnescape '%s': %+v", s, err) - return s + return s, nil } - return ret + return ret, nil } -func QueryUnescape(s string) string { +// func QueryUnescape(s string) string { +func QueryUnescape(params ...any) (any, error) { + s := params[0].(string) ret, err := url.QueryUnescape(s) if err != nil { log.Debugf("unable to QueryUnescape '%s': %+v", s, err) - return s + return s, nil } - return ret + return ret, nil } -func File(filename string) []string { +// func File(filename string) []string { +func File(params ...any) (any, error) { + filename := params[0].(string) if _, ok := dataFile[filename]; ok { - return dataFile[filename] + return dataFile[filename], nil } log.Errorf("file '%s' (type:string) not found in expr library", filename) log.Errorf("expr library : %s", spew.Sdump(dataFile)) - return []string{} + return []string{}, nil } -func RegexpInFile(data string, filename string) bool { - +// func RegexpInFile(data string, filename string) bool { +func RegexpInFile(params ...any) (any, error) { + data := params[0].(string) + filename := params[1].(string) var hash uint64 hasCache := false @@ -307,7 +291,7 @@ func RegexpInFile(data string, filename string) bool { hasCache = true hash = xxhash.Sum64String(data) if val, err := dataFileRegexCache[filename].Get(hash); err == nil { - return val.(bool) + return val.(bool), nil } } @@ -317,7 +301,7 @@ func RegexpInFile(data string, filename string) bool { if hasCache { dataFileRegexCache[filename].Set(hash, true) } - return true + return true, nil } } } else { @@ -327,149 +311,177 @@ func RegexpInFile(data string, filename string) bool { if hasCache { dataFileRegexCache[filename].Set(hash, false) } - return false + return false, nil } -func IpInRange(ip string, ipRange string) bool { +// func IpInRange(ip string, ipRange string) bool { +func IpInRange(params ...any) (any, error) { var err error var ipParsed net.IP var ipRangeParsed *net.IPNet + ip := params[0].(string) + ipRange := params[1].(string) + ipParsed = net.ParseIP(ip) if ipParsed == nil { log.Debugf("'%s' is not a valid IP", ip) - return false + return false, nil } if _, ipRangeParsed, err = net.ParseCIDR(ipRange); err != nil { log.Debugf("'%s' is not a valid IP Range", ipRange) - return false + return false, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility } if ipRangeParsed.Contains(ipParsed) { - return true + return true, nil } - return false + return false, nil } -func IsIPV6(ip string) bool { +// func IsIPV6(ip string) bool { +func IsIPV6(params ...any) (any, error) { + ip := params[0].(string) ipParsed := net.ParseIP(ip) if ipParsed == nil { log.Debugf("'%s' is not a valid IP", ip) - return false + return false, nil } // If it's a valid IP and can't be converted to IPv4 then it is an IPv6 - return ipParsed.To4() == nil + return ipParsed.To4() == nil, nil } -func IsIPV4(ip string) bool { +// func IsIPV4(ip string) bool { +func IsIPV4(params ...any) (any, error) { + ip := params[0].(string) ipParsed := net.ParseIP(ip) if ipParsed == nil { log.Debugf("'%s' is not a valid IP", ip) - return false + return false, nil } - return ipParsed.To4() != nil + return ipParsed.To4() != nil, nil } -func IsIP(ip string) bool { +// func IsIP(ip string) bool { +func IsIP(params ...any) (any, error) { + ip := params[0].(string) ipParsed := net.ParseIP(ip) if ipParsed == nil { log.Debugf("'%s' is not a valid IP", ip) - return false + return false, nil } - return true + return true, nil } -func IpToRange(ip string, cidr string) string { +// func IpToRange(ip string, cidr string) string { +func IpToRange(params ...any) (any, error) { + ip := params[0].(string) + cidr := params[1].(string) cidr = strings.TrimPrefix(cidr, "/") mask, err := strconv.Atoi(cidr) if err != nil { log.Errorf("bad cidr '%s': %s", cidr, err) - return "" + return "", nil } ipAddr := net.ParseIP(ip) if ipAddr == nil { log.Errorf("can't parse IP address '%s'", ip) - return "" + return "", nil } ipRange := iplib.NewNet(ipAddr, mask) if ipRange.IP() == nil { log.Errorf("can't get cidr '%s' of '%s'", cidr, ip) - return "" + return "", nil } - return ipRange.String() + return ipRange.String(), nil } -func TimeNow() string { - return time.Now().UTC().Format(time.RFC3339) +// func TimeNow() string { +func TimeNow(params ...any) (any, error) { + return time.Now().UTC().Format(time.RFC3339), nil } -func ParseUri(uri string) map[string][]string { +// func ParseUri(uri string) map[string][]string { +func ParseUri(params ...any) (any, error) { + uri := params[0].(string) ret := make(map[string][]string) u, err := url.Parse(uri) if err != nil { log.Errorf("Could not parse URI: %s", err) - return ret + return ret, nil } parsed, err := url.ParseQuery(u.RawQuery) if err != nil { log.Errorf("Could not parse query uri : %s", err) - return ret + return ret, nil } for k, v := range parsed { ret[k] = v } - return ret + return ret, nil } -func KeyExists(key string, dict map[string]interface{}) bool { +// func KeyExists(key string, dict map[string]interface{}) bool { +func KeyExists(params ...any) (any, error) { + key := params[0].(string) + dict := params[1].(map[string]interface{}) _, ok := dict[key] - return ok + return ok, nil } -func GetDecisionsCount(value string) int { +// func GetDecisionsCount(value string) int { +func GetDecisionsCount(params ...any) (any, error) { + value := params[0].(string) if dbClient == nil { log.Error("No database config to call GetDecisionsCount()") - return 0 + return 0, nil } count, err := dbClient.CountDecisionsByValue(value) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) - return 0 + return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility } - return count + return count, nil } -func GetDecisionsSinceCount(value string, since string) int { +// func GetDecisionsSinceCount(value string, since string) int { +func GetDecisionsSinceCount(params ...any) (any, error) { + value := params[0].(string) + since := params[1].(string) if dbClient == nil { log.Error("No database config to call GetDecisionsCount()") - return 0 + return 0, nil } sinceDuration, err := time.ParseDuration(since) if err != nil { log.Errorf("Failed to parse since parameter '%s' : %s", since, err) - return 0 + return 0, nil } sinceTime := time.Now().UTC().Add(-sinceDuration) count, err := dbClient.CountDecisionsSinceByValue(value, sinceTime) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) - return 0 + return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility } - return count + return count, nil } -func LookupHost(value string) []string { +// func LookupHost(value string) []string { +func LookupHost(params ...any) (any, error) { + value := params[0].(string) addresses, err := net.LookupHost(value) if err != nil { log.Errorf("Failed to lookup host '%s' : %s", value, err) - return []string{} + return []string{}, nil } - return addresses + return addresses, nil } -func ParseUnixTime(value string) (time.Time, error) { +// func ParseUnixTime(value string) (time.Time, error) { +func ParseUnixTime(params ...any) (any, error) { + value := params[0].(string) //Splitting string here as some unix timestamp may have milliseconds and break ParseInt i, err := strconv.ParseInt(strings.Split(value, ".")[0], 10, 64) if err != nil || i <= 0 { @@ -478,44 +490,75 @@ func ParseUnixTime(value string) (time.Time, error) { return time.Unix(i, 0), nil } -func ParseUnix(value string) string { +// func ParseUnix(value string) string { +func ParseUnix(params ...any) (any, error) { + value := params[0].(string) t, err := ParseUnixTime(value) if err != nil { log.Error(err) - return "" + return "", nil } - return t.Format(time.RFC3339) + return t.(time.Time).Format(time.RFC3339), nil } -func ToString(value interface{}) string { +// func ToString(value interface{}) string { +func ToString(params ...any) (any, error) { + value := params[0] s, ok := value.(string) if !ok { - return "" + return "", nil } - return s + return s, nil } -func Match(pattern, name string) bool { +// func GetFromStash(cacheName string, key string) (string, error) { +func GetFromStash(params ...any) (any, error) { + cacheName := params[0].(string) + key := params[1].(string) + return cache.GetKey(cacheName, key) +} + +// func SetInStash(cacheName string, key string, value string, expiration *time.Duration) any { +func SetInStash(params ...any) (any, error) { + cacheName := params[0].(string) + key := params[1].(string) + value := params[2].(string) + expiration := params[3].(*time.Duration) + return cache.SetKey(cacheName, key, value, expiration), nil +} + +func Sprintf(params ...any) (any, error) { + format := params[0].(string) + return fmt.Sprintf(format, params[1:]...), nil +} + +// func Match(pattern, name string) bool { +func Match(params ...any) (any, error) { var matched bool + + pattern := params[0].(string) + name := params[1].(string) + if pattern == "" { - return name == "" + return name == "", nil } if name == "" { if pattern == "*" || pattern == "" { - return true + return true, nil } - return false + return false, nil } if pattern[0] == '*' { for i := 0; i <= len(name); i++ { - if matched = Match(pattern[1:], name[i:]); matched { - return matched + matched, _ := Match(pattern[1:], name[i:]) + if matched.(bool) { + return matched, nil } } - return matched + return matched, nil } if pattern[0] == '?' || pattern[0] == name[0] { return Match(pattern[1:], name[1:]) } - return matched + return matched, nil } diff --git a/pkg/exprhelpers/jsonextract.go b/pkg/exprhelpers/jsonextract.go index 5b393db9e..12dbb9da8 100644 --- a/pkg/exprhelpers/jsonextract.go +++ b/pkg/exprhelpers/jsonextract.go @@ -11,7 +11,10 @@ import ( log "github.com/sirupsen/logrus" ) -func JsonExtractLib(jsblob string, target ...string) string { +// func JsonExtractLib(jsblob string, target ...string) string { +func JsonExtractLib(params ...any) (any, error) { + jsblob := params[0].(string) + target := params[1].([]string) value, dataType, _, err := jsonparser.Get( jsonparser.StringToBytes(jsblob), target..., @@ -20,45 +23,62 @@ func JsonExtractLib(jsblob string, target ...string) string { if err != nil { if errors.Is(err, jsonparser.KeyPathNotFoundError) { log.Debugf("%+v doesn't exist", target) - return "" + return "", nil } log.Errorf("jsonExtractLib : %+v : %s", target, err) - return "" + return "", nil } if dataType == jsonparser.NotExist { log.Debugf("%+v doesn't exist", target) - return "" + return "", nil } strvalue := string(value) - return strvalue + return strvalue, nil } -func JsonExtractUnescape(jsblob string, target ...string) string { - value, err := jsonparser.GetString( - jsonparser.StringToBytes(jsblob), - target..., - ) +// func JsonExtractUnescape(jsblob string, target ...string) string { +func JsonExtractUnescape(params ...any) (any, error) { + var value string + var err error + jsblob := params[0].(string) + switch v := params[1].(type) { + case string: + target := v + value, err = jsonparser.GetString( + jsonparser.StringToBytes(jsblob), + target, + ) + case []string: + target := v + value, err = jsonparser.GetString( + jsonparser.StringToBytes(jsblob), + target..., + ) + } if err != nil { if errors.Is(err, jsonparser.KeyPathNotFoundError) { - log.Debugf("%+v doesn't exist", target) - return "" + log.Debugf("%+v doesn't exist", params[1]) + return "", nil } - log.Errorf("JsonExtractUnescape : %+v : %s", target, err) - return "" + log.Errorf("JsonExtractUnescape : %+v : %s", params[1], err) + return "", nil } - log.Tracef("extract path %+v", target) - return value + log.Tracef("extract path %+v", params[1]) + return value, nil } -func JsonExtract(jsblob string, target string) string { +// func JsonExtract(jsblob string, target string) string { +func JsonExtract(params ...any) (any, error) { + jsblob := params[0].(string) + target := params[1].(string) if !strings.HasPrefix(target, "[") { target = strings.ReplaceAll(target, "[", ".[") } fullpath := strings.Split(target, ".") log.Tracef("extract path %+v", fullpath) - return JsonExtractLib(jsblob, fullpath...) + return JsonExtractLib(jsblob, fullpath) } func jsonExtractType(jsblob string, target string, t jsonparser.ValueType) ([]byte, error) { @@ -91,13 +111,15 @@ func jsonExtractType(jsblob string, target string, t jsonparser.ValueType) ([]by return value, nil } -func JsonExtractSlice(jsblob string, target string) []interface{} { - +// func JsonExtractSlice(jsblob string, target string) []interface{} { +func JsonExtractSlice(params ...any) (any, error) { + jsblob := params[0].(string) + target := params[1].(string) value, err := jsonExtractType(jsblob, target, jsonparser.Array) if err != nil { log.Errorf("JsonExtractSlice : %s", err) - return nil + return []interface{}(nil), nil } s := make([]interface{}, 0) @@ -105,18 +127,20 @@ func JsonExtractSlice(jsblob string, target string) []interface{} { err = json.Unmarshal(value, &s) if err != nil { log.Errorf("JsonExtractSlice: could not convert '%s' to slice: %s", value, err) - return nil + return []interface{}(nil), nil } - return s + return s, nil } -func JsonExtractObject(jsblob string, target string) map[string]interface{} { - +// func JsonExtractObject(jsblob string, target string) map[string]interface{} { +func JsonExtractObject(params ...any) (any, error) { + jsblob := params[0].(string) + target := params[1].(string) value, err := jsonExtractType(jsblob, target, jsonparser.Object) if err != nil { log.Errorf("JsonExtractObject: %s", err) - return nil + return map[string]interface{}(nil), nil } s := make(map[string]interface{}) @@ -124,16 +148,18 @@ func JsonExtractObject(jsblob string, target string) map[string]interface{} { err = json.Unmarshal(value, &s) if err != nil { log.Errorf("JsonExtractObject: could not convert '%s' to map[string]interface{}: %s", value, err) - return nil + return map[string]interface{}(nil), nil } - return s + return s, nil } -func ToJson(obj interface{}) string { +// func ToJson(obj interface{}) string { +func ToJson(params ...any) (any, error) { + obj := params[0] b, err := json.Marshal(obj) if err != nil { log.Errorf("ToJson : %s", err) - return "" + return "", nil } - return string(b) + return string(b), nil } diff --git a/pkg/exprhelpers/jsonextract_test.go b/pkg/exprhelpers/jsonextract_test.go index 6b85df8a3..594087474 100644 --- a/pkg/exprhelpers/jsonextract_test.go +++ b/pkg/exprhelpers/jsonextract_test.go @@ -4,6 +4,7 @@ import ( "log" "testing" + "github.com/antonmedv/expr" "github.com/stretchr/testify/assert" ) @@ -22,34 +23,43 @@ func TestJsonExtract(t *testing.T) { jsonBlob string targetField string expectResult string + expr string }{ { name: "basic json extract", jsonBlob: `{"test" : "1234"}`, targetField: "test", expectResult: "1234", + expr: "JsonExtract(blob, target)", }, { name: "basic json extract with non existing field", jsonBlob: `{"test" : "1234"}`, targetField: "non_existing_field", expectResult: "", + expr: "JsonExtract(blob, target)", }, { name: "extract subfield", jsonBlob: `{"test" : {"a": "b"}}`, targetField: "test.a", expectResult: "b", + expr: "JsonExtract(blob, target)", }, } for _, test := range tests { - result := JsonExtract(test.jsonBlob, test.targetField) - isOk := assert.Equal(t, test.expectResult, result) - if !isOk { - t.Fatalf("test '%s' failed", test.name) - } - log.Printf("test '%s' : OK", test.name) + t.Run(test.name, func(t *testing.T) { + env := map[string]interface{}{ + "blob": test.jsonBlob, + "target": test.targetField, + } + vm, err := expr.Compile(test.expr, GetExprOptions(env)...) + assert.NoError(t, err) + out, err := expr.Run(vm, env) + assert.NoError(t, err) + assert.Equal(t, test.expectResult, out) + }) } } @@ -68,28 +78,36 @@ func TestJsonExtractUnescape(t *testing.T) { jsonBlob string targetField string expectResult string + expr string }{ { name: "basic json extract", jsonBlob: `{"log" : "\"GET /JBNwtQ6i.blt HTTP/1.1\" 200 13 \"-\" \"Craftbot\""}`, targetField: "log", expectResult: "\"GET /JBNwtQ6i.blt HTTP/1.1\" 200 13 \"-\" \"Craftbot\"", + expr: "JsonExtractUnescape(blob, target)", }, { name: "basic json extract with non existing field", jsonBlob: `{"test" : "1234"}`, targetField: "non_existing_field", expectResult: "", + expr: "JsonExtractUnescape(blob, target)", }, } for _, test := range tests { - result := JsonExtractUnescape(test.jsonBlob, test.targetField) - isOk := assert.Equal(t, test.expectResult, result) - if !isOk { - t.Fatalf("test '%s' failed", test.name) - } - log.Printf("test '%s' : OK", test.name) + t.Run(test.name, func(t *testing.T) { + env := map[string]interface{}{ + "blob": test.jsonBlob, + "target": test.targetField, + } + vm, err := expr.Compile(test.expr, GetExprOptions(env)...) + assert.NoError(t, err) + out, err := expr.Run(vm, env) + assert.NoError(t, err) + assert.Equal(t, test.expectResult, out) + }) } } @@ -108,38 +126,50 @@ func TestJsonExtractSlice(t *testing.T) { jsonBlob string targetField string expectResult []interface{} + expr string }{ { name: "try to extract a string as a slice", jsonBlob: `{"test" : "1234"}`, targetField: "test", expectResult: nil, + expr: "JsonExtractSlice(blob, target)", }, { name: "basic json slice extract", jsonBlob: `{"test" : ["1234"]}`, targetField: "test", expectResult: []interface{}{"1234"}, + expr: "JsonExtractSlice(blob, target)", }, { name: "extract with complex expression", jsonBlob: `{"test": {"foo": [{"a":"b"}]}}`, targetField: "test.foo", expectResult: []interface{}{map[string]interface{}{"a": "b"}}, + expr: "JsonExtractSlice(blob, target)", }, { name: "extract non-existing key", jsonBlob: `{"test: "11234"}`, targetField: "foo", expectResult: nil, + expr: "JsonExtractSlice(blob, target)", }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - result := JsonExtractSlice(test.jsonBlob, test.targetField) - assert.Equal(t, test.expectResult, result) + env := map[string]interface{}{ + "blob": test.jsonBlob, + "target": test.targetField, + } + vm, err := expr.Compile(test.expr, GetExprOptions(env)...) + assert.NoError(t, err) + out, err := expr.Run(vm, env) + assert.NoError(t, err) + assert.Equal(t, test.expectResult, out) }) } } @@ -159,61 +189,79 @@ func TestJsonExtractObject(t *testing.T) { jsonBlob string targetField string expectResult map[string]interface{} + expr string }{ { name: "try to extract a string as an object", jsonBlob: `{"test" : "1234"}`, targetField: "test", expectResult: nil, + expr: "JsonExtractObject(blob, target)", }, { name: "basic json object extract", jsonBlob: `{"test" : {"1234": {"foo": "bar"}}}`, targetField: "test", expectResult: map[string]interface{}{"1234": map[string]interface{}{"foo": "bar"}}, + expr: "JsonExtractObject(blob, target)", }, { name: "extract with complex expression", jsonBlob: `{"test": {"foo": [{"a":"b"}]}}`, targetField: "test.foo[0]", expectResult: map[string]interface{}{"a": "b"}, + expr: "JsonExtractObject(blob, target)", }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - result := JsonExtractObject(test.jsonBlob, test.targetField) - assert.Equal(t, test.expectResult, result) + env := map[string]interface{}{ + "blob": test.jsonBlob, + "target": test.targetField, + } + vm, err := expr.Compile(test.expr, GetExprOptions(env)...) + assert.NoError(t, err) + out, err := expr.Run(vm, env) + assert.NoError(t, err) + assert.Equal(t, test.expectResult, out) }) } } func TestToJson(t *testing.T) { + err := Init(nil) + assert.NoError(t, err) tests := []struct { name string obj interface{} expectResult string + expr string }{ { name: "convert int", obj: 42, expectResult: "42", + expr: "ToJsonString(obj)", }, { name: "convert slice", obj: []string{"foo", "bar"}, expectResult: `["foo","bar"]`, + expr: "ToJsonString(obj)", }, { name: "convert map", obj: map[string]string{"foo": "bar"}, expectResult: `{"foo":"bar"}`, + expr: "ToJsonString(obj)", }, { name: "convert struct", obj: struct{ Foo string }{"bar"}, expectResult: `{"Foo":"bar"}`, + expr: "ToJsonString(obj)", }, { name: "convert complex struct", @@ -233,19 +281,26 @@ func TestToJson(t *testing.T) { Bla: []string{"foo", "bar"}, }, expectResult: `{"Foo":"bar","Bar":{"Baz":"baz"},"Bla":["foo","bar"]}`, + expr: "ToJsonString(obj)", }, { name: "convert invalid type", obj: func() {}, expectResult: "", + expr: "ToJsonString(obj)", }, } for _, test := range tests { - test := test t.Run(test.name, func(t *testing.T) { - result := ToJson(test.obj) - assert.Equal(t, test.expectResult, result) + env := map[string]interface{}{ + "obj": test.obj, + } + vm, err := expr.Compile(test.expr, GetExprOptions(env)...) + assert.NoError(t, err) + out, err := expr.Run(vm, env) + assert.NoError(t, err) + assert.Equal(t, test.expectResult, out) }) } } diff --git a/pkg/exprhelpers/strings.go b/pkg/exprhelpers/strings.go new file mode 100644 index 000000000..81c6776b6 --- /dev/null +++ b/pkg/exprhelpers/strings.go @@ -0,0 +1,69 @@ +package exprhelpers + +import "strings" + +//Wrappers for stdlib strings function exposed in expr + +func Fields(params ...any) (any, error) { + return strings.Fields(params[0].(string)), nil +} + +func Index(params ...any) (any, error) { + return strings.Index(params[0].(string), params[1].(string)), nil +} + +func IndexAny(params ...any) (any, error) { + return strings.IndexAny(params[0].(string), params[1].(string)), nil +} + +func Join(params ...any) (any, error) { + return strings.Join(params[0].([]string), params[1].(string)), nil +} + +func Split(params ...any) (any, error) { + return strings.Split(params[0].(string), params[1].(string)), nil +} + +func SplitAfter(params ...any) (any, error) { + return strings.SplitAfter(params[0].(string), params[1].(string)), nil +} + +func SplitAfterN(params ...any) (any, error) { + return strings.SplitAfterN(params[0].(string), params[1].(string), params[2].(int)), nil +} + +func SplitN(params ...any) (any, error) { + return strings.SplitN(params[0].(string), params[1].(string), params[2].(int)), nil +} + +func Replace(params ...any) (any, error) { + return strings.Replace(params[0].(string), params[1].(string), params[2].(string), params[3].(int)), nil +} + +func ReplaceAll(params ...any) (any, error) { + return strings.ReplaceAll(params[0].(string), params[1].(string), params[2].(string)), nil +} + +func Trim(params ...any) (any, error) { + return strings.Trim(params[0].(string), params[1].(string)), nil +} + +func TrimLeft(params ...any) (any, error) { + return strings.TrimLeft(params[0].(string), params[1].(string)), nil +} + +func TrimPrefix(params ...any) (any, error) { + return strings.TrimPrefix(params[0].(string), params[1].(string)), nil +} + +func TrimRight(params ...any) (any, error) { + return strings.TrimRight(params[0].(string), params[1].(string)), nil +} + +func TrimSpace(params ...any) (any, error) { + return strings.TrimSpace(params[0].(string)), nil +} + +func TrimSuffix(params ...any) (any, error) { + return strings.TrimSuffix(params[0].(string), params[1].(string)), nil +} diff --git a/pkg/exprhelpers/visitor.go b/pkg/exprhelpers/visitor.go index 5dfcc6085..b5bc97e7c 100644 --- a/pkg/exprhelpers/visitor.go +++ b/pkg/exprhelpers/visitor.go @@ -81,7 +81,7 @@ func (v *visitor) Visit(node *ast.Node) { /* Build reconstruct all the variables used in a filter (to display their content later). */ -func (v *visitor) Build(filter string, exprEnv expr.Option) (*ExprDebugger, error) { +func (v *visitor) Build(filter string, exprEnv ...expr.Option) (*ExprDebugger, error) { var expressions []*expression ret := &ExprDebugger{ filter: filter, @@ -105,7 +105,7 @@ func (v *visitor) Build(filter string, exprEnv expr.Option) (*ExprDebugger, erro } toBuild := strings.Join(variable, ".") v.logger.Debugf("compiling expression '%s'", toBuild) - debugFilter, err := expr.Compile(toBuild, exprEnv) + debugFilter, err := expr.Compile(toBuild, exprEnv...) if err != nil { return ret, fmt.Errorf("compilation of variable '%s' failed: %v", toBuild, err) } @@ -153,9 +153,9 @@ func (e *ExprDebugger) Run(logger *logrus.Entry, filterResult bool, exprEnv map[ } // NewDebugger is the exported function that build the debuggers expressions -func NewDebugger(filter string, exprEnv expr.Option) (*ExprDebugger, error) { +func NewDebugger(filter string, exprEnv ...expr.Option) (*ExprDebugger, error) { logger := log.WithField("component", "expr-debugger") visitor := &visitor{logger: logger} - exprDebugger, err := visitor.Build(filter, exprEnv) + exprDebugger, err := visitor.Build(filter, exprEnv...) return exprDebugger, err } diff --git a/pkg/exprhelpers/xml.go b/pkg/exprhelpers/xml.go index 1d0d4073f..75758e183 100644 --- a/pkg/exprhelpers/xml.go +++ b/pkg/exprhelpers/xml.go @@ -7,13 +7,16 @@ import ( var pathCache = make(map[string]etree.Path) -func XMLGetAttributeValue(xmlString string, path string, attributeName string) string { - +// func XMLGetAttributeValue(xmlString string, path string, attributeName string) string { +func XMLGetAttributeValue(params ...any) (any, error) { + xmlString := params[0].(string) + path := params[1].(string) + attributeName := params[2].(string) if _, ok := pathCache[path]; !ok { compiledPath, err := etree.CompilePath(path) if err != nil { log.Errorf("Could not compile path %s: %s", path, err) - return "" + return "", nil } pathCache[path] = compiledPath } @@ -23,27 +26,30 @@ func XMLGetAttributeValue(xmlString string, path string, attributeName string) s err := doc.ReadFromString(xmlString) if err != nil { log.Tracef("Could not parse XML: %s", err) - return "" + return "", nil } elem := doc.FindElementPath(compiledPath) if elem == nil { log.Debugf("Could not find element %s", path) - return "" + return "", nil } attr := elem.SelectAttr(attributeName) if attr == nil { log.Debugf("Could not find attribute %s", attributeName) - return "" + return "", nil } - return attr.Value + return attr.Value, nil } -func XMLGetNodeValue(xmlString string, path string) string { +// func XMLGetNodeValue(xmlString string, path string) string { +func XMLGetNodeValue(params ...any) (any, error) { + xmlString := params[0].(string) + path := params[1].(string) if _, ok := pathCache[path]; !ok { compiledPath, err := etree.CompilePath(path) if err != nil { log.Errorf("Could not compile path %s: %s", path, err) - return "" + return "", nil } pathCache[path] = compiledPath } @@ -53,12 +59,12 @@ func XMLGetNodeValue(xmlString string, path string) string { err := doc.ReadFromString(xmlString) if err != nil { log.Tracef("Could not parse XML: %s", err) - return "" + return "", nil } elem := doc.FindElementPath(compiledPath) if elem == nil { log.Debugf("Could not find element %s", path) - return "" + return "", nil } - return elem.Text() + return elem.Text(), nil } diff --git a/pkg/exprhelpers/xml_test.go b/pkg/exprhelpers/xml_test.go index 695ac0446..516387f76 100644 --- a/pkg/exprhelpers/xml_test.go +++ b/pkg/exprhelpers/xml_test.go @@ -57,7 +57,7 @@ func TestXMLGetAttributeValue(t *testing.T) { } for _, test := range tests { - result := XMLGetAttributeValue(test.xmlString, test.path, test.attribute) + result, _ := XMLGetAttributeValue(test.xmlString, test.path, test.attribute) isOk := assert.Equal(t, test.expectResult, result) if !isOk { t.Fatalf("test '%s' failed", test.name) @@ -104,7 +104,7 @@ func TestXMLGetNodeValue(t *testing.T) { } for _, test := range tests { - result := XMLGetNodeValue(test.xmlString, test.path) + result, _ := XMLGetNodeValue(test.xmlString, test.path) isOk := assert.Equal(t, test.expectResult, result) if !isOk { t.Fatalf("test '%s' failed", test.name) diff --git a/pkg/hubtest/parser_assert.go b/pkg/hubtest/parser_assert.go index b5d488fbd..e13cdb4a2 100644 --- a/pkg/hubtest/parser_assert.go +++ b/pkg/hubtest/parser_assert.go @@ -153,14 +153,14 @@ func (p *ParserAssert) RunExpression(expression string) (interface{}, error) { env := map[string]interface{}{"results": *p.TestData} - if runtimeFilter, err = expr.Compile(expression, expr.Env(exprhelpers.GetExprEnv(env))); err != nil { + if runtimeFilter, err = expr.Compile(expression, exprhelpers.GetExprOptions(env)...); err != nil { return output, err } //dump opcode in trace level log.Tracef("%s", runtimeFilter.Disassemble()) - output, err = expr.Run(runtimeFilter, exprhelpers.GetExprEnv(map[string]interface{}{"results": *p.TestData})) + output, err = expr.Run(runtimeFilter, map[string]interface{}{"results": *p.TestData}) if err != nil { log.Warningf("running : %s", expression) log.Warningf("runtime error : %s", err) diff --git a/pkg/hubtest/scenario_assert.go b/pkg/hubtest/scenario_assert.go index 4b8d8992f..d9ec4dddc 100644 --- a/pkg/hubtest/scenario_assert.go +++ b/pkg/hubtest/scenario_assert.go @@ -148,17 +148,17 @@ func (s *ScenarioAssert) RunExpression(expression string) (interface{}, error) { env := map[string]interface{}{"results": *s.TestData} - if runtimeFilter, err = expr.Compile(expression, expr.Env(exprhelpers.GetExprEnv(env))); err != nil { + if runtimeFilter, err = expr.Compile(expression, exprhelpers.GetExprOptions(env)...); err != nil { return output, err } - // if debugFilter, err = exprhelpers.NewDebugger(assert, expr.Env(exprhelpers.GetExprEnv(env))); err != nil { + // if debugFilter, err = exprhelpers.NewDebugger(assert, expr.Env(env)); err != nil { // log.Warningf("Failed building debugher for %s : %s", assert, err) // } //dump opcode in trace level log.Tracef("%s", runtimeFilter.Disassemble()) - output, err = expr.Run(runtimeFilter, exprhelpers.GetExprEnv(map[string]interface{}{"results": *s.TestData})) + output, err = expr.Run(runtimeFilter, map[string]interface{}{"results": *s.TestData}) if err != nil { log.Warningf("running : %s", expression) log.Warningf("runtime error : %s", err) diff --git a/pkg/leakybucket/conditional.go b/pkg/leakybucket/conditional.go index f55df04ae..3d2f9fc64 100644 --- a/pkg/leakybucket/conditional.go +++ b/pkg/leakybucket/conditional.go @@ -34,8 +34,7 @@ func (c *ConditionalOverflow) OnBucketInit(g *BucketFactory) error { } else { conditionalExprCacheLock.Unlock() //release the lock during compile - compiledExpr, err = expr.Compile(g.ConditionalOverflow, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{ - "queue": &Queue{}, "leaky": &Leaky{}, "evt": &types.Event{}}))) + compiledExpr, err = expr.Compile(g.ConditionalOverflow, exprhelpers.GetExprOptions(map[string]interface{}{"queue": &Queue{}, "leaky": &Leaky{}, "evt": &types.Event{}})...) if err != nil { return fmt.Errorf("conditional compile error : %w", err) } @@ -52,7 +51,7 @@ func (c *ConditionalOverflow) AfterBucketPour(b *BucketFactory) func(types.Event var condition, ok bool if c.ConditionalFilterRuntime != nil { l.logger.Debugf("Running condition expression : %s", c.ConditionalFilter) - ret, err := expr.Run(c.ConditionalFilterRuntime, exprhelpers.GetExprEnv(map[string]interface{}{"evt": &msg, "queue": l.Queue, "leaky": l})) + ret, err := expr.Run(c.ConditionalFilterRuntime, map[string]interface{}{"evt": &msg, "queue": l.Queue, "leaky": l}) if err != nil { l.logger.Errorf("unable to run conditional filter : %s", err) return &msg diff --git a/pkg/leakybucket/manager_load.go b/pkg/leakybucket/manager_load.go index 520a6a7be..3789c749d 100644 --- a/pkg/leakybucket/manager_load.go +++ b/pkg/leakybucket/manager_load.go @@ -134,7 +134,7 @@ func ValidateFactory(bucketFactory *BucketFactory) error { err error ) if bucketFactory.ScopeType.Filter != "" { - if runTimeFilter, err = expr.Compile(bucketFactory.ScopeType.Filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))); err != nil { + if runTimeFilter, err = expr.Compile(bucketFactory.ScopeType.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...); err != nil { return fmt.Errorf("Error compiling the scope filter: %s", err) } bucketFactory.ScopeType.RunTimeFilter = runTimeFilter @@ -147,7 +147,7 @@ func ValidateFactory(bucketFactory *BucketFactory) error { err error ) if bucketFactory.ScopeType.Filter != "" { - if runTimeFilter, err = expr.Compile(bucketFactory.ScopeType.Filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))); err != nil { + if runTimeFilter, err = expr.Compile(bucketFactory.ScopeType.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...); err != nil { return fmt.Errorf("Error compiling the scope filter: %s", err) } bucketFactory.ScopeType.RunTimeFilter = runTimeFilter @@ -286,19 +286,19 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { bucketFactory.logger.Warning("Bucket without filter, abort.") return fmt.Errorf("bucket without filter directive") } - bucketFactory.RunTimeFilter, err = expr.Compile(bucketFactory.Filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + bucketFactory.RunTimeFilter, err = expr.Compile(bucketFactory.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return fmt.Errorf("invalid filter '%s' in %s : %v", bucketFactory.Filter, bucketFactory.Filename, err) } if bucketFactory.Debug { - bucketFactory.ExprDebugger, err = exprhelpers.NewDebugger(bucketFactory.Filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + bucketFactory.ExprDebugger, err = exprhelpers.NewDebugger(bucketFactory.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { log.Errorf("unable to build debug filter for '%s' : %s", bucketFactory.Filter, err) } } if bucketFactory.GroupBy != "" { - bucketFactory.RunTimeGroupBy, err = expr.Compile(bucketFactory.GroupBy, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + bucketFactory.RunTimeGroupBy, err = expr.Compile(bucketFactory.GroupBy, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return fmt.Errorf("invalid groupby '%s' in %s : %v", bucketFactory.GroupBy, bucketFactory.Filename, err) } diff --git a/pkg/leakybucket/manager_run.go b/pkg/leakybucket/manager_run.go index 00888ecd2..56998d745 100644 --- a/pkg/leakybucket/manager_run.go +++ b/pkg/leakybucket/manager_run.go @@ -13,7 +13,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/antonmedv/expr" - "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/prometheus/client_golang/prometheus" ) @@ -297,8 +296,6 @@ func PourItemToHolders(parsed types.Event, holders []BucketFactory, buckets *Buc BucketPourCache["OK"] = append(BucketPourCache["OK"], evt.(types.Event)) } - cachedExprEnv := exprhelpers.GetExprEnv(map[string]interface{}{"evt": &parsed}) - //find the relevant holders (scenarios) for idx := 0; idx < len(holders); idx++ { //for idx, holder := range holders { @@ -306,7 +303,7 @@ func PourItemToHolders(parsed types.Event, holders []BucketFactory, buckets *Buc //evaluate bucket's condition if holders[idx].RunTimeFilter != nil { holders[idx].logger.Tracef("event against holder %d/%d", idx, len(holders)) - output, err := expr.Run(holders[idx].RunTimeFilter, cachedExprEnv) + output, err := expr.Run(holders[idx].RunTimeFilter, map[string]interface{}{"evt": &parsed}) if err != nil { holders[idx].logger.Errorf("failed parsing : %v", err) return false, fmt.Errorf("leaky failed : %s", err) @@ -318,7 +315,7 @@ func PourItemToHolders(parsed types.Event, holders []BucketFactory, buckets *Buc } if holders[idx].Debug { - holders[idx].ExprDebugger.Run(holders[idx].logger, condition, cachedExprEnv) + holders[idx].ExprDebugger.Run(holders[idx].logger, condition, map[string]interface{}{"evt": &parsed}) } if !condition { holders[idx].logger.Debugf("Event leaving node : ko (filter mismatch)") @@ -329,7 +326,7 @@ func PourItemToHolders(parsed types.Event, holders []BucketFactory, buckets *Buc //groupby determines the partition key for the specific bucket var groupby string if holders[idx].RunTimeGroupBy != nil { - tmpGroupBy, err := expr.Run(holders[idx].RunTimeGroupBy, cachedExprEnv) + tmpGroupBy, err := expr.Run(holders[idx].RunTimeGroupBy, map[string]interface{}{"evt": &parsed}) if err != nil { holders[idx].logger.Errorf("failed groupby : %v", err) return false, errors.New("leaky failed :/") diff --git a/pkg/leakybucket/overflow_filter.go b/pkg/leakybucket/overflow_filter.go index 7be6720b5..c716c22d3 100644 --- a/pkg/leakybucket/overflow_filter.go +++ b/pkg/leakybucket/overflow_filter.go @@ -27,8 +27,8 @@ func NewOverflowFilter(g *BucketFactory) (*OverflowFilter, error) { u := OverflowFilter{} u.Filter = g.OverflowFilter - u.FilterRuntime, err = expr.Compile(u.Filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{ - "queue": &Queue{}, "signal": &types.RuntimeAlert{}, "leaky": &Leaky{}}))) + + u.FilterRuntime, err = expr.Compile(u.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"queue": &Queue{}, "signal": &types.RuntimeAlert{}, "leaky": &Leaky{}})...) if err != nil { g.logger.Errorf("Unable to compile filter : %v", err) return nil, fmt.Errorf("unable to compile filter : %v", err) @@ -38,8 +38,8 @@ func NewOverflowFilter(g *BucketFactory) (*OverflowFilter, error) { func (u *OverflowFilter) OnBucketOverflow(Bucket *BucketFactory) func(*Leaky, types.RuntimeAlert, *Queue) (types.RuntimeAlert, *Queue) { return func(l *Leaky, s types.RuntimeAlert, q *Queue) (types.RuntimeAlert, *Queue) { - el, err := expr.Run(u.FilterRuntime, exprhelpers.GetExprEnv(map[string]interface{}{ - "queue": q, "signal": s, "leaky": l})) + el, err := expr.Run(u.FilterRuntime, map[string]interface{}{ + "queue": q, "signal": s, "leaky": l}) if err != nil { l.logger.Errorf("Failed running overflow filter: %s", err) return s, q diff --git a/pkg/leakybucket/overflows.go b/pkg/leakybucket/overflows.go index c1f051523..d6131cd26 100644 --- a/pkg/leakybucket/overflows.go +++ b/pkg/leakybucket/overflows.go @@ -15,7 +15,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/antonmedv/expr" - "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" ) // SourceFromEvent extracts and formats a valid models.Source object from an Event @@ -52,7 +51,7 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e *src.Value = v.Range } if leaky.scopeType.RunTimeFilter != nil { - retValue, err := expr.Run(leaky.scopeType.RunTimeFilter, exprhelpers.GetExprEnv(map[string]interface{}{"evt": &evt})) + retValue, err := expr.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}) if err != nil { return srcs, errors.Wrapf(err, "while running scope filter") } @@ -127,7 +126,7 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e } else if leaky.scopeType.Scope == types.Range { src.Value = &src.Range if leaky.scopeType.RunTimeFilter != nil { - retValue, err := expr.Run(leaky.scopeType.RunTimeFilter, exprhelpers.GetExprEnv(map[string]interface{}{"evt": &evt})) + retValue, err := expr.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}) if err != nil { return srcs, errors.Wrapf(err, "while running scope filter") } @@ -144,7 +143,7 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e if leaky.scopeType.RunTimeFilter == nil { return srcs, fmt.Errorf("empty scope information") } - retValue, err := expr.Run(leaky.scopeType.RunTimeFilter, exprhelpers.GetExprEnv(map[string]interface{}{"evt": &evt})) + retValue, err := expr.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}) if err != nil { return srcs, errors.Wrapf(err, "while running scope filter") } diff --git a/pkg/leakybucket/reset_filter.go b/pkg/leakybucket/reset_filter.go index 0c925afdb..9b64681ab 100644 --- a/pkg/leakybucket/reset_filter.go +++ b/pkg/leakybucket/reset_filter.go @@ -34,14 +34,14 @@ func (u *CancelOnFilter) OnBucketPour(bucketFactory *BucketFactory) func(types.E var condition, ok bool if u.CancelOnFilter != nil { leaky.logger.Tracef("running cancel_on filter") - output, err := expr.Run(u.CancelOnFilter, exprhelpers.GetExprEnv(map[string]interface{}{"evt": &msg})) + output, err := expr.Run(u.CancelOnFilter, map[string]interface{}{"evt": &msg}) if err != nil { leaky.logger.Warningf("cancel_on error : %s", err) return &msg } //only run debugger expression if condition is false if u.CancelOnFilterDebug != nil { - u.CancelOnFilterDebug.Run(leaky.logger, condition, exprhelpers.GetExprEnv(map[string]interface{}{"evt": &msg})) + u.CancelOnFilterDebug.Run(leaky.logger, condition, map[string]interface{}{"evt": &msg}) } if condition, ok = output.(bool); !ok { leaky.logger.Warningf("cancel_on, unexpected non-bool return : %T", output) @@ -93,14 +93,16 @@ func (u *CancelOnFilter) OnBucketInit(bucketFactory *BucketFactory) error { } else { cancelExprCacheLock.Unlock() //release the lock during compile - compiledExpr.CancelOnFilter, err = expr.Compile(bucketFactory.CancelOnFilter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + + compiledExpr.CancelOnFilter, err = expr.Compile(bucketFactory.CancelOnFilter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { bucketFactory.logger.Errorf("reset_filter compile error : %s", err) return err } u.CancelOnFilter = compiledExpr.CancelOnFilter if bucketFactory.Debug { - compiledExpr.CancelOnFilterDebug, err = exprhelpers.NewDebugger(bucketFactory.CancelOnFilter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + compiledExpr.CancelOnFilterDebug, err = exprhelpers.NewDebugger(bucketFactory.CancelOnFilter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})..., + ) if err != nil { bucketFactory.logger.Errorf("reset_filter debug error : %s", err) return err diff --git a/pkg/leakybucket/uniq.go b/pkg/leakybucket/uniq.go index 2c9ddb11e..cb8bf63fe 100644 --- a/pkg/leakybucket/uniq.go +++ b/pkg/leakybucket/uniq.go @@ -74,7 +74,7 @@ func (u *Uniq) OnBucketInit(bucketFactory *BucketFactory) error { } else { uniqExprCacheLock.Unlock() //release the lock during compile - compiledExpr, err = expr.Compile(bucketFactory.Distinct, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + compiledExpr, err = expr.Compile(bucketFactory.Distinct, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) u.DistinctCompiled = compiledExpr uniqExprCacheLock.Lock() uniqExprCache[bucketFactory.Distinct] = *compiledExpr @@ -86,7 +86,7 @@ func (u *Uniq) OnBucketInit(bucketFactory *BucketFactory) error { // getElement computes a string from an event and a filter func getElement(msg types.Event, cFilter *vm.Program) (string, error) { - el, err := expr.Run(cFilter, exprhelpers.GetExprEnv(map[string]interface{}{"evt": &msg})) + el, err := expr.Run(cFilter, map[string]interface{}{"evt": &msg}) if err != nil { return "", err } diff --git a/pkg/parser/enrich_date.go b/pkg/parser/enrich_date.go index c8e49e3d7..a1dd994be 100644 --- a/pkg/parser/enrich_date.go +++ b/pkg/parser/enrich_date.go @@ -84,10 +84,10 @@ func ParseDate(in string, p *types.Event, x interface{}, plog *log.Entry) (map[s } timeobj, err := expr.ParseUnixTime(in) if err == nil { - ret["MarshaledTime"] = timeobj.Format(time.RFC3339) + ret["MarshaledTime"] = timeobj.(time.Time).Format(time.RFC3339) //In time machine, we take the time parsed from the event. In live mode, we keep the timestamp collected at acquisition if p.ExpectMode == types.TIMEMACHINE { - p.Time = timeobj + p.Time = timeobj.(time.Time) } return ret, nil } diff --git a/pkg/parser/node.go b/pkg/parser/node.go index cdd6f7119..5a3b972d6 100644 --- a/pkg/parser/node.go +++ b/pkg/parser/node.go @@ -472,13 +472,13 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { //compile filter if present if n.Filter != "" { - n.RunTimeFilter, err = expr.Compile(n.Filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + n.RunTimeFilter, err = expr.Compile(n.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return fmt.Errorf("compilation of '%s' failed: %v", n.Filter, err) } if n.Debug { - n.ExprDebugger, err = exprhelpers.NewDebugger(n.Filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + n.ExprDebugger, err = exprhelpers.NewDebugger(n.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { log.Errorf("unable to build debug filter for '%s' : %s", n.Filter, err) } @@ -530,7 +530,7 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { /*if grok source is an expression*/ if n.Grok.ExpValue != "" { n.Grok.RunTimeValue, err = expr.Compile(n.Grok.ExpValue, - expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return errors.Wrap(err, "while compiling grok's expression") } @@ -542,7 +542,7 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { for idx := range n.Grok.Statics { if n.Grok.Statics[idx].ExpValue != "" { n.Grok.Statics[idx].RunTimeValue, err = expr.Compile(n.Grok.Statics[idx].ExpValue, - expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return err } @@ -554,13 +554,13 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { /* load data capture (stash) */ for i, stash := range n.Stash { n.Stash[i].ValueExpression, err = expr.Compile(stash.Value, - expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return errors.Wrap(err, "while compiling stash value expression") } n.Stash[i].KeyExpression, err = expr.Compile(stash.Key, - expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return errors.Wrap(err, "while compiling stash key expression") } @@ -607,7 +607,7 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { /* load statics if present */ for idx := range n.Statics { if n.Statics[idx].ExpValue != "" { - n.Statics[idx].RunTimeValue, err = expr.Compile(n.Statics[idx].ExpValue, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + n.Statics[idx].RunTimeValue, err = expr.Compile(n.Statics[idx].ExpValue, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { n.Logger.Errorf("Statics Compilation failed %v.", err) return err @@ -633,11 +633,11 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { } for _, filter := range n.Whitelist.Exprs { expression := &ExprWhitelist{} - expression.Filter, err = expr.Compile(filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + expression.Filter, err = expr.Compile(filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { n.Logger.Fatalf("Unable to compile whitelist expression '%s' : %v.", filter, err) } - expression.ExprDebugger, err = exprhelpers.NewDebugger(filter, expr.Env(exprhelpers.GetExprEnv(map[string]interface{}{"evt": &types.Event{}}))) + expression.ExprDebugger, err = exprhelpers.NewDebugger(filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { log.Errorf("unable to build debug filter for '%s' : %s", filter, err) } diff --git a/pkg/parser/runtime.go b/pkg/parser/runtime.go index 53d24d00d..4541eafd9 100644 --- a/pkg/parser/runtime.go +++ b/pkg/parser/runtime.go @@ -13,7 +13,6 @@ import ( "sync" "time" - "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" "strconv" @@ -118,14 +117,12 @@ func (n *Node) ProcessStatics(statics []types.ExtraField, event *types.Event) er var value string clog := n.Logger - cachedExprEnv := exprhelpers.GetExprEnv(map[string]interface{}{"evt": event}) - for _, static := range statics { value = "" if static.Value != "" { value = static.Value } else if static.RunTimeValue != nil { - output, err := expr.Run(static.RunTimeValue, cachedExprEnv) + output, err := expr.Run(static.RunTimeValue, map[string]interface{}{"evt": event}) if err != nil { clog.Warningf("failed to run RunTimeValue : %v", err) continue @@ -272,8 +269,6 @@ func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error) log.Tracef("INPUT '%s'", event.Line.Raw) } - cachedExprEnv := exprhelpers.GetExprEnv(map[string]interface{}{"evt": &event}) - if ParseDump { if StageParseCache == nil { StageParseMutex.Lock() @@ -321,7 +316,7 @@ func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error) if ctx.Profiling { node.Profiling = true } - ret, err := node.process(&event, ctx, cachedExprEnv) + ret, err := node.process(&event, ctx, map[string]interface{}{"evt": &event}) if err != nil { clog.Errorf("Error while processing node : %v", err) return event, err