Pārlūkot izejas kodu

refact BulkDeleteDecisions (#2308)

Code cleanup and de-duplication.
mmetc 1 gadu atpakaļ
vecāks
revīzija
15542b78fb
2 mainītis faili ar 84 papildinājumiem un 85 dzēšanām
  1. 46 40
      pkg/apiserver/apic.go
  2. 38 45
      pkg/database/decisions.go

+ 46 - 40
pkg/apiserver/apic.go

@@ -7,7 +7,6 @@ import (
 	"net"
 	"net/http"
 	"net/url"
-	"slices"
 	"strconv"
 	"strings"
 	"sync"
@@ -17,6 +16,7 @@ import (
 	"github.com/pkg/errors"
 	log "github.com/sirupsen/logrus"
 	"gopkg.in/tomb.v2"
+	"slices"
 
 	"github.com/crowdsecurity/go-cs-lib/ptr"
 	"github.com/crowdsecurity/go-cs-lib/trace"
@@ -383,19 +383,16 @@ func (a *apic) CAPIPullIsOld() (bool, error) {
 }
 
 func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delete_counters map[string]map[string]int) (int, error) {
-	var filter map[string][]string
-	var nbDeleted int
+	nbDeleted := 0
 	for _, decision := range deletedDecisions {
-		if strings.ToLower(*decision.Scope) == "ip" {
-			filter = make(map[string][]string, 1)
-			filter["value"] = []string{*decision.Value}
-		} else {
-			filter = make(map[string][]string, 3)
-			filter["value"] = []string{*decision.Value}
+		filter := map[string][]string{
+			"value":  {*decision.Value},
+			"origin": {*decision.Origin},
+		}
+		if strings.ToLower(*decision.Scope) != "ip" {
 			filter["type"] = []string{*decision.Type}
 			filter["scopes"] = []string{*decision.Scope}
 		}
-		filter["origin"] = []string{*decision.Origin}
 
 		dbCliRet, _, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter)
 		if err != nil {
@@ -412,20 +409,17 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet
 }
 
 func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, delete_counters map[string]map[string]int) (int, error) {
-	var filter map[string][]string
 	var nbDeleted int
 	for _, decisions := range deletedDecisions {
 		scope := decisions.Scope
 		for _, decision := range decisions.Decisions {
-			if strings.ToLower(*scope) == "ip" {
-				filter = make(map[string][]string, 1)
-				filter["value"] = []string{decision}
-			} else {
-				filter = make(map[string][]string, 2)
-				filter["value"] = []string{decision}
+			filter := map[string][]string{
+				"value":  {decision},
+				"origin": {types.CAPIOrigin},
+			}
+			if strings.ToLower(*scope) != "ip" {
 				filter["scopes"] = []string{*scope}
 			}
-			filter["origin"] = []string{types.CAPIOrigin}
 
 			dbCliRet, _, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter)
 			if err != nil {
@@ -479,30 +473,42 @@ func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert {
 }
 
 func createAlertForDecision(decision *models.Decision) *models.Alert {
-	newAlert := &models.Alert{}
-	newAlert.Source = &models.Source{}
-	newAlert.Source.Scope = ptr.Of("")
-	if *decision.Origin == types.CAPIOrigin { //to make things more user friendly, we replace CAPI with community-blocklist
-		newAlert.Scenario = ptr.Of(types.CAPIOrigin)
-		newAlert.Source.Scope = ptr.Of(types.CAPIOrigin)
-	} else if *decision.Origin == types.ListOrigin {
-		newAlert.Scenario = ptr.Of(*decision.Scenario)
-		newAlert.Source.Scope = ptr.Of(types.ListOrigin)
-	} else {
+	var (
+		scenario string
+		scope    string
+	)
+
+	switch *decision.Origin {
+	case types.CAPIOrigin:
+		scenario = types.CAPIOrigin
+		scope = types.CAPIOrigin
+	case types.ListOrigin:
+		scenario = *decision.Scenario
+		scope = types.ListOrigin
+	default:
+		// XXX: this or nil?
+		scenario = ""
+		scope = ""
 		log.Warningf("unknown origin %s", *decision.Origin)
 	}
-	newAlert.Message = ptr.Of("")
-	newAlert.Source.Value = ptr.Of("")
-	newAlert.StartAt = ptr.Of(time.Now().UTC().Format(time.RFC3339))
-	newAlert.StopAt = ptr.Of(time.Now().UTC().Format(time.RFC3339))
-	newAlert.Capacity = ptr.Of(int32(0))
-	newAlert.Simulated = ptr.Of(false)
-	newAlert.EventsCount = ptr.Of(int32(0))
-	newAlert.Leakspeed = ptr.Of("")
-	newAlert.ScenarioHash = ptr.Of("")
-	newAlert.ScenarioVersion = ptr.Of("")
-	newAlert.MachineID = database.CapiMachineID
-	return newAlert
+
+	return &models.Alert{
+		Source: &models.Source{
+			Scope: ptr.Of(scope),
+			Value: ptr.Of(""),
+		},
+		Scenario:        ptr.Of(scenario),
+		Message:         ptr.Of(""),
+		StartAt:         ptr.Of(time.Now().UTC().Format(time.RFC3339)),
+		StopAt:          ptr.Of(time.Now().UTC().Format(time.RFC3339)),
+		Capacity:        ptr.Of(int32(0)),
+		Simulated:       ptr.Of(false),
+		EventsCount:     ptr.Of(int32(0)),
+		Leakspeed:       ptr.Of(""),
+		ScenarioHash:    ptr.Of(""),
+		ScenarioVersion: ptr.Of(""),
+		MachineID:       database.CapiMachineID,
+	}
 }
 
 // This function takes in list of parent alerts and decisions and then pairs them up.

+ 38 - 45
pkg/database/decisions.go

@@ -9,6 +9,8 @@ import (
 	"entgo.io/ent/dialect/sql"
 	"github.com/pkg/errors"
 
+	"github.com/crowdsecurity/go-cs-lib/slicetools"
+
 	"github.com/crowdsecurity/crowdsec/pkg/database/ent"
 	"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
 	"github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate"
@@ -23,7 +25,6 @@ type DecisionsByScenario struct {
 }
 
 func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string][]string) (*ent.DecisionQuery, error) {
-
 	var err error
 	var start_ip, start_sfx, end_ip, end_sfx int64
 	var ip_sz int
@@ -545,55 +546,39 @@ func (c *Client) SoftDeleteDecisionsWithFilter(filter map[string][]string) (stri
 
 // BulkDeleteDecisions set the expiration of a bulk of decisions to now() or hard deletes them.
 // We are doing it this way so we can return impacted decisions for sync with CAPI/PAPI
-func (c *Client) BulkDeleteDecisions(DecisionsToDelete []*ent.Decision, softDelete bool) (int, error) {
-	bulkSize := 256 //scientifically proven to be the best value for bulk delete
-	idsToDelete := make([]int, 0, bulkSize)
-	totalUpdates := 0
-	for i := 0; i < len(DecisionsToDelete); i++ {
-		idsToDelete = append(idsToDelete, DecisionsToDelete[i].ID)
-		if len(idsToDelete) == bulkSize {
-
-			if softDelete {
-				nbUpdates, err := c.Ent.Decision.Update().Where(
-					decision.IDIn(idsToDelete...),
-				).SetUntil(time.Now().UTC()).Save(c.CTX)
-				if err != nil {
-					return totalUpdates, errors.Wrap(err, "soft delete decisions with provided filter")
-				}
-				totalUpdates += nbUpdates
-			} else {
-				nbUpdates, err := c.Ent.Decision.Delete().Where(
-					decision.IDIn(idsToDelete...),
-				).Exec(c.CTX)
-				if err != nil {
-					return totalUpdates, errors.Wrap(err, "hard delete decisions with provided filter")
-				}
-				totalUpdates += nbUpdates
-			}
-			idsToDelete = make([]int, 0, bulkSize)
-		}
+func (c *Client) BulkDeleteDecisions(decisionsToDelete []*ent.Decision, softDelete bool) (int, error) {
+	const bulkSize = 256 //scientifically proven to be the best value for bulk delete
+
+	var (
+		nbUpdates    int
+		err          error
+		totalUpdates = 0
+	)
+
+	idsToDelete := make([]int, len(decisionsToDelete))
+	for i, decision := range decisionsToDelete {
+		idsToDelete[i] = decision.ID
 	}
 
-	if len(idsToDelete) > 0 {
+	for _, chunk := range slicetools.Chunks(idsToDelete, bulkSize) {
 		if softDelete {
-			nbUpdates, err := c.Ent.Decision.Update().Where(
-				decision.IDIn(idsToDelete...),
+			nbUpdates, err = c.Ent.Decision.Update().Where(
+				decision.IDIn(chunk...),
 			).SetUntil(time.Now().UTC()).Save(c.CTX)
 			if err != nil {
-				return totalUpdates, errors.Wrap(err, "soft delete decisions with provided filter")
+				return totalUpdates, fmt.Errorf("soft delete decisions with provided filter: %w", err)
 			}
-			totalUpdates += nbUpdates
 		} else {
-			nbUpdates, err := c.Ent.Decision.Delete().Where(
-				decision.IDIn(idsToDelete...),
+			nbUpdates, err = c.Ent.Decision.Delete().Where(
+				decision.IDIn(chunk...),
 			).Exec(c.CTX)
 			if err != nil {
-				return totalUpdates, errors.Wrap(err, "hard delete decisions with provided filter")
+				return totalUpdates, fmt.Errorf("hard delete decisions with provided filter: %w", err)
 			}
-			totalUpdates += nbUpdates
 		}
-
+		totalUpdates += nbUpdates
 	}
+
 	return totalUpdates, nil
 }
 
@@ -601,6 +586,7 @@ func (c *Client) BulkDeleteDecisions(DecisionsToDelete []*ent.Decision, softDele
 func (c *Client) SoftDeleteDecisionByID(decisionID int) (int, []*ent.Decision, error) {
 	toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX)
 
+	// XXX: do we want 500 or 404 here?
 	if err != nil || len(toUpdate) == 0 {
 		c.Log.Warningf("SoftDeleteDecisionByID : %v (nb soft deleted: %d)", err, len(toUpdate))
 		return 0, nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionID)
@@ -609,6 +595,7 @@ func (c *Client) SoftDeleteDecisionByID(decisionID int) (int, []*ent.Decision, e
 	if len(toUpdate) == 0 {
 		return 0, nil, ItemNotFound
 	}
+
 	count, err := c.BulkDeleteDecisions(toUpdate, true)
 	return count, toUpdate, err
 }
@@ -639,10 +626,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) {
 }
 
 func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Time) (int, error) {
-	var err error
-	var start_ip, start_sfx, end_ip, end_sfx int64
-	var ip_sz, count int
-	ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue)
+	ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(decisionValue)
 
 	if err != nil {
 		return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err)
@@ -652,11 +636,13 @@ func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Tim
 	decisions := c.Ent.Decision.Query().Where(
 		decision.CreatedAtGT(since),
 	)
+
 	decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
 	if err != nil {
 		return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter")
 	}
-	count, err = decisions.Count(c.CTX)
+
+	count, err := decisions.Count(c.CTX)
 	if err != nil {
 		return 0, errors.Wrapf(err, "fail to count decisions")
 	}
@@ -681,7 +667,10 @@ func applyStartIpEndIpFilter(decisions *ent.DecisionQuery, contains bool, ip_sz
 				decision.IPSizeEQ(int64(ip_sz)),
 			))
 		}
-	} else if ip_sz == 16 {
+		return decisions, nil
+	}
+
+	if ip_sz == 16 {
 		/*decision contains {start_ip,end_ip}*/
 		if contains {
 			decisions = decisions.Where(decision.And(
@@ -733,9 +722,13 @@ func applyStartIpEndIpFilter(decisions *ent.DecisionQuery, contains bool, ip_sz
 				),
 			))
 		}
-	} else if ip_sz != 0 {
+		return decisions, nil
+	}
+
+	if ip_sz != 0 {
 		return nil, errors.Wrapf(InvalidFilter, "unknown ip size %d", ip_sz)
 	}
+
 	return decisions, nil
 }