Browse Source

Allow bouncers to filter decisions by scope (#817)

Signed-off-by: Shivam Sandbhor <shivam@crowdsec.net>
Shivam Sandbhor 4 years ago
parent
commit
f25d02a7c8

+ 5 - 0
cmd/crowdsec-cli/utils.go

@@ -63,6 +63,11 @@ func manageCliDecisionAlerts(ip *string, ipRange *string, scope *string, value *
 		*scope = types.Ip
 		*scope = types.Ip
 	case "range":
 	case "range":
 		*scope = types.Range
 		*scope = types.Range
+	case "country":
+		*scope = types.Country
+	case "as":
+		*scope = types.AS
+
 	}
 	}
 	return nil
 	return nil
 }
 }

+ 1 - 1
pkg/acquisition/tests/test.log

@@ -1 +1 @@
-one log line
+one log line

+ 5 - 2
pkg/apiclient/decisions_service.go

@@ -3,6 +3,7 @@ package apiclient
 import (
 import (
 	"context"
 	"context"
 	"fmt"
 	"fmt"
+	"strings"
 
 
 	"github.com/crowdsecurity/crowdsec/pkg/models"
 	"github.com/crowdsecurity/crowdsec/pkg/models"
 	qs "github.com/google/go-querystring/query"
 	qs "github.com/google/go-querystring/query"
@@ -52,10 +53,12 @@ func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*m
 	return &decisions, resp, nil
 	return &decisions, resp, nil
 }
 }
 
 
-func (s *DecisionsService) GetStream(ctx context.Context, startup bool) (*models.DecisionsStreamResponse, *Response, error) {
+func (s *DecisionsService) GetStream(ctx context.Context, startup bool, scopes []string) (*models.DecisionsStreamResponse, *Response, error) {
 	var decisions models.DecisionsStreamResponse
 	var decisions models.DecisionsStreamResponse
-
 	u := fmt.Sprintf("%s/decisions/stream?startup=%t", s.client.URLPrefix, startup)
 	u := fmt.Sprintf("%s/decisions/stream?startup=%t", s.client.URLPrefix, startup)
+	if len(scopes) > 0 {
+		u += "&scopes=" + strings.Join(scopes, ",")
+	}
 	req, err := s.client.NewRequest("GET", u, nil)
 	req, err := s.client.NewRequest("GET", u, nil)
 	if err != nil {
 	if err != nil {
 		return nil, nil, err
 		return nil, nil, err

+ 2 - 2
pkg/apiclient/decisions_service_test.go

@@ -160,7 +160,7 @@ func TestDecisionsStream(t *testing.T) {
 		},
 		},
 	}
 	}
 
 
-	decisions, resp, err := newcli.Decisions.GetStream(context.Background(), true)
+	decisions, resp, err := newcli.Decisions.GetStream(context.Background(), true, []string{})
 	require.NoError(t, err)
 	require.NoError(t, err)
 
 
 	if resp.Response.StatusCode != http.StatusOK {
 	if resp.Response.StatusCode != http.StatusOK {
@@ -175,7 +175,7 @@ func TestDecisionsStream(t *testing.T) {
 	}
 	}
 
 
 	//and second call, we get empty lists
 	//and second call, we get empty lists
-	decisions, resp, err = newcli.Decisions.GetStream(context.Background(), false)
+	decisions, resp, err = newcli.Decisions.GetStream(context.Background(), false, []string{})
 	require.NoError(t, err)
 	require.NoError(t, err)
 
 
 	if resp.Response.StatusCode != http.StatusOK {
 	if resp.Response.StatusCode != http.StatusOK {

+ 1 - 1
pkg/apiserver/apic.go

@@ -234,7 +234,7 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) {
 func (a *apic) PullTop() error {
 func (a *apic) PullTop() error {
 	var err error
 	var err error
 
 
-	data, _, err := a.apiClient.Decisions.GetStream(context.Background(), a.startup)
+	data, _, err := a.apiClient.Decisions.GetStream(context.Background(), a.startup, []string{})
 	if err != nil {
 	if err != nil {
 		return errors.Wrap(err, "get stream")
 		return errors.Wrap(err, "get stream")
 	}
 	}

+ 11 - 4
pkg/apiserver/controllers/v1/decisions.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
 	"strconv"
 	"strconv"
+	"strings"
 	"time"
 	"time"
 
 
 	"github.com/crowdsecurity/crowdsec/pkg/database/ent"
 	"github.com/crowdsecurity/crowdsec/pkg/database/ent"
@@ -127,10 +128,16 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {
 		return
 		return
 	}
 	}
 
 
+	filters := make(map[string][]string)
+	filters["scope"] = []string{"ip", "range"}
+	if val, ok := gctx.Request.URL.Query()["scopes"]; ok {
+		filters["scope"] = strings.Split(val[0], ",")
+	}
+
 	// if the blocker just start, return all decisions
 	// if the blocker just start, return all decisions
 	if val, ok := gctx.Request.URL.Query()["startup"]; ok {
 	if val, ok := gctx.Request.URL.Query()["startup"]; ok {
 		if val[0] == "true" {
 		if val[0] == "true" {
-			data, err := c.DBClient.QueryAllDecisions()
+			data, err := c.DBClient.QueryAllDecisionsWithFilters(filters)
 			if err != nil {
 			if err != nil {
 				log.Errorf("failed querying decisions: %v", err)
 				log.Errorf("failed querying decisions: %v", err)
 				gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
 				gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
@@ -144,7 +151,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {
 			}
 			}
 
 
 			// getting expired decisions
 			// getting expired decisions
-			data, err = c.DBClient.QueryExpiredDecisions()
+			data, err = c.DBClient.QueryExpiredDecisionsWithFilters(filters)
 			if err != nil {
 			if err != nil {
 				log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err)
 				log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err)
 				gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
 				gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
@@ -172,7 +179,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {
 	}
 	}
 
 
 	// getting new decisions
 	// getting new decisions
-	data, err = c.DBClient.QueryNewDecisionsSince(bouncerInfo.LastPull)
+	data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(bouncerInfo.LastPull, filters)
 	if err != nil {
 	if err != nil {
 		log.Errorf("unable to query new decision for '%s' : %v", bouncerInfo.Name, err)
 		log.Errorf("unable to query new decision for '%s' : %v", bouncerInfo.Name, err)
 		gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
 		gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
@@ -186,7 +193,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {
 	}
 	}
 
 
 	// getting expired decisions
 	// getting expired decisions
-	data, err = c.DBClient.QueryExpiredDecisionsSince(bouncerInfo.LastPull.Add((-2 * time.Second))) // do we want to give exactly lastPull time ?
+	data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(bouncerInfo.LastPull.Add((-2 * time.Second)), filters) // do we want to give exactly lastPull time ?
 	if err != nil {
 	if err != nil {
 		log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err)
 		log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err)
 		gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
 		gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})

+ 2 - 2
pkg/apiserver/decisions_test.go

@@ -277,7 +277,7 @@ func TestGetDecision(t *testing.T) {
 	router.ServeHTTP(w, req)
 	router.ServeHTTP(w, req)
 
 
 	assert.Equal(t, 200, w.Code)
 	assert.Equal(t, 200, w.Code)
-	assert.Contains(t, w.Body.String(), "\"id\":1,\"origin\":\"test\",\"scenario\":\"crowdsecurity/test\",\"scope\":\"ip\",\"type\":\"ban\",\"value\":\"127.0.0.1\"}]")
+	assert.Contains(t, w.Body.String(), "\"id\":1,\"origin\":\"test\",\"scenario\":\"crowdsecurity/test\",\"scope\":\"Ip\",\"type\":\"ban\",\"value\":\"127.0.0.1\"}]")
 
 
 }
 }
 
 
@@ -449,5 +449,5 @@ func TestStreamDecision(t *testing.T) {
 	router.ServeHTTP(w, req)
 	router.ServeHTTP(w, req)
 
 
 	assert.Equal(t, 200, w.Code)
 	assert.Equal(t, 200, w.Code)
-	assert.Contains(t, w.Body.String(), "\"id\":1,\"origin\":\"test\",\"scenario\":\"crowdsecurity/test\",\"scope\":\"ip\",\"type\":\"ban\",\"value\":\"127.0.0.1\"}]}")
+	assert.Contains(t, w.Body.String(), "\"id\":1,\"origin\":\"test\",\"scenario\":\"crowdsecurity/test\",\"scope\":\"Ip\",\"type\":\"ban\",\"value\":\"127.0.0.1\"}]}")
 }
 }

+ 1 - 1
pkg/apiserver/tests/alert_sample.json

@@ -10,7 +10,7 @@
                 "duration": "1h",
                 "duration": "1h",
                 "origin": "test",
                 "origin": "test",
                 "scenario": "crowdsecurity/test",
                 "scenario": "crowdsecurity/test",
-                "scope": "ip",
+                "scope": "Ip",
                 "value": "127.0.0.1",
                 "value": "127.0.0.1",
                 "type": "ban"
                 "type": "ban"
             }
             }

+ 55 - 20
pkg/database/decisions.go

@@ -41,13 +41,19 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string]
 				return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err)
 				return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err)
 			}
 			}
 		case "scope":
 		case "scope":
-			var scope string = value[0]
-			if strings.ToLower(scope) == "ip" {
-				scope = types.Ip
-			} else if strings.ToLower(scope) == "range" {
-				scope = types.Range
+			for i, scope := range value {
+				switch strings.ToLower(scope) {
+				case "ip":
+					value[i] = types.Ip
+				case "range":
+					value[i] = types.Range
+				case "country":
+					value[i] = types.Country
+				case "as":
+					value[i] = types.AS
+				}
 			}
 			}
-			query = query.Where(decision.ScopeEQ(scope))
+			query = query.Where(decision.ScopeIn(value...))
 		case "value":
 		case "value":
 			query = query.Where(decision.ValueEQ(value[0]))
 			query = query.Where(decision.ValueEQ(value[0]))
 		case "type":
 		case "type":
@@ -165,37 +171,66 @@ func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Dec
 	return data, nil
 	return data, nil
 }
 }
 
 
-func (c *Client) QueryAllDecisions() ([]*ent.Decision, error) {
-	data, err := c.Ent.Decision.Query().Where(decision.UntilGT(time.Now())).All(c.CTX)
+func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) {
+	query := c.Ent.Decision.Query().Where(decision.UntilGT(time.Now()))
+	query, err := BuildDecisionRequestWithFilter(query, filters)
+
+	if err != nil {
+		c.Log.Warningf("QueryAllDecisionsWithFilters : %s", err)
+		return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions with filters")
+	}
+
+	data, err := query.All(c.CTX)
 	if err != nil {
 	if err != nil {
-		c.Log.Warningf("QueryAllDecisions : %s", err)
-		return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions")
+		c.Log.Warningf("QueryAllDecisionsWithFilters : %s", err)
+		return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions with filters")
 	}
 	}
 	return data, nil
 	return data, nil
 }
 }
 
 
-func (c *Client) QueryExpiredDecisions() ([]*ent.Decision, error) {
-	data, err := c.Ent.Decision.Query().Where(decision.UntilLT(time.Now())).All(c.CTX)
+func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) {
+	query := c.Ent.Decision.Query().Where(decision.UntilLT(time.Now()))
+	query, err := BuildDecisionRequestWithFilter(query, filters)
+
+	if err != nil {
+		c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err)
+		return []*ent.Decision{}, errors.Wrap(QueryFail, "get expired decisions with filters")
+	}
+	data, err := query.All(c.CTX)
 	if err != nil {
 	if err != nil {
-		c.Log.Warningf("QueryExpiredDecisions : %s", err)
+		c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err)
 		return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions")
 		return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions")
 	}
 	}
 	return data, nil
 	return data, nil
 }
 }
 
 
-func (c *Client) QueryExpiredDecisionsSince(since time.Time) ([]*ent.Decision, error) {
-	data, err := c.Ent.Decision.Query().Where(decision.UntilLT(time.Now())).Where(decision.UntilGT(since)).All(c.CTX)
+func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters map[string][]string) ([]*ent.Decision, error) {
+	query := c.Ent.Decision.Query().Where(decision.UntilLT(time.Now())).Where(decision.UntilGT(since))
+	query, err := BuildDecisionRequestWithFilter(query, filters)
 	if err != nil {
 	if err != nil {
-		c.Log.Warningf("QueryExpiredDecisionsSince : %s", err)
-		return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions")
+		c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err)
+		return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters")
 	}
 	}
+
+	data, err := query.All(c.CTX)
+	if err != nil {
+		c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err)
+		return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters")
+	}
+
 	return data, nil
 	return data, nil
 }
 }
 
 
-func (c *Client) QueryNewDecisionsSince(since time.Time) ([]*ent.Decision, error) {
-	data, err := c.Ent.Decision.Query().Where(decision.CreatedAtGT(since)).All(c.CTX)
+func (c *Client) QueryNewDecisionsSinceWithFilters(since time.Time, filters map[string][]string) ([]*ent.Decision, error) {
+	query := c.Ent.Decision.Query().Where(decision.CreatedAtGT(since))
+	query, err := BuildDecisionRequestWithFilter(query, filters)
+	if err != nil {
+		c.Log.Warningf("QueryNewDecisionsSinceWithFilters : %s", err)
+		return []*ent.Decision{}, errors.Wrapf(QueryFail, "new decisions since '%s'", since.String())
+	}
+	data, err := query.All(c.CTX)
 	if err != nil {
 	if err != nil {
-		c.Log.Warningf("QueryNewDecisionsSince : %s", err)
+		c.Log.Warningf("QueryNewDecisionsSinceWithFilters : %s", err)
 		return []*ent.Decision{}, errors.Wrapf(QueryFail, "new decisions since '%s'", since.String())
 		return []*ent.Decision{}, errors.Wrapf(QueryFail, "new decisions since '%s'", since.String())
 	}
 	}
 	return data, nil
 	return data, nil

+ 2 - 0
pkg/types/event.go

@@ -57,6 +57,8 @@ const (
 	Ip        = "Ip"
 	Ip        = "Ip"
 	Range     = "Range"
 	Range     = "Range"
 	Filter    = "Filter"
 	Filter    = "Filter"
+	Country   = "Country"
+	AS        = "AS"
 )
 )
 
 
 //Move in leakybuckets
 //Move in leakybuckets