From fd94e2c056855d63de43aa1cd8bac8e14483ffc1 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Mon, 4 Sep 2023 14:21:45 +0200 Subject: [PATCH] refactor alert/decisions insert/update to avoid database locking in bulk operations (#2446) --- cmd/crowdsec-cli/config_show.go | 6 + cmd/crowdsec-cli/decisions_import.go | 17 +- pkg/csconfig/database.go | 12 +- pkg/database/alerts.go | 740 +++++++++++---------------- pkg/database/database.go | 71 --- pkg/database/flush.go | 278 ++++++++++ test/bats/90_decisions.bats | 2 +- 7 files changed, 605 insertions(+), 521 deletions(-) create mode 100644 pkg/database/flush.go diff --git a/cmd/crowdsec-cli/config_show.go b/cmd/crowdsec-cli/config_show.go index f152bff06..9f5b11fc1 100644 --- a/cmd/crowdsec-cli/config_show.go +++ b/cmd/crowdsec-cli/config_show.go @@ -163,6 +163,12 @@ Central API: - User : {{.DbConfig.User}} - DB Name : {{.DbConfig.DbName}} {{- end }} +{{- if .DbConfig.MaxOpenConns }} + - Max Open Conns : {{.DbConfig.MaxOpenConns}} +{{- end }} +{{- if ne .DbConfig.DecisionBulkSize 0 }} + - Decision Bulk Size : {{.DbConfig.DecisionBulkSize}} +{{- end }} {{- if .DbConfig.Flush }} {{- if .DbConfig.Flush.MaxAge }} - Flush age : {{.DbConfig.Flush.MaxAge}} diff --git a/cmd/crowdsec-cli/decisions_import.go b/cmd/crowdsec-cli/decisions_import.go index e7ba1d83f..56fc37c87 100644 --- a/cmd/crowdsec-cli/decisions_import.go +++ b/cmd/crowdsec-cli/decisions_import.go @@ -188,7 +188,9 @@ func runDecisionsImport(cmd *cobra.Command, args []string) error { } } - alerts := models.AddAlertsRequest{} + if len(decisions) > 1000 { + log.Infof("You are about to add %d decisions, this may take a while", len(decisions)) + } for _, chunk := range slicetools.Chunks(decisions, batchSize) { log.Debugf("Processing chunk of %d decisions", len(chunk)) @@ -212,16 +214,11 @@ func runDecisionsImport(cmd *cobra.Command, args []string) error { ScenarioVersion: ptr.Of(""), Decisions: chunk, } - alerts = append(alerts, &importAlert) - } - if len(decisions) > 1000 { - log.Infof("You are about to add %d decisions, this may take a while", len(decisions)) - } - - _, _, err = Client.Alerts.Add(context.Background(), alerts) - if err != nil { - return err + _, _, err = Client.Alerts.Add(context.Background(), models.AddAlertsRequest{&importAlert}) + if err != nil { + return err + } } log.Infof("Imported %d decisions", len(decisions)) diff --git a/pkg/csconfig/database.go b/pkg/csconfig/database.go index 91df34f20..61e5fccb4 100644 --- a/pkg/csconfig/database.go +++ b/pkg/csconfig/database.go @@ -12,7 +12,12 @@ import ( var DEFAULT_MAX_OPEN_CONNS = 100 -const defaultDecisionBulkSize = 50 +const ( + defaultDecisionBulkSize = 1000 + // we need an upper bound due to the sqlite limit of 32k variables in a query + // we have 15 variables per decision, so 32768/15 = 2184.5333 + maxDecisionBulkSize = 2000 +) type DatabaseCfg struct { User string `yaml:"user"` @@ -68,6 +73,11 @@ func (c *Config) LoadDBConfig() error { c.DbConfig.DecisionBulkSize = defaultDecisionBulkSize } + if c.DbConfig.DecisionBulkSize > maxDecisionBulkSize { + log.Warningf("decision_bulk_size too high (%d), setting to the maximum value of %d", c.DbConfig.DecisionBulkSize, maxDecisionBulkSize) + c.DbConfig.DecisionBulkSize = maxDecisionBulkSize + } + if c.DbConfig.Type == "sqlite" { if c.DbConfig.UseWal == nil { log.Warning("You are using sqlite without WAL, this can have a performance impact. If you do not store the database in a network share, set db_config.use_wal to true. Set explicitly to false to disable this warning.") diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index fcc2cdfdc..3a5041444 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -9,17 +9,18 @@ import ( "strings" "time" + "github.com/mattn/go-sqlite3" + "github.com/davecgh/go-spew/spew" "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/go-cs-lib/slicetools" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" - "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" - "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" "github.com/crowdsecurity/crowdsec/pkg/models" @@ -30,6 +31,7 @@ const ( paginationSize = 100 // used to queryAlert to avoid 'too many SQL variable' defaultLimit = 100 // default limit of element to returns when query alerts bulkSize = 50 // bulk size when create alerts + maxLockRetries = 10 // how many times to retry a bulk operation when sqlite3.ErrBusy is encountered ) func formatAlertCN(source models.Source) string { @@ -43,32 +45,36 @@ func formatAlertCN(source models.Source) string { } func formatAlertSource(alert *models.Alert) string { - if alert.Source == nil { + if alert.Source == nil || alert.Source.Scope == nil || *alert.Source.Scope == "" { return "empty source" } if *alert.Source.Scope == types.Ip { ret := "ip " + *alert.Source.Value + cn := formatAlertCN(*alert.Source) if cn != "" { ret += " (" + cn + ")" } + return ret } if *alert.Source.Scope == types.Range { ret := "range " + *alert.Source.Value + cn := formatAlertCN(*alert.Source) if cn != "" { ret += " (" + cn + ")" } + return ret } return *alert.Source.Scope + " " + *alert.Source.Value } -func formatAlertAsString(machineId string, alert *models.Alert) []string { +func formatAlertAsString(machineID string, alert *models.Alert) []string { src := formatAlertSource(alert) /**/ @@ -84,11 +90,15 @@ func formatAlertAsString(machineId string, alert *models.Alert) []string { reason := fmt.Sprintf("%s by %s", msg, src) if len(alert.Decisions) == 0 { - return []string{fmt.Sprintf("(%s) alert : %s", machineId, reason)} + return []string{fmt.Sprintf("(%s) alert : %s", machineID, reason)} } var retStr []string + if alert.Decisions[0].Origin != nil && *alert.Decisions[0].Origin == types.CscliImportOrigin { + return []string{fmt.Sprintf("(%s) alert : %s for %d decisions", machineID, reason, len(alert.Decisions))} + } + for i, decisionItem := range alert.Decisions { decision := "" if alert.Simulated != nil && *alert.Simulated { @@ -96,25 +106,29 @@ func formatAlertAsString(machineId string, alert *models.Alert) []string { } else if decisionItem.Simulated != nil && *decisionItem.Simulated { decision = "(simulated decision)" } + if log.GetLevel() >= log.DebugLevel { /*spew is expensive*/ log.Debugf("%s", spew.Sdump(decisionItem)) } + if len(alert.Decisions) > 1 { reason = fmt.Sprintf("%s for %d/%d decisions", msg, i+1, len(alert.Decisions)) } - machineIdOrigin := "" - if machineId == "" { - machineIdOrigin = *decisionItem.Origin + + machineIDOrigin := "" + if machineID == "" { + machineIDOrigin = *decisionItem.Origin } else { - machineIdOrigin = fmt.Sprintf("%s/%s", machineId, *decisionItem.Origin) + machineIDOrigin = fmt.Sprintf("%s/%s", machineID, *decisionItem.Origin) } decision += fmt.Sprintf("%s %s on %s %s", *decisionItem.Duration, *decisionItem.Type, *decisionItem.Scope, *decisionItem.Value) retStr = append(retStr, - fmt.Sprintf("(%s) %s : %s", machineIdOrigin, reason, decision)) + fmt.Sprintf("(%s) %s : %s", machineIDOrigin, reason, decision)) } + return retStr } @@ -122,7 +136,6 @@ func formatAlertAsString(machineId string, alert *models.Alert) []string { // if alert already exists, it checks it associated decisions already exists // if some associated decisions are missing (ie. previous insert ended up in error) it inserts them func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) (string, error) { - if alertItem.UUID == "" { return "", fmt.Errorf("alert UUID is empty") } @@ -135,11 +148,12 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) //alert wasn't found, insert it (expected hotpath) if ent.IsNotFound(err) || len(alerts) == 0 { - ret, err := c.CreateAlert(machineID, []*models.Alert{alertItem}) + alertIDs, err := c.CreateAlert(machineID, []*models.Alert{alertItem}) if err != nil { return "", fmt.Errorf("unable to create alert: %w", err) } - return ret[0], nil + + return alertIDs[0], nil } //this should never happen @@ -148,22 +162,26 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) } log.Infof("Alert %s already exists, checking associated decisions", alertItem.UUID) + //alert is found, check for any missing decisions - missingUuids := []string{} - newUuids := []string{} - for _, decItem := range alertItem.Decisions { - newUuids = append(newUuids, decItem.UUID) + + newUuids := make([]string, len(alertItem.Decisions)) + for i, decItem := range alertItem.Decisions { + newUuids[i] = decItem.UUID } foundAlert := alerts[0] - foundUuids := []string{} - for _, decItem := range foundAlert.Edges.Decisions { - foundUuids = append(foundUuids, decItem.UUID) + foundUuids := make([]string, len(foundAlert.Edges.Decisions)) + + for i, decItem := range foundAlert.Edges.Decisions { + foundUuids[i] = decItem.UUID } sort.Strings(foundUuids) sort.Strings(newUuids) + missingUuids := []string{} + for idx, uuid := range newUuids { if len(foundUuids) < idx+1 || uuid != foundUuids[idx] { log.Warningf("Decision with uuid %s not found in alert %s", uuid, foundAlert.UUID) @@ -176,9 +194,10 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) return "", nil } - //add any and all missing decisions based on their uuids - //prepare missing decisions + // add any and all missing decisions based on their uuids + // prepare missing decisions missingDecisions := []*models.Decision{} + for _, uuid := range missingUuids { for _, newDecision := range alertItem.Decisions { if newDecision.UUID == uuid { @@ -190,8 +209,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) //add missing decisions log.Debugf("Adding %d missing decisions to alert %s", len(missingDecisions), foundAlert.UUID) - decisions := make([]*ent.Decision, 0) - decisionBulk := make([]*ent.DecisionCreate, 0, c.decisionBulkSize) + decisionBuilders := make([]*ent.DecisionCreate, len(missingDecisions)) for i, decisionItem := range missingDecisions { var start_ip, start_sfx, end_ip, end_sfx int64 @@ -204,20 +222,24 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) return "", errors.Wrapf(InvalidIPOrRange, "invalid addr/range %s : %s", *decisionItem.Value, err) } } + decisionDuration, err := time.ParseDuration(*decisionItem.Duration) if err != nil { log.Warningf("invalid duration %s for decision %s", *decisionItem.Duration, decisionItem.UUID) continue } + //use the created_at from the alert instead alertTime, err := time.Parse(time.RFC3339, alertItem.CreatedAt) if err != nil { log.Errorf("unable to parse alert time %s : %s", alertItem.CreatedAt, err) + alertTime = time.Now() } + decisionUntil := alertTime.UTC().Add(decisionDuration) - decisionCreate := c.Ent.Decision.Create(). + decisionBuilder := c.Ent.Decision.Create(). SetUntil(decisionUntil). SetScenario(*decisionItem.Scenario). SetType(*decisionItem.Type). @@ -232,58 +254,34 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) SetSimulated(*alertItem.Simulated). SetUUID(decisionItem.UUID) - decisionBulk = append(decisionBulk, decisionCreate) - if len(decisionBulk) == c.decisionBulkSize { - decisionsCreateRet, err := c.Ent.Decision.CreateBulk(decisionBulk...).Save(c.CTX) - if err != nil { - return "", errors.Wrapf(BulkError, "creating alert decisions: %s", err) + decisionBuilders[i] = decisionBuilder + } - } - decisions = append(decisions, decisionsCreateRet...) - if len(missingDecisions)-i <= c.decisionBulkSize { - decisionBulk = make([]*ent.DecisionCreate, 0, (len(missingDecisions) - i)) - } else { - decisionBulk = make([]*ent.DecisionCreate, 0, c.decisionBulkSize) - } + decisions := []*ent.Decision{} + + builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) + + for _, builderChunk := range builderChunks { + decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(c.CTX) + if err != nil { + return "", fmt.Errorf("creating alert decisions: %w", err) } + + decisions = append(decisions, decisionsCreateRet...) } - decisionsCreateRet, err := c.Ent.Decision.CreateBulk(decisionBulk...).Save(c.CTX) - if err != nil { - return "", errors.Wrapf(BulkError, "creating alert decisions: %s", err) - } - decisions = append(decisions, decisionsCreateRet...) + //now that we bulk created missing decisions, let's update the alert - err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisions...).Exec(c.CTX) - if err != nil { - return "", fmt.Errorf("updating alert %s: %w", alertItem.UUID, err) + + decisionChunks := slicetools.Chunks(decisions, c.decisionBulkSize) + + for _, decisionChunk := range decisionChunks { + err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(c.CTX) + if err != nil { + return "", fmt.Errorf("updating alert %s: %w", alertItem.UUID, err) + } } return "", nil - -} - -func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]string, error) { - pageStart := 0 - pageEnd := bulkSize - ret := []string{} - for { - if pageEnd >= len(alertList) { - results, err := c.CreateAlertBulk(machineID, alertList[pageStart:]) - if err != nil { - return []string{}, fmt.Errorf("unable to create alerts: %s", err) - } - ret = append(ret, results...) - break - } - results, err := c.CreateAlertBulk(machineID, alertList[pageStart:pageEnd]) - if err != nil { - return []string{}, fmt.Errorf("unable to create alerts: %s", err) - } - ret = append(ret, results...) - pageStart += bulkSize - pageEnd += bulkSize - } - return ret, nil } // UpdateCommunityBlocklist is called to update either the community blocklist (or other lists the user subscribed to) @@ -291,21 +289,23 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str // 1st pull, you get decisions [1,2,3]. it inserts [1,2,3] // 2nd pull, you get decisions [1,2,3,4]. it inserts [1,2,3,4] and will try to delete [1,2,3,4] with a different alert ID and same origin func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, int, error) { - var err error - if alertItem == nil { return 0, 0, 0, fmt.Errorf("nil alert") } + if alertItem.StartAt == nil { return 0, 0, 0, fmt.Errorf("nil start_at") } + startAtTime, err := time.Parse(time.RFC3339, *alertItem.StartAt) if err != nil { return 0, 0, 0, errors.Wrapf(ParseTimeFail, "start_at field time '%s': %s", *alertItem.StartAt, err) } + if alertItem.StopAt == nil { return 0, 0, 0, fmt.Errorf("nil stop_at") } + stopAtTime, err := time.Parse(time.RFC3339, *alertItem.StopAt) if err != nil { return 0, 0, 0, errors.Wrapf(ParseTimeFail, "stop_at field time '%s': %s", *alertItem.StopAt, err) @@ -314,6 +314,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in ts, err := time.Parse(time.RFC3339, *alertItem.StopAt) if err != nil { c.Log.Errorf("While parsing StartAt of item %s : %s", *alertItem.StopAt, err) + ts = time.Now().UTC() } @@ -352,9 +353,9 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in if err != nil { return 0, 0, 0, errors.Wrapf(BulkError, "error creating transaction : %s", err) } - decisionBulk := make([]*ent.DecisionCreate, 0, c.decisionBulkSize) - valueList := make([]string, 0, c.decisionBulkSize) + DecOrigin := CapiMachineID + if *alertItem.Decisions[0].Origin == CapiMachineID || *alertItem.Decisions[0].Origin == CapiListsMachineID { DecOrigin = *alertItem.Decisions[0].Origin } else { @@ -364,25 +365,33 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in deleted := 0 inserted := 0 - for i, decisionItem := range alertItem.Decisions { + decisionBuilders := make([]*ent.DecisionCreate, 0, len(alertItem.Decisions)) + valueList := make([]string, 0, len(alertItem.Decisions)) + + for _, decisionItem := range alertItem.Decisions { var start_ip, start_sfx, end_ip, end_sfx int64 var sz int + if decisionItem.Duration == nil { log.Warning("nil duration in community decision") continue } + duration, err := time.ParseDuration(*decisionItem.Duration) if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { log.Errorf("rollback error: %s", rollbackErr) } + return 0, 0, 0, errors.Wrapf(ParseDurationFail, "decision duration '%+v' : %s", *decisionItem.Duration, err) } + if decisionItem.Scope == nil { log.Warning("nil scope in community decision") continue } + /*if the scope is IP or Range, convert the value to integers */ if strings.ToLower(*decisionItem.Scope) == "ip" || strings.ToLower(*decisionItem.Scope) == "range" { sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(*decisionItem.Value) @@ -391,11 +400,13 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in if rollbackErr != nil { log.Errorf("rollback error: %s", rollbackErr) } + return 0, 0, 0, errors.Wrapf(InvalidIPOrRange, "invalid addr/range %s : %s", *decisionItem.Value, err) } } + /*bulk insert some new decisions*/ - decisionBulk = append(decisionBulk, c.Ent.Decision.Create(). + decisionBuilder := c.Ent.Decision.Create(). SetUntil(ts.Add(duration)). SetScenario(*decisionItem.Scenario). SetType(*decisionItem.Type). @@ -408,145 +419,143 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in SetScope(*decisionItem.Scope). SetOrigin(*decisionItem.Origin). SetSimulated(*alertItem.Simulated). - SetOwner(alertRef)) + SetOwner(alertRef) + + decisionBuilders = append(decisionBuilders, decisionBuilder) /*for bulk delete of duplicate decisions*/ if decisionItem.Value == nil { log.Warning("nil value in community decision") continue } + valueList = append(valueList, *decisionItem.Value) - - if len(decisionBulk) == c.decisionBulkSize { - - insertedDecisions, err := txClient.Decision.CreateBulk(decisionBulk...).Save(c.CTX) - if err != nil { - rollbackErr := txClient.Rollback() - if rollbackErr != nil { - log.Errorf("rollback error: %s", rollbackErr) - } - return 0, 0, 0, errors.Wrapf(BulkError, "bulk creating decisions : %s", err) - } - inserted += len(insertedDecisions) - - /*Deleting older decisions from capi*/ - deletedDecisions, err := txClient.Decision.Delete(). - Where(decision.And( - decision.OriginEQ(DecOrigin), - decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))), - decision.ValueIn(valueList...), - )).Exec(c.CTX) - if err != nil { - rollbackErr := txClient.Rollback() - if rollbackErr != nil { - log.Errorf("rollback error: %s", rollbackErr) - } - return 0, 0, 0, fmt.Errorf("while deleting older community blocklist decisions: %w", err) - } - deleted += deletedDecisions - - if len(alertItem.Decisions)-i <= c.decisionBulkSize { - decisionBulk = make([]*ent.DecisionCreate, 0, (len(alertItem.Decisions) - i)) - valueList = make([]string, 0, (len(alertItem.Decisions) - i)) - } else { - decisionBulk = make([]*ent.DecisionCreate, 0, c.decisionBulkSize) - valueList = make([]string, 0, c.decisionBulkSize) - } - } - } - log.Debugf("deleted %d decisions for %s vs %s", deleted, DecOrigin, *alertItem.Decisions[0].Origin) - insertedDecisions, err := txClient.Decision.CreateBulk(decisionBulk...).Save(c.CTX) - if err != nil { - return 0, 0, 0, errors.Wrapf(BulkError, "creating alert decisions: %s", err) - } - inserted += len(insertedDecisions) - /*Deleting older decisions from capi*/ - if len(valueList) > 0 { + + deleteChunks := slicetools.Chunks(valueList, c.decisionBulkSize) + + for _, deleteChunk := range deleteChunks { + // Deleting older decisions from capi deletedDecisions, err := txClient.Decision.Delete(). Where(decision.And( decision.OriginEQ(DecOrigin), decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))), - decision.ValueIn(valueList...), + decision.ValueIn(deleteChunk...), )).Exec(c.CTX) if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { log.Errorf("rollback error: %s", rollbackErr) } + return 0, 0, 0, fmt.Errorf("while deleting older community blocklist decisions: %w", err) } + deleted += deletedDecisions } + + builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) + + for _, builderChunk := range builderChunks { + insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(c.CTX) + if err != nil { + rollbackErr := txClient.Rollback() + if rollbackErr != nil { + log.Errorf("rollback error: %s", rollbackErr) + } + + return 0, 0, 0, fmt.Errorf("while bulk creating decisions: %w", err) + } + + inserted += len(insertedDecisions) + } + + log.Debugf("deleted %d decisions for %s vs %s", deleted, DecOrigin, *alertItem.Decisions[0].Origin) + err = txClient.Commit() if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { log.Errorf("rollback error: %s", rollbackErr) } - return 0, 0, 0, errors.Wrapf(BulkError, "error committing transaction : %s", err) + + return 0, 0, 0, fmt.Errorf("error committing transaction: %w", err) } return alertRef.ID, inserted, deleted, nil } -func chunkDecisions(decisions []*ent.Decision, chunkSize int) [][]*ent.Decision { - var ret [][]*ent.Decision - var chunk []*ent.Decision - for _, d := range decisions { - chunk = append(chunk, d) - if len(chunk) == chunkSize { - ret = append(ret, chunk) - chunk = nil +func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) { + decisionCreate := make([]*ent.DecisionCreate, len(decisions)) + + for i, decisionItem := range decisions { + var start_ip, start_sfx, end_ip, end_sfx int64 + var sz int + + duration, err := time.ParseDuration(*decisionItem.Duration) + if err != nil { + return nil, errors.Wrapf(ParseDurationFail, "decision duration '%+v' : %s", *decisionItem.Duration, err) } + + /*if the scope is IP or Range, convert the value to integers */ + if strings.ToLower(*decisionItem.Scope) == "ip" || strings.ToLower(*decisionItem.Scope) == "range" { + sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(*decisionItem.Value) + if err != nil { + return nil, fmt.Errorf("%s: %w", *decisionItem.Value, InvalidIPOrRange) + } + } + + newDecision := c.Ent.Decision.Create(). + SetUntil(stopAtTime.Add(duration)). + SetScenario(*decisionItem.Scenario). + SetType(*decisionItem.Type). + SetStartIP(start_ip). + SetStartSuffix(start_sfx). + SetEndIP(end_ip). + SetEndSuffix(end_sfx). + SetIPSize(int64(sz)). + SetValue(*decisionItem.Value). + SetScope(*decisionItem.Scope). + SetOrigin(*decisionItem.Origin). + SetSimulated(simulated). + SetUUID(decisionItem.UUID) + + decisionCreate[i] = newDecision } - if len(chunk) > 0 { - ret = append(ret, chunk) + + ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(c.CTX) + if err != nil { + return nil, err } - return ret + + return ret, nil } -func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([]string, error) { - ret := []string{} - bulkSize := 20 - var owner *ent.Machine - var err error - if machineId != "" { - owner, err = c.QueryMachineByID(machineId) - if err != nil { - if errors.Cause(err) != UserNotExists { - return []string{}, errors.Wrapf(QueryFail, "machine '%s': %s", machineId, err) - } - c.Log.Debugf("CreateAlertBulk: Machine Id %s doesn't exist", machineId) - owner = nil - } - } else { - owner = nil - } +func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) { + alertBuilders := make([]*ent.AlertCreate, len(alerts)) + alertDecisions := make([][]*ent.Decision, len(alerts)) - c.Log.Debugf("writing %d items", len(alertList)) - bulk := make([]*ent.AlertCreate, 0, bulkSize) - alertDecisions := make([][]*ent.Decision, 0, bulkSize) - for i, alertItem := range alertList { - var decisions []*ent.Decision + for i, alertItem := range alerts { var metas []*ent.Meta var events []*ent.Event startAtTime, err := time.Parse(time.RFC3339, *alertItem.StartAt) if err != nil { c.Log.Errorf("CreateAlertBulk: Failed to parse startAtTime '%s', defaulting to now: %s", *alertItem.StartAt, err) + startAtTime = time.Now().UTC() } stopAtTime, err := time.Parse(time.RFC3339, *alertItem.StopAt) if err != nil { c.Log.Errorf("CreateAlertBulk: Failed to parse stopAtTime '%s', defaulting to now: %s", *alertItem.StopAt, err) + stopAtTime = time.Now().UTC() } /*display proper alert in logs*/ - for _, disp := range formatAlertAsString(machineId, alertItem) { + for _, disp := range formatAlertAsString(machineID, alertItem) { c.Log.Info(disp) } @@ -556,12 +565,15 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([ if len(alertItem.Events) > 0 { eventBulk := make([]*ent.EventCreate, len(alertItem.Events)) + for i, eventItem := range alertItem.Events { ts, err := time.Parse(time.RFC3339, *eventItem.Timestamp) if err != nil { c.Log.Errorf("CreateAlertBulk: Failed to parse event timestamp '%s', defaulting to now: %s", *eventItem.Timestamp, err) + ts = time.Now().UTC() } + marshallMetas, err := json.Marshal(eventItem.Meta) if err != nil { return nil, errors.Wrapf(MarshalFail, "event meta '%v' : %s", eventItem.Meta, err) @@ -573,6 +585,7 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([ valid := false stripSize := 2048 + for !valid && stripSize > 0 { for _, serializedItem := range eventItem.Meta { if len(serializedItem.Value) > stripSize*2 { @@ -584,9 +597,11 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([ if err != nil { return nil, errors.Wrapf(MarshalFail, "event meta '%v' : %s", eventItem.Meta, err) } + if event.SerializedValidator(string(marshallMetas)) == nil { valid = true } + stripSize /= 2 } @@ -596,19 +611,21 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([ stripped = false marshallMetas = []byte("") } - } eventBulk[i] = c.Ent.Event.Create(). SetTime(ts). SetSerialized(string(marshallMetas)) } + if stripped { - c.Log.Warningf("stripped 'serialized' field (machine %s / scenario %s)", machineId, *alertItem.Scenario) + c.Log.Warningf("stripped 'serialized' field (machine %s / scenario %s)", machineID, *alertItem.Scenario) } + if dropped { - c.Log.Warningf("dropped 'serialized' field (machine %s / scenario %s)", machineId, *alertItem.Scenario) + c.Log.Warningf("dropped 'serialized' field (machine %s / scenario %s)", machineID, *alertItem.Scenario) } + events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(c.CTX) if err != nil { return nil, errors.Wrapf(BulkError, "creating alert events: %s", err) @@ -622,70 +639,26 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([ SetKey(metaItem.Key). SetValue(metaItem.Value) } + metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(c.CTX) if err != nil { return nil, errors.Wrapf(BulkError, "creating alert meta: %s", err) } } - decisions = make([]*ent.Decision, 0) - if len(alertItem.Decisions) > 0 { - decisionBulk := make([]*ent.DecisionCreate, 0, c.decisionBulkSize) - for i, decisionItem := range alertItem.Decisions { - var start_ip, start_sfx, end_ip, end_sfx int64 - var sz int + decisions := []*ent.Decision{} - duration, err := time.ParseDuration(*decisionItem.Duration) - if err != nil { - return nil, errors.Wrapf(ParseDurationFail, "decision duration '%+v' : %s", *decisionItem.Duration, err) - } - - /*if the scope is IP or Range, convert the value to integers */ - if strings.ToLower(*decisionItem.Scope) == "ip" || strings.ToLower(*decisionItem.Scope) == "range" { - sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(*decisionItem.Value) - if err != nil { - return nil, fmt.Errorf("%s: %w", *decisionItem.Value, InvalidIPOrRange) - } - } - - decisionCreate := c.Ent.Decision.Create(). - SetUntil(stopAtTime.Add(duration)). - SetScenario(*decisionItem.Scenario). - SetType(*decisionItem.Type). - SetStartIP(start_ip). - SetStartSuffix(start_sfx). - SetEndIP(end_ip). - SetEndSuffix(end_sfx). - SetIPSize(int64(sz)). - SetValue(*decisionItem.Value). - SetScope(*decisionItem.Scope). - SetOrigin(*decisionItem.Origin). - SetSimulated(*alertItem.Simulated). - SetUUID(decisionItem.UUID) - - decisionBulk = append(decisionBulk, decisionCreate) - if len(decisionBulk) == c.decisionBulkSize { - decisionsCreateRet, err := c.Ent.Decision.CreateBulk(decisionBulk...).Save(c.CTX) - if err != nil { - return nil, errors.Wrapf(BulkError, "creating alert decisions: %s", err) - - } - decisions = append(decisions, decisionsCreateRet...) - if len(alertItem.Decisions)-i <= c.decisionBulkSize { - decisionBulk = make([]*ent.DecisionCreate, 0, (len(alertItem.Decisions) - i)) - } else { - decisionBulk = make([]*ent.DecisionCreate, 0, c.decisionBulkSize) - } - } - } - decisionsCreateRet, err := c.Ent.Decision.CreateBulk(decisionBulk...).Save(c.CTX) + decisionChunks := slicetools.Chunks(alertItem.Decisions, c.decisionBulkSize) + for _, decisionChunk := range decisionChunks { + decisionRet, err := c.createDecisionChunk(*alertItem.Simulated, stopAtTime, decisionChunk) if err != nil { - return nil, errors.Wrapf(BulkError, "creating alert decisions: %s", err) + return nil, fmt.Errorf("creating alert decisions: %w", err) } - decisions = append(decisions, decisionsCreateRet...) + + decisions = append(decisions, decisionRet...) } - alertB := c.Ent.Alert. + alertBuilder := c.Ent.Alert. Create(). SetScenario(*alertItem.Scenario). SetMessage(*alertItem.Message). @@ -711,50 +684,52 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([ AddMetas(metas...) if owner != nil { - alertB.SetOwner(owner) + alertBuilder.SetOwner(owner) } - bulk = append(bulk, alertB) - alertDecisions = append(alertDecisions, decisions) - if len(bulk) == bulkSize { - alerts, err := c.Ent.Alert.CreateBulk(bulk...).Save(c.CTX) - if err != nil { - return nil, errors.Wrapf(BulkError, "bulk creating alert : %s", err) - } - for alertIndex, a := range alerts { - ret = append(ret, strconv.Itoa(a.ID)) - d := alertDecisions[alertIndex] - decisionsChunk := chunkDecisions(d, bulkSize) - for _, d2 := range decisionsChunk { - _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(c.CTX) - if err != nil { - return nil, fmt.Errorf("error while updating decisions: %s", err) + alertBuilders[i] = alertBuilder + alertDecisions[i] = decisions + } + + alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(c.CTX) + if err != nil { + return nil, errors.Wrapf(BulkError, "bulk creating alert : %s", err) + } + + ret := make([]string, len(alertsCreateBulk)) + for i, a := range alertsCreateBulk { + ret[i] = strconv.Itoa(a.ID) + + d := alertDecisions[i] + decisionsChunk := slicetools.Chunks(d, c.decisionBulkSize) + + for _, d2 := range decisionsChunk { + retry := 0 + + for retry < maxLockRetries { + // so much for the happy path... but sqlite3 errors work differently + _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(c.CTX) + if err == nil { + break + } + + if sqliteErr, ok := err.(sqlite3.Error); ok { + if sqliteErr.Code == sqlite3.ErrBusy { + // sqlite3.Error{ + // Code: 5, + // ExtendedCode: 5, + // SystemErrno: 0, + // err: "database is locked", + // } + retry++ + log.Warningf("while updating decisions, sqlite3.ErrBusy: %s, retry %d of %d", err, retry, maxLockRetries) + time.Sleep(1 * time.Second) + + continue } } - } - if len(alertList)-i <= bulkSize { - bulk = make([]*ent.AlertCreate, 0, (len(alertList) - i)) - alertDecisions = make([][]*ent.Decision, 0, (len(alertList) - i)) - } else { - bulk = make([]*ent.AlertCreate, 0, bulkSize) - alertDecisions = make([][]*ent.Decision, 0, bulkSize) - } - } - } - alerts, err := c.Ent.Alert.CreateBulk(bulk...).Save(c.CTX) - if err != nil { - return nil, errors.Wrapf(BulkError, "leftovers creating alert : %s", err) - } - - for alertIndex, a := range alerts { - ret = append(ret, strconv.Itoa(a.ID)) - d := alertDecisions[alertIndex] - decisionsChunk := chunkDecisions(d, bulkSize) - for _, d2 := range decisionsChunk { - _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(c.CTX) - if err != nil { - return nil, fmt.Errorf("error while updating decisions: %s", err) + return nil, fmt.Errorf("error while updating decisions: %w", err) } } } @@ -762,13 +737,50 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([ return ret, nil } + +func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]string, error) { + var owner *ent.Machine + var err error + + if machineID != "" { + owner, err = c.QueryMachineByID(machineID) + if err != nil { + if errors.Cause(err) != UserNotExists { + return nil, fmt.Errorf("machine '%s': %w", machineID, err) + } + + c.Log.Debugf("CreateAlertBulk: Machine Id %s doesn't exist", machineID) + + owner = nil + } + } + + c.Log.Debugf("writing %d items", len(alertList)) + + alertChunks := slicetools.Chunks(alertList, bulkSize) + alertIDs := []string{} + + for _, alertChunk := range alertChunks { + ids, err := c.createAlertChunk(machineID, owner, alertChunk) + if err != nil { + return nil, fmt.Errorf("machine '%s': %w", machineID, err) + } + + alertIDs = append(alertIDs, ids...) + } + + return alertIDs, nil +} + func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, error) { predicates := make([]predicate.Alert, 0) + var err error var start_ip, start_sfx, end_ip, end_sfx int64 var hasActiveDecision bool var ip_sz int - var contains bool = true + var contains = true + /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer)*/ @@ -791,12 +803,13 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err) } case "scope": - var scope string = value[0] + var scope = value[0] if strings.ToLower(scope) == "ip" { scope = types.Ip } else if strings.ToLower(scope) == "range" { scope = types.Range } + predicates = append(predicates, alert.SourceScopeEQ(scope)) case "value": predicates = append(predicates, alert.SourceValueEQ(value[0])) @@ -812,30 +825,36 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e if err != nil { return nil, fmt.Errorf("while parsing duration: %w", err) } + since := time.Now().UTC().Add(-duration) if since.IsZero() { - return nil, fmt.Errorf("Empty time now() - %s", since.String()) + return nil, fmt.Errorf("empty time now() - %s", since.String()) } + predicates = append(predicates, alert.StartedAtGTE(since)) case "created_before": duration, err := ParseDuration(value[0]) if err != nil { return nil, fmt.Errorf("while parsing duration: %w", err) } + since := time.Now().UTC().Add(-duration) if since.IsZero() { return nil, fmt.Errorf("empty time now() - %s", since.String()) } + predicates = append(predicates, alert.CreatedAtLTE(since)) case "until": duration, err := ParseDuration(value[0]) if err != nil { return nil, fmt.Errorf("while parsing duration: %w", err) } + until := time.Now().UTC().Add(-duration) if until.IsZero() { return nil, fmt.Errorf("empty time now() - %s", until.String()) } + predicates = append(predicates, alert.StartedAtLTE(until)) case "decision_type": predicates = append(predicates, alert.HasDecisionsWith(decision.TypeEQ(value[0]))) @@ -855,6 +874,7 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e if hasActiveDecision, err = strconv.ParseBool(value[0]); err != nil { return nil, errors.Wrapf(ParseType, "'%s' is not a boolean: %s", value[0], err) } + if hasActiveDecision { predicates = append(predicates, alert.HasDecisionsWith(decision.UntilGTE(time.Now().UTC()))) } else { @@ -888,7 +908,6 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e )) } } else if ip_sz == 16 { - if contains { /*decision contains {start_ip,end_ip}*/ predicates = append(predicates, alert.And( //matching addr size @@ -941,18 +960,20 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e } else if ip_sz != 0 { return nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) } + return predicates, nil } + func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]string) (*ent.AlertQuery, error) { preds, err := AlertPredicatesFromFilter(filter) if err != nil { return nil, err } + return alerts.Where(preds...), nil } func (c *Client) AlertsCountPerScenario(filters map[string][]string) (map[string]int, error) { - var res []struct { Scenario string Count int @@ -988,8 +1009,8 @@ func (c *Client) TotalAlerts() (int, error) { } func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, error) { - sort := "DESC" // we sort by desc by default + if val, ok := filter["sort"]; ok { if val[0] != "ASC" && val[0] != "DESC" { c.Log.Errorf("invalid 'sort' parameter: %s", val) @@ -997,22 +1018,27 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, sort = val[0] } } + limit := defaultLimit + if val, ok := filter["limit"]; ok { limitConv, err := strconv.Atoi(val[0]) if err != nil { - return []*ent.Alert{}, errors.Wrapf(QueryFail, "bad limit in parameters: %s", val) + return nil, errors.Wrapf(QueryFail, "bad limit in parameters: %s", val) } - limit = limitConv + limit = limitConv } + offset := 0 ret := make([]*ent.Alert, 0) + for { alerts := c.Ent.Alert.Query() + alerts, err := BuildAlertRequestFromFilter(alerts, filter) if err != nil { - return []*ent.Alert{}, err + return nil, err } //only if with_decisions is present and set to false, we exclude this @@ -1022,6 +1048,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, alerts = alerts. WithDecisions() } + alerts = alerts. WithEvents(). WithMetas(). @@ -1030,7 +1057,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, if limit == 0 { limit, err = alerts.Count(c.CTX) if err != nil { - return []*ent.Alert{}, fmt.Errorf("unable to count nb alerts: %s", err) + return nil, fmt.Errorf("unable to count nb alerts: %s", err) } } @@ -1042,23 +1069,27 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, result, err := alerts.Limit(paginationSize).Offset(offset).All(c.CTX) if err != nil { - return []*ent.Alert{}, errors.Wrapf(QueryFail, "pagination size: %d, offset: %d: %s", paginationSize, offset, err) + return nil, errors.Wrapf(QueryFail, "pagination size: %d, offset: %d: %s", paginationSize, offset, err) } + if diff := limit - len(ret); diff < paginationSize { if len(result) < diff { ret = append(ret, result...) c.Log.Debugf("Pagination done, %d < %d", len(result), diff) + break } - ret = append(ret, result[0:diff]...) + ret = append(ret, result[0:diff]...) } else { ret = append(ret, result...) } + if len(ret) == limit || len(ret) == 0 || len(ret) < paginationSize { c.Log.Debugf("Pagination done len(ret) = %d", len(ret)) break } + offset += paginationSize } @@ -1153,180 +1184,10 @@ func (c *Client) DeleteAlertWithFilter(filter map[string][]string) (int, error) if err != nil { return 0, err } + return c.Ent.Alert.Delete().Where(preds...).Exec(c.CTX) } -func (c *Client) FlushOrphans() { - /* While it has only been linked to some very corner-case bug : https://github.com/crowdsecurity/crowdsec/issues/778 */ - /* We want to take care of orphaned events for which the parent alert/decision has been deleted */ - - events_count, err := c.Ent.Event.Delete().Where(event.Not(event.HasOwner())).Exec(c.CTX) - if err != nil { - c.Log.Warningf("error while deleting orphan events : %s", err) - return - } - if events_count > 0 { - c.Log.Infof("%d deleted orphan events", events_count) - } - - events_count, err = c.Ent.Decision.Delete().Where( - decision.Not(decision.HasOwner())).Where(decision.UntilLTE(time.Now().UTC())).Exec(c.CTX) - - if err != nil { - c.Log.Warningf("error while deleting orphan decisions : %s", err) - return - } - if events_count > 0 { - c.Log.Infof("%d deleted orphan decisions", events_count) - } -} - -func (c *Client) FlushAgentsAndBouncers(agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { - log.Debug("starting FlushAgentsAndBouncers") - if bouncersCfg != nil { - if bouncersCfg.ApiDuration != nil { - log.Debug("trying to delete old bouncers from api") - deletionCount, err := c.Ent.Bouncer.Delete().Where( - bouncer.LastPullLTE(time.Now().UTC().Add(-*bouncersCfg.ApiDuration)), - ).Where( - bouncer.AuthTypeEQ(types.ApiKeyAuthType), - ).Exec(c.CTX) - if err != nil { - c.Log.Errorf("while auto-deleting expired bouncers (api key) : %s", err) - } else if deletionCount > 0 { - c.Log.Infof("deleted %d expired bouncers (api auth)", deletionCount) - } - } - if bouncersCfg.CertDuration != nil { - log.Debug("trying to delete old bouncers from cert") - - deletionCount, err := c.Ent.Bouncer.Delete().Where( - bouncer.LastPullLTE(time.Now().UTC().Add(-*bouncersCfg.CertDuration)), - ).Where( - bouncer.AuthTypeEQ(types.TlsAuthType), - ).Exec(c.CTX) - if err != nil { - c.Log.Errorf("while auto-deleting expired bouncers (api key) : %s", err) - } else if deletionCount > 0 { - c.Log.Infof("deleted %d expired bouncers (api auth)", deletionCount) - } - } - } - - if agentsCfg != nil { - if agentsCfg.CertDuration != nil { - log.Debug("trying to delete old agents from cert") - - deletionCount, err := c.Ent.Machine.Delete().Where( - machine.LastHeartbeatLTE(time.Now().UTC().Add(-*agentsCfg.CertDuration)), - ).Where( - machine.Not(machine.HasAlerts()), - ).Where( - machine.AuthTypeEQ(types.TlsAuthType), - ).Exec(c.CTX) - log.Debugf("deleted %d entries", deletionCount) - if err != nil { - c.Log.Errorf("while auto-deleting expired machine (cert) : %s", err) - } else if deletionCount > 0 { - c.Log.Infof("deleted %d expired machine (cert auth)", deletionCount) - } - } - if agentsCfg.LoginPasswordDuration != nil { - log.Debug("trying to delete old agents from password") - - deletionCount, err := c.Ent.Machine.Delete().Where( - machine.LastHeartbeatLTE(time.Now().UTC().Add(-*agentsCfg.LoginPasswordDuration)), - ).Where( - machine.Not(machine.HasAlerts()), - ).Where( - machine.AuthTypeEQ(types.PasswordAuthType), - ).Exec(c.CTX) - log.Debugf("deleted %d entries", deletionCount) - if err != nil { - c.Log.Errorf("while auto-deleting expired machine (password) : %s", err) - } else if deletionCount > 0 { - c.Log.Infof("deleted %d expired machine (password auth)", deletionCount) - } - } - } - return nil -} - -func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { - var deletedByAge int - var deletedByNbItem int - var totalAlerts int - var err error - - if !c.CanFlush { - c.Log.Debug("a list is being imported, flushing later") - return nil - } - - c.Log.Debug("Flushing orphan alerts") - c.FlushOrphans() - c.Log.Debug("Done flushing orphan alerts") - totalAlerts, err = c.TotalAlerts() - if err != nil { - c.Log.Warningf("FlushAlerts (max items count) : %s", err) - return fmt.Errorf("unable to get alerts count: %w", err) - } - c.Log.Debugf("FlushAlerts (Total alerts): %d", totalAlerts) - if MaxAge != "" { - filter := map[string][]string{ - "created_before": {MaxAge}, - } - nbDeleted, err := c.DeleteAlertWithFilter(filter) - if err != nil { - c.Log.Warningf("FlushAlerts (max age) : %s", err) - return fmt.Errorf("unable to flush alerts with filter until=%s: %w", MaxAge, err) - } - c.Log.Debugf("FlushAlerts (deleted max age alerts): %d", nbDeleted) - deletedByAge = nbDeleted - } - if MaxItems > 0 { - //We get the highest id for the alerts - //We subtract MaxItems to avoid deleting alerts that are not old enough - //This gives us the oldest alert that we want to keep - //We then delete all the alerts with an id lower than this one - //We can do this because the id is auto-increment, and the database won't reuse the same id twice - lastAlert, err := c.QueryAlertWithFilter(map[string][]string{ - "sort": {"DESC"}, - "limit": {"1"}, - //we do not care about fetching the edges, we just want the id - "with_decisions": {"false"}, - }) - c.Log.Debugf("FlushAlerts (last alert): %+v", lastAlert) - if err != nil { - c.Log.Errorf("FlushAlerts: could not get last alert: %s", err) - return fmt.Errorf("could not get last alert: %w", err) - } - - if len(lastAlert) != 0 { - maxid := lastAlert[0].ID - MaxItems - - c.Log.Debugf("FlushAlerts (max id): %d", maxid) - - if maxid > 0 { - //This may lead to orphan alerts (at least on MySQL), but the next time the flush job will run, they will be deleted - deletedByNbItem, err = c.Ent.Alert.Delete().Where(alert.IDLT(maxid)).Exec(c.CTX) - - if err != nil { - c.Log.Errorf("FlushAlerts: Could not delete alerts : %s", err) - return fmt.Errorf("could not delete alerts: %w", err) - } - } - } - } - if deletedByNbItem > 0 { - c.Log.Infof("flushed %d/%d alerts because max number of alerts has been reached (%d max)", deletedByNbItem, totalAlerts, MaxItems) - } - if deletedByAge > 0 { - c.Log.Infof("flushed %d/%d alerts because they were created %s ago or more", deletedByAge, totalAlerts, MaxAge) - } - return nil -} - func (c *Client) GetAlertByID(alertID int) (*ent.Alert, error) { alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(c.CTX) if err != nil { @@ -1335,8 +1196,11 @@ func (c *Client) GetAlertByID(alertID int) (*ent.Alert, error) { log.Warningf("GetAlertByID (not found): %s", err) return &ent.Alert{}, ItemNotFound } + c.Log.Warningf("GetAlertByID : %s", err) + return &ent.Alert{}, QueryFail } + return alert, nil } diff --git a/pkg/database/database.go b/pkg/database/database.go index 18b5dbf38..aa191d7dc 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -5,10 +5,8 @@ import ( "database/sql" "fmt" "os" - "time" entsql "entgo.io/ent/dialect/sql" - "github.com/go-co-op/gocron" _ "github.com/go-sql-driver/mysql" _ "github.com/jackc/pgx/v4/stdlib" _ "github.com/mattn/go-sqlite3" @@ -105,72 +103,3 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) { decisionBulkSize: config.DecisionBulkSize, }, nil } - -func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) { - maxItems := 0 - maxAge := "" - if config.MaxItems != nil && *config.MaxItems <= 0 { - return nil, fmt.Errorf("max_items can't be zero or negative number") - } - if config.MaxItems != nil { - maxItems = *config.MaxItems - } - if config.MaxAge != nil && *config.MaxAge != "" { - maxAge = *config.MaxAge - } - - // Init & Start cronjob every minute for alerts - scheduler := gocron.NewScheduler(time.UTC) - job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, maxAge, maxItems) - if err != nil { - return nil, fmt.Errorf("while starting FlushAlerts scheduler: %w", err) - } - job.SingletonMode() - // Init & Start cronjob every hour for bouncers/agents - if config.AgentsGC != nil { - if config.AgentsGC.Cert != nil { - duration, err := ParseDuration(*config.AgentsGC.Cert) - if err != nil { - return nil, fmt.Errorf("while parsing agents cert auto-delete duration: %w", err) - } - config.AgentsGC.CertDuration = &duration - } - if config.AgentsGC.LoginPassword != nil { - duration, err := ParseDuration(*config.AgentsGC.LoginPassword) - if err != nil { - return nil, fmt.Errorf("while parsing agents login/password auto-delete duration: %w", err) - } - config.AgentsGC.LoginPasswordDuration = &duration - } - if config.AgentsGC.Api != nil { - log.Warning("agents auto-delete for API auth is not supported (use cert or login_password)") - } - } - if config.BouncersGC != nil { - if config.BouncersGC.Cert != nil { - duration, err := ParseDuration(*config.BouncersGC.Cert) - if err != nil { - return nil, fmt.Errorf("while parsing bouncers cert auto-delete duration: %w", err) - } - config.BouncersGC.CertDuration = &duration - } - if config.BouncersGC.Api != nil { - duration, err := ParseDuration(*config.BouncersGC.Api) - if err != nil { - return nil, fmt.Errorf("while parsing bouncers api auto-delete duration: %w", err) - } - config.BouncersGC.ApiDuration = &duration - } - if config.BouncersGC.LoginPassword != nil { - log.Warning("bouncers auto-delete for login/password auth is not supported (use cert or api)") - } - } - baJob, err := scheduler.Every(1).Minute().Do(c.FlushAgentsAndBouncers, config.AgentsGC, config.BouncersGC) - if err != nil { - return nil, fmt.Errorf("while starting FlushAgentsAndBouncers scheduler: %w", err) - } - baJob.SingletonMode() - scheduler.StartAsync() - - return scheduler, nil -} diff --git a/pkg/database/flush.go b/pkg/database/flush.go new file mode 100644 index 000000000..44d145d0b --- /dev/null +++ b/pkg/database/flush.go @@ -0,0 +1,278 @@ +package database + +import ( + "fmt" + "time" + + "github.com/go-co-op/gocron" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + + +func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) { + maxItems := 0 + maxAge := "" + if config.MaxItems != nil && *config.MaxItems <= 0 { + return nil, fmt.Errorf("max_items can't be zero or negative number") + } + if config.MaxItems != nil { + maxItems = *config.MaxItems + } + if config.MaxAge != nil && *config.MaxAge != "" { + maxAge = *config.MaxAge + } + + // Init & Start cronjob every minute for alerts + scheduler := gocron.NewScheduler(time.UTC) + job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, maxAge, maxItems) + if err != nil { + return nil, fmt.Errorf("while starting FlushAlerts scheduler: %w", err) + } + + job.SingletonMode() + // Init & Start cronjob every hour for bouncers/agents + if config.AgentsGC != nil { + if config.AgentsGC.Cert != nil { + duration, err := ParseDuration(*config.AgentsGC.Cert) + if err != nil { + return nil, fmt.Errorf("while parsing agents cert auto-delete duration: %w", err) + } + config.AgentsGC.CertDuration = &duration + } + if config.AgentsGC.LoginPassword != nil { + duration, err := ParseDuration(*config.AgentsGC.LoginPassword) + if err != nil { + return nil, fmt.Errorf("while parsing agents login/password auto-delete duration: %w", err) + } + config.AgentsGC.LoginPasswordDuration = &duration + } + if config.AgentsGC.Api != nil { + log.Warning("agents auto-delete for API auth is not supported (use cert or login_password)") + } + } + if config.BouncersGC != nil { + if config.BouncersGC.Cert != nil { + duration, err := ParseDuration(*config.BouncersGC.Cert) + if err != nil { + return nil, fmt.Errorf("while parsing bouncers cert auto-delete duration: %w", err) + } + config.BouncersGC.CertDuration = &duration + } + if config.BouncersGC.Api != nil { + duration, err := ParseDuration(*config.BouncersGC.Api) + if err != nil { + return nil, fmt.Errorf("while parsing bouncers api auto-delete duration: %w", err) + } + config.BouncersGC.ApiDuration = &duration + } + if config.BouncersGC.LoginPassword != nil { + log.Warning("bouncers auto-delete for login/password auth is not supported (use cert or api)") + } + } + baJob, err := scheduler.Every(1).Minute().Do(c.FlushAgentsAndBouncers, config.AgentsGC, config.BouncersGC) + if err != nil { + return nil, fmt.Errorf("while starting FlushAgentsAndBouncers scheduler: %w", err) + } + + baJob.SingletonMode() + scheduler.StartAsync() + + return scheduler, nil +} + + +func (c *Client) FlushOrphans() { + /* While it has only been linked to some very corner-case bug : https://github.com/crowdsecurity/crowdsec/issues/778 */ + /* We want to take care of orphaned events for which the parent alert/decision has been deleted */ + eventsCount, err := c.Ent.Event.Delete().Where(event.Not(event.HasOwner())).Exec(c.CTX) + if err != nil { + c.Log.Warningf("error while deleting orphan events: %s", err) + return + } + if eventsCount > 0 { + c.Log.Infof("%d deleted orphan events", eventsCount) + } + + eventsCount, err = c.Ent.Decision.Delete().Where( + decision.Not(decision.HasOwner())).Where(decision.UntilLTE(time.Now().UTC())).Exec(c.CTX) + + if err != nil { + c.Log.Warningf("error while deleting orphan decisions: %s", err) + return + } + if eventsCount > 0 { + c.Log.Infof("%d deleted orphan decisions", eventsCount) + } +} + +func (c *Client) flushBouncers(bouncersCfg *csconfig.AuthGCCfg) { + if bouncersCfg == nil { + return + } + + if bouncersCfg.ApiDuration != nil { + log.Debug("trying to delete old bouncers from api") + + deletionCount, err := c.Ent.Bouncer.Delete().Where( + bouncer.LastPullLTE(time.Now().UTC().Add(-*bouncersCfg.ApiDuration)), + ).Where( + bouncer.AuthTypeEQ(types.ApiKeyAuthType), + ).Exec(c.CTX) + if err != nil { + c.Log.Errorf("while auto-deleting expired bouncers (api key): %s", err) + } else if deletionCount > 0 { + c.Log.Infof("deleted %d expired bouncers (api auth)", deletionCount) + } + } + + if bouncersCfg.CertDuration != nil { + log.Debug("trying to delete old bouncers from cert") + + deletionCount, err := c.Ent.Bouncer.Delete().Where( + bouncer.LastPullLTE(time.Now().UTC().Add(-*bouncersCfg.CertDuration)), + ).Where( + bouncer.AuthTypeEQ(types.TlsAuthType), + ).Exec(c.CTX) + if err != nil { + c.Log.Errorf("while auto-deleting expired bouncers (api key): %s", err) + } else if deletionCount > 0 { + c.Log.Infof("deleted %d expired bouncers (api auth)", deletionCount) + } + } +} + +func (c *Client) flushAgents(agentsCfg *csconfig.AuthGCCfg) { + if agentsCfg == nil { + return + } + + if agentsCfg.CertDuration != nil { + log.Debug("trying to delete old agents from cert") + + deletionCount, err := c.Ent.Machine.Delete().Where( + machine.LastHeartbeatLTE(time.Now().UTC().Add(-*agentsCfg.CertDuration)), + ).Where( + machine.Not(machine.HasAlerts()), + ).Where( + machine.AuthTypeEQ(types.TlsAuthType), + ).Exec(c.CTX) + log.Debugf("deleted %d entries", deletionCount) + if err != nil { + c.Log.Errorf("while auto-deleting expired machine (cert): %s", err) + } else if deletionCount > 0 { + c.Log.Infof("deleted %d expired machine (cert auth)", deletionCount) + } + } + + if agentsCfg.LoginPasswordDuration != nil { + log.Debug("trying to delete old agents from password") + + deletionCount, err := c.Ent.Machine.Delete().Where( + machine.LastHeartbeatLTE(time.Now().UTC().Add(-*agentsCfg.LoginPasswordDuration)), + ).Where( + machine.Not(machine.HasAlerts()), + ).Where( + machine.AuthTypeEQ(types.PasswordAuthType), + ).Exec(c.CTX) + log.Debugf("deleted %d entries", deletionCount) + if err != nil { + c.Log.Errorf("while auto-deleting expired machine (password): %s", err) + } else if deletionCount > 0 { + c.Log.Infof("deleted %d expired machine (password auth)", deletionCount) + } + } +} + +func (c *Client) FlushAgentsAndBouncers(agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { + log.Debug("starting FlushAgentsAndBouncers") + + c.flushBouncers(bouncersCfg) + c.flushAgents(agentsCfg) + + return nil +} + +func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { + var deletedByAge int + var deletedByNbItem int + var totalAlerts int + var err error + + if !c.CanFlush { + c.Log.Debug("a list is being imported, flushing later") + return nil + } + + c.Log.Debug("Flushing orphan alerts") + c.FlushOrphans() + c.Log.Debug("Done flushing orphan alerts") + totalAlerts, err = c.TotalAlerts() + if err != nil { + c.Log.Warningf("FlushAlerts (max items count): %s", err) + return fmt.Errorf("unable to get alerts count: %w", err) + } + + c.Log.Debugf("FlushAlerts (Total alerts): %d", totalAlerts) + if MaxAge != "" { + filter := map[string][]string{ + "created_before": {MaxAge}, + } + nbDeleted, err := c.DeleteAlertWithFilter(filter) + if err != nil { + c.Log.Warningf("FlushAlerts (max age): %s", err) + return fmt.Errorf("unable to flush alerts with filter until=%s: %w", MaxAge, err) + } + + c.Log.Debugf("FlushAlerts (deleted max age alerts): %d", nbDeleted) + deletedByAge = nbDeleted + } + if MaxItems > 0 { + //We get the highest id for the alerts + //We subtract MaxItems to avoid deleting alerts that are not old enough + //This gives us the oldest alert that we want to keep + //We then delete all the alerts with an id lower than this one + //We can do this because the id is auto-increment, and the database won't reuse the same id twice + lastAlert, err := c.QueryAlertWithFilter(map[string][]string{ + "sort": {"DESC"}, + "limit": {"1"}, + //we do not care about fetching the edges, we just want the id + "with_decisions": {"false"}, + }) + c.Log.Debugf("FlushAlerts (last alert): %+v", lastAlert) + if err != nil { + c.Log.Errorf("FlushAlerts: could not get last alert: %s", err) + return fmt.Errorf("could not get last alert: %w", err) + } + + if len(lastAlert) != 0 { + maxid := lastAlert[0].ID - MaxItems + + c.Log.Debugf("FlushAlerts (max id): %d", maxid) + + if maxid > 0 { + //This may lead to orphan alerts (at least on MySQL), but the next time the flush job will run, they will be deleted + deletedByNbItem, err = c.Ent.Alert.Delete().Where(alert.IDLT(maxid)).Exec(c.CTX) + + if err != nil { + c.Log.Errorf("FlushAlerts: Could not delete alerts: %s", err) + return fmt.Errorf("could not delete alerts: %w", err) + } + } + } + } + if deletedByNbItem > 0 { + c.Log.Infof("flushed %d/%d alerts because max number of alerts has been reached (%d max)", deletedByNbItem, totalAlerts, MaxItems) + } + if deletedByAge > 0 { + c.Log.Infof("flushed %d/%d alerts because they were created %s ago or more", deletedByAge, totalAlerts, MaxAge) + } + return nil +} diff --git a/test/bats/90_decisions.bats b/test/bats/90_decisions.bats index bcb410de9..f2464084a 100644 --- a/test/bats/90_decisions.bats +++ b/test/bats/90_decisions.bats @@ -163,7 +163,7 @@ teardown() { whatever EOT assert_stderr --partial 'Parsing values' - assert_stderr --partial 'API error: unable to create alerts: whatever: invalid ip address / range' + assert_stderr --partial 'creating alert decisions: whatever: invalid ip address / range' #---------- # Batch