diff --git a/pkg/parser/node.go b/pkg/parser/node.go index 159a11813..9f848535f 100644 --- a/pkg/parser/node.go +++ b/pkg/parser/node.go @@ -3,7 +3,6 @@ package parser import ( "errors" "fmt" - "net" "strings" "time" @@ -172,75 +171,24 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri if n.Name != "" { NodesHits.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name}).Inc() } - isWhitelisted := false - hasWhitelist := false - var srcs []net.IP - /*overflow and log don't hold the source ip in the same field, should be changed */ - /* perform whitelist checks for ips, cidr accordingly */ - /* TODO move whitelist elsewhere */ - if p.Type == types.LOG { - if _, ok := p.Meta["source_ip"]; ok { - srcs = append(srcs, net.ParseIP(p.Meta["source_ip"])) - } - } else if p.Type == types.OVFLW { - for k := range p.Overflow.Sources { - srcs = append(srcs, net.ParseIP(k)) - } + exprErr := error(nil) + isWhitelisted := n.CheckIPsWL(p.ParseIPSources()) + if !isWhitelisted { + isWhitelisted, exprErr = n.CheckExprWL(cachedExprEnv) } - for _, src := range srcs { - if isWhitelisted { - break - } - for _, v := range n.Whitelist.B_Ips { - if v.Equal(src) { - clog.Debugf("Event from [%s] is whitelisted by IP (%s), reason [%s]", src, v, n.Whitelist.Reason) - isWhitelisted = true - } else { - clog.Tracef("whitelist: %s is not eq [%s]", src, v) - } - hasWhitelist = true - } - for _, v := range n.Whitelist.B_Cidrs { - if v.Contains(src) { - clog.Debugf("Event from [%s] is whitelisted by CIDR (%s), reason [%s]", src, v, n.Whitelist.Reason) - isWhitelisted = true - } else { - clog.Tracef("whitelist: %s not in [%s]", src, v) - } - hasWhitelist = true - } + if exprErr != nil { + // Previous code returned nil if there was an error, so we keep this behavior + return false, nil //nolint:nilerr } - /* run whitelist expression tests anyway */ - for eidx, e := range n.Whitelist.B_Exprs { - output, err := expr.Run(e.Filter, cachedExprEnv) - if err != nil { - clog.Warningf("failed to run whitelist expr : %v", err) - clog.Debug("Event leaving node : ko") - return false, nil - } - switch out := output.(type) { - case bool: - if n.Debug { - e.ExprDebugger.Run(clog, out, cachedExprEnv) - } - if out { - clog.Debugf("Event is whitelisted by expr, reason [%s]", n.Whitelist.Reason) - isWhitelisted = true - } - hasWhitelist = true - default: - log.Errorf("unexpected type %t (%v) while running '%s'", output, output, n.Whitelist.Exprs[eidx]) - } - } if isWhitelisted && !p.Whitelisted { p.Whitelisted = true p.WhitelistReason = n.Whitelist.Reason /*huglily wipe the ban order if the event is whitelisted and it's an overflow */ if p.Type == types.OVFLW { /*don't do this at home kids */ ips := []string{} - for _, src := range srcs { - ips = append(ips, src.String()) + for k := range p.Overflow.Sources { + ips = append(ips, k) } clog.Infof("Ban for %s whitelisted, reason [%s]", strings.Join(ips, ","), n.Whitelist.Reason) p.Overflow.Whitelisted = true @@ -395,9 +343,10 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri } /* - This is to apply statics when the node *has* whitelists that successfully matched the node. + This is to apply statics when the node either was whitelisted, or is not a whitelist (it has no expr/ips wl) + It is overconvoluted and should be simplified */ - if len(n.Statics) > 0 && (isWhitelisted || !hasWhitelist) { + if len(n.Statics) > 0 && (isWhitelisted || !n.ContainsWLs()) { clog.Debugf("+ Processing %d statics", len(n.Statics)) // if all else is good in whitelist, process node's statics err := n.ProcessStatics(n.Statics, p) @@ -610,36 +559,11 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error { } /* compile whitelists if present */ - for _, v := range n.Whitelist.Ips { - n.Whitelist.B_Ips = append(n.Whitelist.B_Ips, net.ParseIP(v)) - n.Logger.Debugf("adding ip %s to whitelists", net.ParseIP(v)) - valid = true - } - - for _, v := range n.Whitelist.Cidrs { - _, tnet, err := net.ParseCIDR(v) - if err != nil { - n.Logger.Fatalf("Unable to parse cidr whitelist '%s' : %v.", v, err) - } - n.Whitelist.B_Cidrs = append(n.Whitelist.B_Cidrs, tnet) - n.Logger.Debugf("adding cidr %s to whitelists", tnet) - valid = true - } - - for _, filter := range n.Whitelist.Exprs { - expression := &ExprWhitelist{} - expression.Filter, err = expr.Compile(filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) - if err != nil { - n.Logger.Fatalf("Unable to compile whitelist expression '%s' : %v.", filter, err) - } - expression.ExprDebugger, err = exprhelpers.NewDebugger(filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) - if err != nil { - log.Errorf("unable to build debug filter for '%s' : %s", filter, err) - } - n.Whitelist.B_Exprs = append(n.Whitelist.B_Exprs, expression) - n.Logger.Debugf("adding expression %s to whitelists", filter) - valid = true + whitelistValid, err := n.CompileWLs() + if err != nil { + return err } + valid = valid || whitelistValid if !valid { /* node is empty, error force return */ diff --git a/pkg/parser/whitelist.go b/pkg/parser/whitelist.go index e2f179fb3..8c18e70c3 100644 --- a/pkg/parser/whitelist.go +++ b/pkg/parser/whitelist.go @@ -1,11 +1,14 @@ package parser import ( + "fmt" "net" + "github.com/antonmedv/expr" "github.com/antonmedv/expr/vm" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" + "github.com/crowdsecurity/crowdsec/pkg/types" ) type Whitelist struct { @@ -22,3 +25,111 @@ type ExprWhitelist struct { Filter *vm.Program ExprDebugger *exprhelpers.ExprDebugger // used to debug expression by printing the content of each variable of the expression } + +func (n *Node) ContainsWLs() bool { + return n.ContainsIPLists() || n.ContainsExprLists() +} + +func (n *Node) ContainsExprLists() bool { + return len(n.Whitelist.B_Exprs) > 0 +} + +func (n *Node) ContainsIPLists() bool { + return len(n.Whitelist.B_Ips) > 0 || len(n.Whitelist.B_Cidrs) > 0 +} + +func (n *Node) CheckIPsWL(srcs []net.IP) bool { + isWhitelisted := false + if !n.ContainsIPLists() { + return isWhitelisted + } + for _, src := range srcs { + if isWhitelisted { + break + } + for _, v := range n.Whitelist.B_Ips { + if v.Equal(src) { + n.Logger.Debugf("Event from [%s] is whitelisted by IP (%s), reason [%s]", src, v, n.Whitelist.Reason) + isWhitelisted = true + break + } + n.Logger.Tracef("whitelist: %s is not eq [%s]", src, v) + } + for _, v := range n.Whitelist.B_Cidrs { + if v.Contains(src) { + n.Logger.Debugf("Event from [%s] is whitelisted by CIDR (%s), reason [%s]", src, v, n.Whitelist.Reason) + isWhitelisted = true + break + } + n.Logger.Tracef("whitelist: %s not in [%s]", src, v) + } + } + return isWhitelisted +} + +func (n *Node) CheckExprWL(cachedExprEnv map[string]interface{}) (bool, error) { + isWhitelisted := false + + if !n.ContainsExprLists() { + return false, nil + } + /* run whitelist expression tests anyway */ + for eidx, e := range n.Whitelist.B_Exprs { + //if we already know the event is whitelisted, skip the rest of the expressions + if isWhitelisted { + break + } + output, err := expr.Run(e.Filter, cachedExprEnv) + if err != nil { + n.Logger.Warningf("failed to run whitelist expr : %v", err) + n.Logger.Debug("Event leaving node : ko") + return isWhitelisted, err + } + switch out := output.(type) { + case bool: + if n.Debug { + e.ExprDebugger.Run(n.Logger, out, cachedExprEnv) + } + if out { + n.Logger.Debugf("Event is whitelisted by expr, reason [%s]", n.Whitelist.Reason) + isWhitelisted = true + } + default: + n.Logger.Errorf("unexpected type %t (%v) while running '%s'", output, output, n.Whitelist.Exprs[eidx]) + } + } + return isWhitelisted, nil +} + +func (n *Node) CompileWLs() (bool, error) { + for _, v := range n.Whitelist.Ips { + n.Whitelist.B_Ips = append(n.Whitelist.B_Ips, net.ParseIP(v)) + n.Logger.Debugf("adding ip %s to whitelists", net.ParseIP(v)) + } + + for _, v := range n.Whitelist.Cidrs { + _, tnet, err := net.ParseCIDR(v) + if err != nil { + return false, fmt.Errorf("unable to parse cidr whitelist '%s' : %v", v, err) + } + n.Whitelist.B_Cidrs = append(n.Whitelist.B_Cidrs, tnet) + n.Logger.Debugf("adding cidr %s to whitelists", tnet) + } + + for _, filter := range n.Whitelist.Exprs { + var err error + expression := &ExprWhitelist{} + expression.Filter, err = expr.Compile(filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) + if err != nil { + return false, fmt.Errorf("unable to compile whitelist expression '%s' : %v", filter, err) + } + expression.ExprDebugger, err = exprhelpers.NewDebugger(filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) + if err != nil { + n.Logger.Errorf("unable to build debug filter for '%s' : %s", filter, err) + } + + n.Whitelist.B_Exprs = append(n.Whitelist.B_Exprs, expression) + n.Logger.Debugf("adding expression %s to whitelists", filter) + } + return n.ContainsWLs(), nil +} diff --git a/pkg/parser/whitelist_test.go b/pkg/parser/whitelist_test.go new file mode 100644 index 000000000..8796aaeda --- /dev/null +++ b/pkg/parser/whitelist_test.go @@ -0,0 +1,300 @@ +package parser + +import ( + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/go-cs-lib/cstest" + + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func TestWhitelistCompile(t *testing.T) { + node := &Node{ + Logger: log.NewEntry(log.New()), + } + tests := []struct { + name string + whitelist Whitelist + expectedErr string + }{ + { + name: "Valid CIDR whitelist", + whitelist: Whitelist{ + Reason: "test", + Cidrs: []string{ + "127.0.0.1/24", + }, + }, + }, + { + name: "Invalid CIDR whitelist", + whitelist: Whitelist{ + Reason: "test", + Cidrs: []string{ + "127.0.0.1/1000", + }, + }, + expectedErr: "invalid CIDR address", + }, + { + name: "Valid EXPR whitelist", + whitelist: Whitelist{ + Reason: "test", + Exprs: []string{ + "1==1", + }, + }, + }, + { + name: "Invalid EXPR whitelist", + whitelist: Whitelist{ + Reason: "test", + Exprs: []string{ + "evt.THISPROPERTYSHOULDERROR == true", + }, + }, + expectedErr: "types.Event has no field", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + node.Whitelist = tt.whitelist + _, err := node.CompileWLs() + cstest.RequireErrorContains(t, err, tt.expectedErr) + }) + } +} + +func TestWhitelistCheck(t *testing.T) { + node := &Node{ + Logger: log.NewEntry(log.New()), + } + tests := []struct { + name string + whitelist Whitelist + event *types.Event + expected bool + }{ + { + name: "IP Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Ips: []string{ + "127.0.0.1", + }, + }, + event: &types.Event{ + Meta: map[string]string{ + "source_ip": "127.0.0.1", + }, + }, + expected: true, + }, + { + name: "IP Not Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Ips: []string{ + "127.0.0.1", + }, + }, + event: &types.Event{ + Meta: map[string]string{ + "source_ip": "127.0.0.2", + }, + }, + }, + { + name: "CIDR Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Cidrs: []string{ + "127.0.0.1/32", + }, + }, + event: &types.Event{ + Meta: map[string]string{ + "source_ip": "127.0.0.1", + }, + }, + expected: true, + }, + { + name: "CIDR Not Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Cidrs: []string{ + "127.0.0.1/32", + }, + }, + event: &types.Event{ + Meta: map[string]string{ + "source_ip": "127.0.0.2", + }, + }, + }, + { + name: "EXPR Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Exprs: []string{ + "evt.Meta.source_ip == '127.0.0.1'", + }, + }, + event: &types.Event{ + Meta: map[string]string{ + "source_ip": "127.0.0.1", + }, + }, + expected: true, + }, + { + name: "EXPR Not Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Exprs: []string{ + "evt.Meta.source_ip == '127.0.0.1'", + }, + }, + event: &types.Event{ + Meta: map[string]string{ + "source_ip": "127.0.0.2", + }, + }, + }, + { + name: "Postoverflow IP Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Ips: []string{ + "192.168.1.1", + }, + }, + event: &types.Event{ + Type: types.OVFLW, + Overflow: types.RuntimeAlert{ + Sources: map[string]models.Source{ + "192.168.1.1": {}, + }, + }, + }, + expected: true, + }, + { + name: "Postoverflow IP Not Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Ips: []string{ + "192.168.1.2", + }, + }, + event: &types.Event{ + Type: types.OVFLW, + Overflow: types.RuntimeAlert{ + Sources: map[string]models.Source{ + "192.168.1.1": {}, + }, + }, + }, + }, + { + name: "Postoverflow CIDR Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Cidrs: []string{ + "192.168.1.1/32", + }, + }, + event: &types.Event{ + Type: types.OVFLW, + Overflow: types.RuntimeAlert{ + Sources: map[string]models.Source{ + "192.168.1.1": {}, + }, + }, + }, + expected: true, + }, + { + name: "Postoverflow CIDR Not Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Cidrs: []string{ + "192.168.1.2/32", + }, + }, + event: &types.Event{ + Type: types.OVFLW, + Overflow: types.RuntimeAlert{ + Sources: map[string]models.Source{ + "192.168.1.1": {}, + }, + }, + }, + }, + { + name: "Postoverflow EXPR Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Exprs: []string{ + "evt.Overflow.APIAlerts[0].Source.Cn == 'test'", + }, + }, + event: &types.Event{ + Type: types.OVFLW, + Overflow: types.RuntimeAlert{ + APIAlerts: []models.Alert{ + { + Source: &models.Source{ + Cn: "test", + }, + }, + }, + }, + }, + expected: true, + }, + { + name: "Postoverflow EXPR Not Whitelisted", + whitelist: Whitelist{ + Reason: "test", + Exprs: []string{ + "evt.Overflow.APIAlerts[0].Source.Cn == 'test2'", + }, + }, + event: &types.Event{ + Type: types.OVFLW, + Overflow: types.RuntimeAlert{ + APIAlerts: []models.Alert{ + { + Source: &models.Source{ + Cn: "test", + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + var err error + node.Whitelist = tt.whitelist + node.CompileWLs() + isWhitelisted := node.CheckIPsWL(tt.event.ParseIPSources()) + if !isWhitelisted { + isWhitelisted, err = node.CheckExprWL(map[string]interface{}{"evt": tt.event}) + } + require.NoError(t, err) + require.Equal(t, tt.expected, isWhitelisted) + }) + } +} diff --git a/pkg/types/event.go b/pkg/types/event.go index fc8d966ab..622d1d8bc 100644 --- a/pkg/types/event.go +++ b/pkg/types/event.go @@ -1,6 +1,7 @@ package types import ( + "net" "time" log "github.com/sirupsen/logrus" @@ -73,6 +74,21 @@ func (e *Event) GetMeta(key string) string { return "" } +func (e *Event) ParseIPSources() []net.IP { + var srcs []net.IP + switch e.Type { + case LOG: + if _, ok := e.Meta["source_ip"]; ok { + srcs = append(srcs, net.ParseIP(e.Meta["source_ip"])) + } + case OVFLW: + for k := range e.Overflow.Sources { + srcs = append(srcs, net.ParseIP(k)) + } + } + return srcs +} + // Move in leakybuckets const ( Undefined = "" diff --git a/pkg/types/event_test.go b/pkg/types/event_test.go new file mode 100644 index 000000000..c3261c647 --- /dev/null +++ b/pkg/types/event_test.go @@ -0,0 +1,79 @@ +package types + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +func TestParseIPSources(t *testing.T) { + tests := []struct { + name string + evt Event + expected []net.IP + }{ + { + name: "ParseIPSources: Valid Log Sources", + evt: Event{ + Type: LOG, + Meta: map[string]string{ + "source_ip": "127.0.0.1", + }, + }, + expected: []net.IP{ + net.ParseIP("127.0.0.1"), + }, + }, + { + name: "ParseIPSources: Valid Overflow Sources", + evt: Event{ + Type: OVFLW, + Overflow: RuntimeAlert{ + Sources: map[string]models.Source{ + "127.0.0.1": {}, + }, + }, + }, + expected: []net.IP{ + net.ParseIP("127.0.0.1"), + }, + }, + { + name: "ParseIPSources: Invalid Log Sources", + evt: Event{ + Type: LOG, + Meta: map[string]string{ + "source_ip": "IAMNOTANIP", + }, + }, + expected: []net.IP{ + nil, + }, + }, + { + name: "ParseIPSources: Invalid Overflow Sources", + evt: Event{ + Type: OVFLW, + Overflow: RuntimeAlert{ + Sources: map[string]models.Source{ + "IAMNOTANIP": {}, + }, + }, + }, + expected: []net.IP{ + nil, + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ips := tt.evt.ParseIPSources() + assert.Equal(t, ips, tt.expected) + }) + } +}