Преглед на файлове

Runtime whitelist parsing improvement (#2422)

* Improve whitelist parsing

* Split whitelist check into a function tied to whitelist, also since we check node debug we can make a pointer to node containing whitelist

* No point passing clog as an argument since it is just a pointer to node we already know about

* We should break instead of returning false, false as it may have been whitelisted by ips/cidrs

* reimplement early return if expr errors

* Fix lint and dont need to parse ip back to string just loop over sources

* Log error with node logger as it provides context

* Move getsource to a function cleanup some code

* Change func name

* Split out compile to a function so we can use in tests. Add a bunch of tests

* spell correction

* Use node logger so it has context

* alternative solution

* quick fixes

* Use containswls

* Change whitelist test to use parseipsource and only events

* Make it simpler

* Postoverflow tests, some basic ones to make sure it works

* Use official pkg

* Add @mmetc reco

* Add @mmetc reco

* Change if if to a switch to only evaluate once

* simplify assertions

---------

Co-authored-by: bui <thibault@crowdsec.net>
Co-authored-by: Marco Mariani <marco@crowdsec.net>
Laurence Jones преди 1 година
родител
ревизия
19de3a8a77
променени са 5 файла, в които са добавени 522 реда и са изтрити 92 реда
  1. 16 92
      pkg/parser/node.go
  2. 111 0
      pkg/parser/whitelist.go
  3. 300 0
      pkg/parser/whitelist_test.go
  4. 16 0
      pkg/types/event.go
  5. 79 0
      pkg/types/event_test.go

+ 16 - 92
pkg/parser/node.go

@@ -3,7 +3,6 @@ package parser
 import (
 	"errors"
 	"fmt"
-	"net"
 	"strings"
 	"time"
 
@@ -172,75 +171,24 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri
 	if n.Name != "" {
 		NodesHits.With(prometheus.Labels{"source": p.Line.Src, "type": p.Line.Module, "name": n.Name}).Inc()
 	}
-	isWhitelisted := false
-	hasWhitelist := false
-	var srcs []net.IP
-	/*overflow and log don't hold the source ip in the same field, should be changed */
-	/* perform whitelist checks for ips, cidr accordingly */
-	/* TODO move whitelist elsewhere */
-	if p.Type == types.LOG {
-		if _, ok := p.Meta["source_ip"]; ok {
-			srcs = append(srcs, net.ParseIP(p.Meta["source_ip"]))
-		}
-	} else if p.Type == types.OVFLW {
-		for k := range p.Overflow.Sources {
-			srcs = append(srcs, net.ParseIP(k))
-		}
+	exprErr := error(nil)
+	isWhitelisted := n.CheckIPsWL(p.ParseIPSources())
+	if !isWhitelisted {
+		isWhitelisted, exprErr = n.CheckExprWL(cachedExprEnv)
 	}
-	for _, src := range srcs {
-		if isWhitelisted {
-			break
-		}
-		for _, v := range n.Whitelist.B_Ips {
-			if v.Equal(src) {
-				clog.Debugf("Event from [%s] is whitelisted by IP (%s), reason [%s]", src, v, n.Whitelist.Reason)
-				isWhitelisted = true
-			} else {
-				clog.Tracef("whitelist: %s is not eq [%s]", src, v)
-			}
-			hasWhitelist = true
-		}
-		for _, v := range n.Whitelist.B_Cidrs {
-			if v.Contains(src) {
-				clog.Debugf("Event from [%s] is whitelisted by CIDR (%s), reason [%s]", src, v, n.Whitelist.Reason)
-				isWhitelisted = true
-			} else {
-				clog.Tracef("whitelist: %s not in [%s]", src, v)
-			}
-			hasWhitelist = true
-		}
+	if exprErr != nil {
+		// Previous code returned nil if there was an error, so we keep this behavior
+		return false, nil //nolint:nilerr
 	}
 
-	/* run whitelist expression tests anyway */
-	for eidx, e := range n.Whitelist.B_Exprs {
-		output, err := expr.Run(e.Filter, cachedExprEnv)
-		if err != nil {
-			clog.Warningf("failed to run whitelist expr : %v", err)
-			clog.Debug("Event leaving node : ko")
-			return false, nil
-		}
-		switch out := output.(type) {
-		case bool:
-			if n.Debug {
-				e.ExprDebugger.Run(clog, out, cachedExprEnv)
-			}
-			if out {
-				clog.Debugf("Event is whitelisted by expr, reason [%s]", n.Whitelist.Reason)
-				isWhitelisted = true
-			}
-			hasWhitelist = true
-		default:
-			log.Errorf("unexpected type %t (%v) while running '%s'", output, output, n.Whitelist.Exprs[eidx])
-		}
-	}
 	if isWhitelisted && !p.Whitelisted {
 		p.Whitelisted = true
 		p.WhitelistReason = n.Whitelist.Reason
 		/*huglily wipe the ban order if the event is whitelisted and it's an overflow */
 		if p.Type == types.OVFLW { /*don't do this at home kids */
 			ips := []string{}
-			for _, src := range srcs {
-				ips = append(ips, src.String())
+			for k := range p.Overflow.Sources {
+				ips = append(ips, k)
 			}
 			clog.Infof("Ban for %s whitelisted, reason [%s]", strings.Join(ips, ","), n.Whitelist.Reason)
 			p.Overflow.Whitelisted = true
@@ -395,9 +343,10 @@ func (n *Node) process(p *types.Event, ctx UnixParserCtx, expressionEnv map[stri
 	}
 
 	/*
-		This is to apply statics when the node *has* whitelists that successfully matched the node.
+		This is to apply statics when the node either was whitelisted, or is not a whitelist (it has no expr/ips wl)
+		It is overconvoluted and should be simplified
 	*/
-	if len(n.Statics) > 0 && (isWhitelisted || !hasWhitelist) {
+	if len(n.Statics) > 0 && (isWhitelisted || !n.ContainsWLs()) {
 		clog.Debugf("+ Processing %d statics", len(n.Statics))
 		// if all else is good in whitelist, process node's statics
 		err := n.ProcessStatics(n.Statics, p)
@@ -610,36 +559,11 @@ func (n *Node) compile(pctx *UnixParserCtx, ectx EnricherCtx) error {
 	}
 
 	/* compile whitelists if present */
-	for _, v := range n.Whitelist.Ips {
-		n.Whitelist.B_Ips = append(n.Whitelist.B_Ips, net.ParseIP(v))
-		n.Logger.Debugf("adding ip %s to whitelists", net.ParseIP(v))
-		valid = true
-	}
-
-	for _, v := range n.Whitelist.Cidrs {
-		_, tnet, err := net.ParseCIDR(v)
-		if err != nil {
-			n.Logger.Fatalf("Unable to parse cidr whitelist '%s' : %v.", v, err)
-		}
-		n.Whitelist.B_Cidrs = append(n.Whitelist.B_Cidrs, tnet)
-		n.Logger.Debugf("adding cidr %s to whitelists", tnet)
-		valid = true
-	}
-
-	for _, filter := range n.Whitelist.Exprs {
-		expression := &ExprWhitelist{}
-		expression.Filter, err = expr.Compile(filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...)
-		if err != nil {
-			n.Logger.Fatalf("Unable to compile whitelist expression '%s' : %v.", filter, err)
-		}
-		expression.ExprDebugger, err = exprhelpers.NewDebugger(filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...)
-		if err != nil {
-			log.Errorf("unable to build debug filter for '%s' : %s", filter, err)
-		}
-		n.Whitelist.B_Exprs = append(n.Whitelist.B_Exprs, expression)
-		n.Logger.Debugf("adding expression %s to whitelists", filter)
-		valid = true
+	whitelistValid, err := n.CompileWLs()
+	if err != nil {
+		return err
 	}
+	valid = valid || whitelistValid
 
 	if !valid {
 		/* node is empty, error force return */

+ 111 - 0
pkg/parser/whitelist.go

@@ -1,11 +1,14 @@
 package parser
 
 import (
+	"fmt"
 	"net"
 
+	"github.com/antonmedv/expr"
 	"github.com/antonmedv/expr/vm"
 
 	"github.com/crowdsecurity/crowdsec/pkg/exprhelpers"
+	"github.com/crowdsecurity/crowdsec/pkg/types"
 )
 
 type Whitelist struct {
@@ -22,3 +25,111 @@ type ExprWhitelist struct {
 	Filter       *vm.Program
 	ExprDebugger *exprhelpers.ExprDebugger // used to debug expression by printing the content of each variable of the expression
 }
+
+func (n *Node) ContainsWLs() bool {
+	return n.ContainsIPLists() || n.ContainsExprLists()
+}
+
+func (n *Node) ContainsExprLists() bool {
+	return len(n.Whitelist.B_Exprs) > 0
+}
+
+func (n *Node) ContainsIPLists() bool {
+	return len(n.Whitelist.B_Ips) > 0 || len(n.Whitelist.B_Cidrs) > 0
+}
+
+func (n *Node) CheckIPsWL(srcs []net.IP) bool {
+	isWhitelisted := false
+	if !n.ContainsIPLists() {
+		return isWhitelisted
+	}
+	for _, src := range srcs {
+		if isWhitelisted {
+			break
+		}
+		for _, v := range n.Whitelist.B_Ips {
+			if v.Equal(src) {
+				n.Logger.Debugf("Event from [%s] is whitelisted by IP (%s), reason [%s]", src, v, n.Whitelist.Reason)
+				isWhitelisted = true
+				break
+			}
+			n.Logger.Tracef("whitelist: %s is not eq [%s]", src, v)
+		}
+		for _, v := range n.Whitelist.B_Cidrs {
+			if v.Contains(src) {
+				n.Logger.Debugf("Event from [%s] is whitelisted by CIDR (%s), reason [%s]", src, v, n.Whitelist.Reason)
+				isWhitelisted = true
+				break
+			}
+			n.Logger.Tracef("whitelist: %s not in [%s]", src, v)
+		}
+	}
+	return isWhitelisted
+}
+
+func (n *Node) CheckExprWL(cachedExprEnv map[string]interface{}) (bool, error) {
+	isWhitelisted := false
+
+	if !n.ContainsExprLists() {
+		return false, nil
+	}
+	/* run whitelist expression tests anyway */
+	for eidx, e := range n.Whitelist.B_Exprs {
+		//if we already know the event is whitelisted, skip the rest of the expressions
+		if isWhitelisted {
+			break
+		}
+		output, err := expr.Run(e.Filter, cachedExprEnv)
+		if err != nil {
+			n.Logger.Warningf("failed to run whitelist expr : %v", err)
+			n.Logger.Debug("Event leaving node : ko")
+			return isWhitelisted, err
+		}
+		switch out := output.(type) {
+		case bool:
+			if n.Debug {
+				e.ExprDebugger.Run(n.Logger, out, cachedExprEnv)
+			}
+			if out {
+				n.Logger.Debugf("Event is whitelisted by expr, reason [%s]", n.Whitelist.Reason)
+				isWhitelisted = true
+			}
+		default:
+			n.Logger.Errorf("unexpected type %t (%v) while running '%s'", output, output, n.Whitelist.Exprs[eidx])
+		}
+	}
+	return isWhitelisted, nil
+}
+
+func (n *Node) CompileWLs() (bool, error) {
+	for _, v := range n.Whitelist.Ips {
+		n.Whitelist.B_Ips = append(n.Whitelist.B_Ips, net.ParseIP(v))
+		n.Logger.Debugf("adding ip %s to whitelists", net.ParseIP(v))
+	}
+
+	for _, v := range n.Whitelist.Cidrs {
+		_, tnet, err := net.ParseCIDR(v)
+		if err != nil {
+			return false, fmt.Errorf("unable to parse cidr whitelist '%s' : %v", v, err)
+		}
+		n.Whitelist.B_Cidrs = append(n.Whitelist.B_Cidrs, tnet)
+		n.Logger.Debugf("adding cidr %s to whitelists", tnet)
+	}
+
+	for _, filter := range n.Whitelist.Exprs {
+		var err error
+		expression := &ExprWhitelist{}
+		expression.Filter, err = expr.Compile(filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...)
+		if err != nil {
+			return false, fmt.Errorf("unable to compile whitelist expression '%s' : %v", filter, err)
+		}
+		expression.ExprDebugger, err = exprhelpers.NewDebugger(filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...)
+		if err != nil {
+			n.Logger.Errorf("unable to build debug filter for '%s' : %s", filter, err)
+		}
+
+		n.Whitelist.B_Exprs = append(n.Whitelist.B_Exprs, expression)
+		n.Logger.Debugf("adding expression %s to whitelists", filter)
+	}
+	return n.ContainsWLs(), nil
+}

+ 300 - 0
pkg/parser/whitelist_test.go

@@ -0,0 +1,300 @@
+package parser
+
+import (
+	"testing"
+
+	log "github.com/sirupsen/logrus"
+	"github.com/stretchr/testify/require"
+
+	"github.com/crowdsecurity/go-cs-lib/cstest"
+
+	"github.com/crowdsecurity/crowdsec/pkg/models"
+	"github.com/crowdsecurity/crowdsec/pkg/types"
+)
+
+func TestWhitelistCompile(t *testing.T) {
+	node := &Node{
+		Logger: log.NewEntry(log.New()),
+	}
+	tests := []struct {
+		name        string
+		whitelist   Whitelist
+		expectedErr string
+	}{
+		{
+			name: "Valid CIDR whitelist",
+			whitelist: Whitelist{
+				Reason: "test",
+				Cidrs: []string{
+					"127.0.0.1/24",
+				},
+			},
+		},
+		{
+			name: "Invalid CIDR whitelist",
+			whitelist: Whitelist{
+				Reason: "test",
+				Cidrs: []string{
+					"127.0.0.1/1000",
+				},
+			},
+			expectedErr: "invalid CIDR address",
+		},
+		{
+			name: "Valid EXPR whitelist",
+			whitelist: Whitelist{
+				Reason: "test",
+				Exprs: []string{
+					"1==1",
+				},
+			},
+		},
+		{
+			name: "Invalid EXPR whitelist",
+			whitelist: Whitelist{
+				Reason: "test",
+				Exprs: []string{
+					"evt.THISPROPERTYSHOULDERROR == true",
+				},
+			},
+			expectedErr: "types.Event has no field",
+		},
+	}
+
+	for _, tt := range tests {
+		tt := tt
+		t.Run(tt.name, func(t *testing.T) {
+			node.Whitelist = tt.whitelist
+			_, err := node.CompileWLs()
+			cstest.RequireErrorContains(t, err, tt.expectedErr)
+		})
+	}
+}
+
+func TestWhitelistCheck(t *testing.T) {
+	node := &Node{
+		Logger: log.NewEntry(log.New()),
+	}
+	tests := []struct {
+		name      string
+		whitelist Whitelist
+		event     *types.Event
+		expected  bool
+	}{
+		{
+			name: "IP Whitelisted",
+			whitelist: Whitelist{
+				Reason: "test",
+				Ips: []string{
+					"127.0.0.1",
+				},
+			},
+			event: &types.Event{
+				Meta: map[string]string{
+					"source_ip": "127.0.0.1",
+				},
+			},
+			expected: true,
+		},
+		{
+			name: "IP Not Whitelisted",
+			whitelist: Whitelist{
+				Reason: "test",
+				Ips: []string{
+					"127.0.0.1",
+				},
+			},
+			event: &types.Event{
+				Meta: map[string]string{
+					"source_ip": "127.0.0.2",
+				},
+			},
+		},
+		{
+			name: "CIDR Whitelisted",
+			whitelist: Whitelist{
+				Reason: "test",
+				Cidrs: []string{
+					"127.0.0.1/32",
+				},
+			},
+			event: &types.Event{
+				Meta: map[string]string{
+					"source_ip": "127.0.0.1",
+				},
+			},
+			expected: true,
+		},
+		{
+			name: "CIDR Not Whitelisted",
+			whitelist: Whitelist{
+				Reason: "test",
+				Cidrs: []string{
+					"127.0.0.1/32",
+				},
+			},
+			event: &types.Event{
+				Meta: map[string]string{
+					"source_ip": "127.0.0.2",
+				},
+			},
+		},
+		{
+			name: "EXPR Whitelisted",
+			whitelist: Whitelist{
+				Reason: "test",
+				Exprs: []string{
+					"evt.Meta.source_ip == '127.0.0.1'",
+				},
+			},
+			event: &types.Event{
+				Meta: map[string]string{
+					"source_ip": "127.0.0.1",
+				},
+			},
+			expected: true,
+		},
+		{
+			name: "EXPR Not Whitelisted",
+			whitelist: Whitelist{
+				Reason: "test",
+				Exprs: []string{
+					"evt.Meta.source_ip == '127.0.0.1'",
+				},
+			},
+			event: &types.Event{
+				Meta: map[string]string{
+					"source_ip": "127.0.0.2",
+				},
+			},
+		},
+		{
+			name: "Postoverflow IP Whitelisted",
+			whitelist: Whitelist{
+				Reason: "test",
+				Ips: []string{
+					"192.168.1.1",
+				},
+			},
+			event: &types.Event{
+				Type: types.OVFLW,
+				Overflow: types.RuntimeAlert{
+					Sources: map[string]models.Source{
+						"192.168.1.1": {},
+					},
+				},
+			},
+			expected: true,
+		},
+		{
+			name: "Postoverflow IP Not Whitelisted",
+			whitelist: Whitelist{
+				Reason: "test",
+				Ips: []string{
+					"192.168.1.2",
+				},
+			},
+			event: &types.Event{
+				Type: types.OVFLW,
+				Overflow: types.RuntimeAlert{
+					Sources: map[string]models.Source{
+						"192.168.1.1": {},
+					},
+				},
+			},
+		},
+		{
+			name: "Postoverflow CIDR Whitelisted",
+			whitelist: Whitelist{
+				Reason: "test",
+				Cidrs: []string{
+					"192.168.1.1/32",
+				},
+			},
+			event: &types.Event{
+				Type: types.OVFLW,
+				Overflow: types.RuntimeAlert{
+					Sources: map[string]models.Source{
+						"192.168.1.1": {},
+					},
+				},
+			},
+			expected: true,
+		},
+		{
+			name: "Postoverflow CIDR Not Whitelisted",
+			whitelist: Whitelist{
+				Reason: "test",
+				Cidrs: []string{
+					"192.168.1.2/32",
+				},
+			},
+			event: &types.Event{
+				Type: types.OVFLW,
+				Overflow: types.RuntimeAlert{
+					Sources: map[string]models.Source{
+						"192.168.1.1": {},
+					},
+				},
+			},
+		},
+		{
+			name: "Postoverflow EXPR Whitelisted",
+			whitelist: Whitelist{
+				Reason: "test",
+				Exprs: []string{
+					"evt.Overflow.APIAlerts[0].Source.Cn == 'test'",
+				},
+			},
+			event: &types.Event{
+				Type: types.OVFLW,
+				Overflow: types.RuntimeAlert{
+					APIAlerts: []models.Alert{
+						{
+							Source: &models.Source{
+								Cn: "test",
+							},
+						},
+					},
+				},
+			},
+			expected: true,
+		},
+		{
+			name: "Postoverflow EXPR Not Whitelisted",
+			whitelist: Whitelist{
+				Reason: "test",
+				Exprs: []string{
+					"evt.Overflow.APIAlerts[0].Source.Cn == 'test2'",
+				},
+			},
+			event: &types.Event{
+				Type: types.OVFLW,
+				Overflow: types.RuntimeAlert{
+					APIAlerts: []models.Alert{
+						{
+							Source: &models.Source{
+								Cn: "test",
+							},
+						},
+					},
+				},
+			},
+		},
+	}
+
+	for _, tt := range tests {
+		tt := tt
+		t.Run(tt.name, func(t *testing.T) {
+			var err error
+			node.Whitelist = tt.whitelist
+			node.CompileWLs()
+			isWhitelisted := node.CheckIPsWL(tt.event.ParseIPSources())
+			if !isWhitelisted {
+				isWhitelisted, err = node.CheckExprWL(map[string]interface{}{"evt": tt.event})
+			}
+			require.NoError(t, err)
+			require.Equal(t, tt.expected, isWhitelisted)
+		})
+	}
+}

+ 16 - 0
pkg/types/event.go

@@ -1,6 +1,7 @@
 package types
 
 import (
+	"net"
 	"time"
 
 	log "github.com/sirupsen/logrus"
@@ -73,6 +74,21 @@ func (e *Event) GetMeta(key string) string {
 	return ""
 }
 
+func (e *Event) ParseIPSources() []net.IP {
+	var srcs []net.IP
+	switch e.Type {
+	case LOG:
+		if _, ok := e.Meta["source_ip"]; ok {
+			srcs = append(srcs, net.ParseIP(e.Meta["source_ip"]))
+		}
+	case OVFLW:
+		for k := range e.Overflow.Sources {
+			srcs = append(srcs, net.ParseIP(k))
+		}
+	}
+	return srcs
+}
+
 // Move in leakybuckets
 const (
 	Undefined = ""

+ 79 - 0
pkg/types/event_test.go

@@ -0,0 +1,79 @@
+package types
+
+import (
+	"net"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+
+	"github.com/crowdsecurity/crowdsec/pkg/models"
+)
+
+func TestParseIPSources(t *testing.T) {
+	tests := []struct {
+		name     string
+		evt      Event
+		expected []net.IP
+	}{
+		{
+			name: "ParseIPSources: Valid Log Sources",
+			evt: Event{
+				Type: LOG,
+				Meta: map[string]string{
+					"source_ip": "127.0.0.1",
+				},
+			},
+			expected: []net.IP{
+				net.ParseIP("127.0.0.1"),
+			},
+		},
+		{
+			name: "ParseIPSources: Valid Overflow Sources",
+			evt: Event{
+				Type: OVFLW,
+				Overflow: RuntimeAlert{
+					Sources: map[string]models.Source{
+						"127.0.0.1": {},
+					},
+				},
+			},
+			expected: []net.IP{
+				net.ParseIP("127.0.0.1"),
+			},
+		},
+		{
+			name: "ParseIPSources: Invalid Log Sources",
+			evt: Event{
+				Type: LOG,
+				Meta: map[string]string{
+					"source_ip": "IAMNOTANIP",
+				},
+			},
+			expected: []net.IP{
+				nil,
+			},
+		},
+		{
+			name: "ParseIPSources: Invalid Overflow Sources",
+			evt: Event{
+				Type: OVFLW,
+				Overflow: RuntimeAlert{
+					Sources: map[string]models.Source{
+						"IAMNOTANIP": {},
+					},
+				},
+			},
+			expected: []net.IP{
+				nil,
+			},
+		},
+	}
+
+	for _, tt := range tests {
+		tt := tt
+		t.Run(tt.name, func(t *testing.T) {
+			ips := tt.evt.ParseIPSources()
+			assert.Equal(t, ips, tt.expected)
+		})
+	}
+}