Ver código fonte

Fast bulk alert delete (#1791)

Shivam Sandbhor 2 anos atrás
pai
commit
74659a82ab
3 arquivos alterados com 52 adições e 30 exclusões
  1. 2 0
      cmd/crowdsec-cli/alerts.go
  2. 30 30
      pkg/database/alerts.go
  3. 20 0
      tests/bats/80_alerts.bats

+ 2 - 0
cmd/crowdsec-cli/alerts.go

@@ -328,6 +328,8 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`,
 				if contained != nil && *contained {
 					alertDeleteFilter.Contains = new(bool)
 				}
+				limit := 0
+				alertDeleteFilter.Limit = &limit
 			} else {
 				limit := 0
 				alertDeleteFilter = apiclient.AlertsDeleteOpts{Limit: &limit}

+ 30 - 30
pkg/database/alerts.go

@@ -16,6 +16,7 @@ import (
 	"github.com/crowdsecurity/crowdsec/pkg/database/ent/event"
 	"github.com/crowdsecurity/crowdsec/pkg/database/ent/machine"
 	"github.com/crowdsecurity/crowdsec/pkg/database/ent/meta"
+	"github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate"
 	"github.com/crowdsecurity/crowdsec/pkg/models"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	"github.com/davecgh/go-spew/spew"
@@ -545,7 +546,8 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([
 	return ret, nil
 }
 
-func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]string) (*ent.AlertQuery, error) {
+func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, error) {
+	predicates := make([]predicate.Alert, 0)
 	var err error
 	var start_ip, start_sfx, end_ip, end_sfx int64
 	var hasActiveDecision bool
@@ -557,7 +559,7 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str
 	/*the simulated filter is a bit different : if it's not present *or* set to false, specifically exclude records with simulated to true */
 	if v, ok := filter["simulated"]; ok {
 		if v[0] == "false" {
-			alerts = alerts.Where(alert.SimulatedEQ(false))
+			predicates = append(predicates, alert.SimulatedEQ(false))
 		}
 	}
 
@@ -579,11 +581,11 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str
 			} else if strings.ToLower(scope) == "range" {
 				scope = types.Range
 			}
-			alerts = alerts.Where(alert.SourceScopeEQ(scope))
+			predicates = append(predicates, alert.SourceScopeEQ(scope))
 		case "value":
-			alerts = alerts.Where(alert.SourceValueEQ(value[0]))
+			predicates = append(predicates, alert.SourceValueEQ(value[0]))
 		case "scenario":
-			alerts = alerts.Where(alert.HasDecisionsWith(decision.ScenarioEQ(value[0])))
+			predicates = append(predicates, alert.HasDecisionsWith(decision.ScenarioEQ(value[0])))
 		case "ip", "range":
 			ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(value[0])
 			if err != nil {
@@ -598,7 +600,7 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str
 			if since.IsZero() {
 				return nil, fmt.Errorf("Empty time now() - %s", since.String())
 			}
-			alerts = alerts.Where(alert.StartedAtGTE(since))
+			predicates = append(predicates, alert.StartedAtGTE(since))
 		case "created_before":
 			duration, err := types.ParseDuration(value[0])
 			if err != nil {
@@ -608,7 +610,7 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str
 			if since.IsZero() {
 				return nil, fmt.Errorf("Empty time now() - %s", since.String())
 			}
-			alerts = alerts.Where(alert.CreatedAtLTE(since))
+			predicates = append(predicates, alert.CreatedAtLTE(since))
 		case "until":
 			duration, err := types.ParseDuration(value[0])
 			if err != nil {
@@ -618,14 +620,14 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str
 			if until.IsZero() {
 				return nil, fmt.Errorf("Empty time now() - %s", until.String())
 			}
-			alerts = alerts.Where(alert.StartedAtLTE(until))
+			predicates = append(predicates, alert.StartedAtLTE(until))
 		case "decision_type":
-			alerts = alerts.Where(alert.HasDecisionsWith(decision.TypeEQ(value[0])))
+			predicates = append(predicates, alert.HasDecisionsWith(decision.TypeEQ(value[0])))
 		case "origin":
-			alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(value[0])))
+			predicates = append(predicates, alert.HasDecisionsWith(decision.OriginEQ(value[0])))
 		case "include_capi": //allows to exclude one or more specific origins
 			if value[0] == "false" {
-				alerts = alerts.Where(alert.HasDecisionsWith(decision.Or(decision.OriginEQ("crowdsec"), decision.OriginEQ("cscli"))))
+				predicates = append(predicates, alert.HasDecisionsWith(decision.Or(decision.OriginEQ("crowdsec"), decision.OriginEQ("cscli"))))
 			} else if value[0] != "true" {
 				log.Errorf("Invalid bool '%s' for include_capi", value[0])
 			}
@@ -634,9 +636,9 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str
 				return nil, errors.Wrapf(ParseType, "'%s' is not a boolean: %s", value[0], err)
 			}
 			if hasActiveDecision {
-				alerts = alerts.Where(alert.HasDecisionsWith(decision.UntilGTE(time.Now().UTC())))
+				predicates = append(predicates, alert.HasDecisionsWith(decision.UntilGTE(time.Now().UTC())))
 			} else {
-				alerts = alerts.Where(alert.Not(alert.HasDecisions()))
+				predicates = append(predicates, alert.Not(alert.HasDecisions()))
 			}
 		case "limit":
 			continue
@@ -651,13 +653,13 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str
 
 	if ip_sz == 4 {
 		if contains { /*decision contains {start_ip,end_ip}*/
-			alerts = alerts.Where(alert.And(
+			predicates = append(predicates, alert.And(
 				alert.HasDecisionsWith(decision.StartIPLTE(start_ip)),
 				alert.HasDecisionsWith(decision.EndIPGTE(end_ip)),
 				alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))),
 			))
 		} else { /*decision is contained within {start_ip,end_ip}*/
-			alerts = alerts.Where(alert.And(
+			predicates = append(predicates, alert.And(
 				alert.HasDecisionsWith(decision.StartIPGTE(start_ip)),
 				alert.HasDecisionsWith(decision.EndIPLTE(end_ip)),
 				alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))),
@@ -666,7 +668,7 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str
 	} else if ip_sz == 16 {
 
 		if contains { /*decision contains {start_ip,end_ip}*/
-			alerts = alerts.Where(alert.And(
+			predicates = append(predicates, alert.And(
 				//matching addr size
 				alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))),
 				alert.Or(
@@ -690,7 +692,7 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str
 				),
 			))
 		} else { /*decision is contained within {start_ip,end_ip}*/
-			alerts = alerts.Where(alert.And(
+			predicates = append(predicates, alert.And(
 				//matching addr size
 				alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))),
 				alert.Or(
@@ -717,7 +719,14 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str
 	} else if ip_sz != 0 {
 		return nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz)
 	}
-	return alerts, nil
+	return predicates, nil
+}
+func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]string) (*ent.AlertQuery, error) {
+	preds, err := AlertPredicatesFromFilter(filter)
+	if err != nil {
+		return nil, err
+	}
+	return alerts.Where(preds...), nil
 }
 
 func (c *Client) AlertsCountPerScenario(filters map[string][]string) (map[string]int, error) {
@@ -901,20 +910,11 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error {
 }
 
 func (c *Client) DeleteAlertWithFilter(filter map[string][]string) (int, error) {
-	var err error
-
-	// Get all the alerts that match the filter
-	alertsToDelete, err := c.QueryAlertWithFilter(filter)
+	preds, err := AlertPredicatesFromFilter(filter)
 	if err != nil {
-		return 0, errors.Wrap(DeleteFail, "alert query failed")
+		return 0, err
 	}
-
-	deleted, err := c.DeleteAlertGraphBatch(alertsToDelete)
-	if err != nil {
-		c.Log.Warningf("DeleteAlertWithFilter : %s", err)
-		return 0, errors.Wrapf(DeleteFail, "%d alert(s)", len(alertsToDelete))
-	}
-	return deleted, nil
+	return c.Ent.Alert.Delete().Where(preds...).Exec(c.CTX)
 }
 
 func (c *Client) FlushOrphans() {

+ 20 - 0
tests/bats/80_alerts.bats

@@ -129,6 +129,26 @@ teardown() {
     # XXX TODO: delete by scope, id, value, scenario, range..
 }
 
+@test "cscli alerts delete (with cascade to decisions)" {
+    run -0 cscli decisions add -i 1.2.3.4
+    run -0 cscli decisions list -o json
+    run -0 jq '. | length' <(output)
+    assert_output 1
+
+    run -0 --separate-stderr cscli alerts delete -i 1.2.3.4
+    assert_stderr --partial 'alert(s) deleted'
+    run -0 cscli decisions list -o json
+    assert_output null
+}
+
+@test "cscli alerts delete (must ignore the query limit)" {
+    for i in $(seq 1 200); do
+        run -0 cscli decisions add -i 1.2.3.4
+    done
+    run -0 --separate-stderr cscli alerts delete -i 1.2.3.4
+    assert_stderr --partial '200 alert(s) deleted'
+}
+
 @test "bad duration" {
     skip 'TODO'
     run -0 cscli decisions add -i 10.20.30.40 -t ban