Browse Source

Improve CAPI pull management (#871)

* prepare for new consensus : thousands of ips

Co-authored-by: Sebastien Blot <sebastien@crowdsec.net>
Thibault "bui" Koechlin 3 years ago
parent
commit
c188d401a3

+ 50 - 43
pkg/apiserver/apic.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"fmt"
 	"net/url"
+	"strconv"
 	"strings"
 	"sync"
 	"time"
@@ -12,6 +13,8 @@ import (
 	"github.com/crowdsecurity/crowdsec/pkg/csconfig"
 	"github.com/crowdsecurity/crowdsec/pkg/cwversion"
 	"github.com/crowdsecurity/crowdsec/pkg/database"
+	"github.com/crowdsecurity/crowdsec/pkg/database/ent/alert"
+	"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
 	"github.com/crowdsecurity/crowdsec/pkg/models"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	"github.com/go-openapi/strfmt"
@@ -234,6 +237,18 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) {
 func (a *apic) PullTop() error {
 	var err error
 
+	/*only pull community blocklist if it's older than 1h30 */
+	alerts := a.dbClient.Ent.Alert.Query()
+	alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID)))
+	alerts = alerts.Where(alert.CreatedAtGTE(time.Now().Add(-time.Duration(1*time.Hour + 30*time.Minute))))
+	count, err := alerts.Count(a.dbClient.CTX)
+	if err != nil {
+		return errors.Wrap(err, "while looking for CAPI alert")
+	}
+	if count > 0 {
+		log.Printf("last CAPI pull is newer than 1h30, skip.")
+		return nil
+	}
 	data, _, err := a.apiClient.Decisions.GetStream(context.Background(), a.startup, []string{})
 	if err != nil {
 		return errors.Wrap(err, "get stream")
@@ -243,6 +258,7 @@ func (a *apic) PullTop() error {
 	}
 	// process deleted decisions
 	var filter map[string][]string
+	var nbDeleted int
 	for _, decision := range data.Deleted {
 		if strings.ToLower(*decision.Scope) == "ip" {
 			filter = make(map[string][]string, 1)
@@ -254,27 +270,40 @@ func (a *apic) PullTop() error {
 			filter["value"] = []string{*decision.Scope}
 		}
 
-		nbDeleted, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter)
+		dbCliRet, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter)
 		if err != nil {
-			return err
+			return errors.Wrap(err, "deleting decisions error")
 		}
-
-		log.Printf("pull top: deleted %s entries", nbDeleted)
+		dbCliDel, err := strconv.Atoi(dbCliRet)
+		if err != nil {
+			return errors.Wrapf(err, "converting db ret %d", dbCliDel)
+		}
+		nbDeleted += dbCliDel
 	}
+	log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted)
 
-	alertCreated, err := a.dbClient.Ent.Alert.
-		Create().
-		SetScenario(fmt.Sprintf("update : +%d/-%d IPs", len(data.New), len(data.Deleted))).
-		SetSourceScope("Community blocklist").
-		Save(a.dbClient.CTX)
-	if err != nil {
-		return errors.Wrap(err, "create alert from crowdsec-api")
+	if len(data.New) == 0 {
+		log.Warnf("capi/community-blocklist : received 0 new entries, CAPI failure ?")
+		return nil
 	}
 
+	capiPullTopX := models.Alert{}
+	capiPullTopX.Scenario = types.StrPtr(fmt.Sprintf("update : +%d/-%d IPs", len(data.New), len(data.Deleted)))
+	capiPullTopX.Message = types.StrPtr("")
+	capiPullTopX.Source = &models.Source{}
+	capiPullTopX.Source.Scope = types.StrPtr("crowdsec/community-blocklist")
+	capiPullTopX.Source.Value = types.StrPtr("")
+	capiPullTopX.StartAt = types.StrPtr(time.Now().Format(time.RFC3339))
+	capiPullTopX.StopAt = types.StrPtr(time.Now().Format(time.RFC3339))
+	capiPullTopX.Capacity = types.Int32Ptr(0)
+	capiPullTopX.Simulated = types.BoolPtr(false)
+	capiPullTopX.EventsCount = types.Int32Ptr(int32(len(data.New)))
+	capiPullTopX.Leakspeed = types.StrPtr("")
+	capiPullTopX.ScenarioHash = types.StrPtr("")
+	capiPullTopX.ScenarioVersion = types.StrPtr("")
+	capiPullTopX.MachineID = database.CapiMachineID
 	// process new decisions
 	for _, decision := range data.New {
-		var start_ip, start_sfx, end_ip, end_sfx int64
-		var sz int
 
 		/*CAPI might send lower case scopes, unify it.*/
 		switch strings.ToLower(*decision.Scope) {
@@ -284,36 +313,16 @@ func (a *apic) PullTop() error {
 			*decision.Scope = types.Range
 		}
 
-		/*if the scope is IP or Range, convert the value to integers */
-		if strings.ToLower(*decision.Scope) == "ip" || strings.ToLower(*decision.Scope) == "range" {
-			sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(*decision.Value)
-			if err != nil {
-				return errors.Wrapf(err, "invalid ip/range %s", *decision.Value)
-			}
-		}
+		capiPullTopX.Decisions = append(capiPullTopX.Decisions, decision)
+	}
 
-		duration, err := time.ParseDuration(*decision.Duration)
-		if err != nil {
-			return errors.Wrapf(err, "parse decision duration '%s':", *decision.Duration)
-		}
-		_, err = a.dbClient.Ent.Decision.Create().
-			SetUntil(time.Now().Add(duration)).
-			SetScenario(*decision.Scenario).
-			SetType(*decision.Type).
-			SetIPSize(int64(sz)).
-			SetStartIP(start_ip).
-			SetStartSuffix(start_sfx).
-			SetEndIP(end_ip).
-			SetEndSuffix(end_sfx).
-			SetValue(*decision.Value).
-			SetScope(*decision.Scope).
-			SetOrigin(*decision.Origin).
-			SetOwner(alertCreated).Save(a.dbClient.CTX)
-		if err != nil {
-			return errors.Wrap(err, "decision creation from crowdsec-api:")
-		}
+	alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(&capiPullTopX)
+	if err != nil {
+		return errors.Wrap(err, "while saving alert from capi/community-blocklist")
 	}
-	log.Printf("pull top: added %d entries", len(data.New))
+
+	log.Printf("capi/community-blocklist : added %d entries, deleted %d entries (alert:%d)", inserted, deleted, alertID)
+
 	return nil
 }
 
@@ -379,8 +388,6 @@ func (a *apic) SendMetrics() error {
 			if err != nil {
 				return err
 			}
-			// models.metric structure : len(machines), len(bouncers), a.credentials.Login
-			// _, _, err := a.apiClient.Metrics.Add(//*models.Metrics)
 			for _, machine := range machines {
 				m := &models.MetricsSoftInfo{
 					Version: machine.Version,

+ 146 - 1
pkg/database/alerts.go

@@ -114,6 +114,151 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str
 	return ret, nil
 }
 
+/*We can't bulk both the alert and the decision at the same time. With new consensus, we want to bulk a single alert with a lot of decisions.*/
+func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, int, error) {
+
+	decisionBulkSize := 50
+	var err error
+	var deleted, inserted int
+
+	startAtTime, err := time.Parse(time.RFC3339, *alertItem.StartAt)
+	if err != nil {
+		return 0, 0, 0, errors.Wrapf(ParseTimeFail, "start_at field time '%s': %s", *alertItem.StartAt, err)
+	}
+
+	stopAtTime, err := time.Parse(time.RFC3339, *alertItem.StopAt)
+	if err != nil {
+		return 0, 0, 0, errors.Wrapf(ParseTimeFail, "stop_at field time '%s': %s", *alertItem.StopAt, err)
+	}
+
+	ts, err := time.Parse(time.RFC3339, *alertItem.StopAt)
+	if err != nil {
+		c.Log.Errorf("While parsing StartAt of item %s : %s", *alertItem.StopAt, err)
+		ts = time.Now()
+	}
+
+	alertB := c.Ent.Alert.
+		Create().
+		SetScenario(*alertItem.Scenario).
+		SetMessage(*alertItem.Message).
+		SetEventsCount(*alertItem.EventsCount).
+		SetStartedAt(startAtTime).
+		SetStoppedAt(stopAtTime).
+		SetSourceScope(*alertItem.Source.Scope).
+		SetSourceValue(*alertItem.Source.Value).
+		SetSourceIp(alertItem.Source.IP).
+		SetSourceRange(alertItem.Source.Range).
+		SetSourceAsNumber(alertItem.Source.AsNumber).
+		SetSourceAsName(alertItem.Source.AsName).
+		SetSourceCountry(alertItem.Source.Cn).
+		SetSourceLatitude(alertItem.Source.Latitude).
+		SetSourceLongitude(alertItem.Source.Longitude).
+		SetCapacity(*alertItem.Capacity).
+		SetLeakSpeed(*alertItem.Leakspeed).
+		SetSimulated(*alertItem.Simulated).
+		SetScenarioVersion(*alertItem.ScenarioVersion).
+		SetScenarioHash(*alertItem.ScenarioHash)
+
+	alertRef, err := alertB.Save(c.CTX)
+	if err != nil {
+		return 0, 0, 0, errors.Wrapf(BulkError, "error creating alert : %s", err)
+	}
+
+	if len(alertItem.Decisions) > 0 {
+		decisionBulk := make([]*ent.DecisionCreate, 0, decisionBulkSize)
+		valueList := make([]string, 0, decisionBulkSize)
+		for i, decisionItem := range alertItem.Decisions {
+			var start_ip, start_sfx, end_ip, end_sfx int64
+			var sz int
+
+			duration, err := time.ParseDuration(*decisionItem.Duration)
+			if err != nil {
+				return 0, 0, 0, errors.Wrapf(ParseDurationFail, "decision duration '%v' : %s", decisionItem.Duration, err)
+			}
+
+			/*if the scope is IP or Range, convert the value to integers */
+			if strings.ToLower(*decisionItem.Scope) == "ip" || strings.ToLower(*decisionItem.Scope) == "range" {
+				sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(*decisionItem.Value)
+				if err != nil {
+					return 0, 0, 0, errors.Wrapf(ParseDurationFail, "invalid addr/range %s : %s", *decisionItem.Value, err)
+				}
+			}
+			/*bulk insert some new decisions*/
+			decisionBulk = append(decisionBulk, c.Ent.Decision.Create().
+				SetUntil(ts.Add(duration)).
+				SetScenario(*decisionItem.Scenario).
+				SetType(*decisionItem.Type).
+				SetStartIP(start_ip).
+				SetStartSuffix(start_sfx).
+				SetEndIP(end_ip).
+				SetEndSuffix(end_sfx).
+				SetIPSize(int64(sz)).
+				SetValue(*decisionItem.Value).
+				SetScope(*decisionItem.Scope).
+				SetOrigin(*decisionItem.Origin).
+				SetSimulated(*alertItem.Simulated).
+				SetOwner(alertRef))
+			/*for bulk delete of duplicate decisions*/
+			valueList = append(valueList, *decisionItem.Value)
+
+			if len(decisionBulk) == decisionBulkSize {
+				insertedDecisions, err := c.Ent.Decision.CreateBulk(decisionBulk...).Save(c.CTX)
+				if err != nil {
+					return 0, 0, 0, errors.Wrapf(BulkError, "bulk creating decisions : %s", err)
+				}
+				inserted += len(insertedDecisions)
+
+				/*Deleting older decisions from capi*/
+				deletedDecisions, err := c.Ent.Decision.Delete().
+					Where(decision.And(
+						decision.OriginEQ(CapiMachineID),
+						decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))),
+						decision.ValueIn(valueList...),
+					)).Exec(c.CTX)
+				if err != nil {
+					return 0, 0, 0, errors.Wrap(err, "while deleting older community blocklist decisions")
+				}
+				deleted += deletedDecisions
+
+				if len(alertItem.Decisions)-i <= decisionBulkSize {
+					decisionBulk = make([]*ent.DecisionCreate, 0, (len(alertItem.Decisions) - i))
+					valueList = make([]string, 0, (len(alertItem.Decisions) - i))
+				} else {
+					decisionBulk = make([]*ent.DecisionCreate, 0, decisionBulkSize)
+					valueList = make([]string, 0, decisionBulkSize)
+				}
+
+				// The 90's called, they want their concurrency back.
+				// This is needed for sqlite, which does not support concurrent access while writing.
+				// If we pull a large number of IPs from CAPI, and we have a slow disk, LAPI won't respond until all IPs are inserted (which can take up to a few seconds).
+				time.Sleep(100 * time.Millisecond)
+			}
+
+		}
+		insertedDecisions, err := c.Ent.Decision.CreateBulk(decisionBulk...).Save(c.CTX)
+		if err != nil {
+			return 0, 0, 0, errors.Wrapf(BulkError, "creating alert decisions: %s", err)
+		}
+		inserted += len(insertedDecisions)
+		/*Deleting older decisions from capi*/
+		if len(valueList) > 0 {
+			deletedDecisions, err := c.Ent.Decision.Delete().
+				Where(decision.And(
+					decision.OriginEQ(CapiMachineID),
+					decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))),
+					decision.ValueIn(valueList...),
+				)).Exec(c.CTX)
+			if err != nil {
+				return 0, 0, 0, errors.Wrap(err, "while deleting older community blocklist decisions")
+			}
+			deleted += deletedDecisions
+		}
+
+	}
+
+	return alertRef.ID, inserted, deleted, nil
+}
+
 func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([]string, error) {
 
 	ret := []string{}
@@ -403,7 +548,7 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str
 			alerts = alerts.Where(alert.HasDecisionsWith(decision.TypeEQ(value[0])))
 		case "include_capi": //allows to exclude one or more specific origins
 			if value[0] == "false" {
-				alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginNEQ("CAPI")))
+				alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginNEQ(CapiMachineID)))
 			} else if value[0] != "true" {
 				log.Errorf("Invalid bool '%s' for include_capi", value[0])
 			}

+ 17 - 14
pkg/database/database.go

@@ -29,6 +29,17 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
 	if config == nil {
 		return &Client{}, fmt.Errorf("DB config is empty")
 	}
+	/*The logger that will be used by db operations*/
+	clog := log.New()
+	if err := types.ConfigureLogger(clog); err != nil {
+		return nil, errors.Wrap(err, "while configuring db logger")
+	}
+	if config.LogLevel != nil {
+		clog.SetLevel(*config.LogLevel)
+	}
+	entLogger := clog.WithField("context", "ent")
+
+	entOpt := ent.Log(entLogger.Debug)
 	switch config.Type {
 	case "sqlite":
 
@@ -46,17 +57,17 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
 				return &Client{}, fmt.Errorf("unable to set perms on %s: %v", config.DbPath, err)
 			}
 		}
-		client, err = ent.Open("sqlite3", fmt.Sprintf("file:%s?_busy_timeout=100000&_fk=1", config.DbPath))
+		client, err = ent.Open("sqlite3", fmt.Sprintf("file:%s?_busy_timeout=100000&_fk=1", config.DbPath), entOpt)
 		if err != nil {
 			return &Client{}, fmt.Errorf("failed opening connection to sqlite: %v", err)
 		}
 	case "mysql":
-		client, err = ent.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", config.User, config.Password, config.Host, config.Port, config.DbName))
+		client, err = ent.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", config.User, config.Password, config.Host, config.Port, config.DbName), entOpt)
 		if err != nil {
 			return &Client{}, fmt.Errorf("failed opening connection to mysql: %v", err)
 		}
 	case "postgres", "postgresql":
-		client, err = ent.Open("postgres", fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", config.Host, config.Port, config.User, config.DbName, config.Password, config.Sslmode))
+		client, err = ent.Open("postgres", fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", config.Host, config.Port, config.User, config.DbName, config.Password, config.Sslmode), entOpt)
 		if err != nil {
 			return &Client{}, fmt.Errorf("failed opening connection to postgres: %v", err)
 		}
@@ -64,17 +75,9 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
 		return &Client{}, fmt.Errorf("unknown database type")
 	}
 
-	/*The logger that will be used by db operations*/
-	clog := log.New()
-	if err := types.ConfigureLogger(clog); err != nil {
-		return nil, errors.Wrap(err, "while configuring db logger")
-	}
-	if config.LogLevel != nil {
-		clog.SetLevel(*config.LogLevel)
-		if *config.LogLevel >= log.DebugLevel {
-			clog.Debugf("Enabling request debug")
-			client = client.Debug()
-		}
+	if config.LogLevel != nil && *config.LogLevel >= log.DebugLevel {
+		clog.Debugf("Enabling request debug")
+		client = client.Debug()
 	}
 	if err = client.Schema.Create(context.Background()); err != nil {
 		return nil, fmt.Errorf("failed creating schema resources: %v", err)

+ 2 - 0
pkg/database/machines.go

@@ -12,6 +12,8 @@ import (
 	"golang.org/x/crypto/bcrypt"
 )
 
+const CapiMachineID = "CAPI"
+
 func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool) (int, error) {
 	hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
 	if err != nil {

+ 12 - 0
pkg/types/utils.go

@@ -187,3 +187,15 @@ func CopyFile(sourceSymLink, destinationFile string) (err error) {
 	err = copyFileContents(sourceFile, destinationFile)
 	return
 }
+
+func StrPtr(s string) *string {
+	return &s
+}
+
+func Int32Ptr(i int32) *int32 {
+	return &i
+}
+
+func BoolPtr(b bool) *bool {
+	return &b
+}

+ 1 - 1
scripts/func_tests/tests_post-install_0base.sh

@@ -35,7 +35,7 @@ pidof crowdsec || fail "crowdsec process should be running"
 ${CSCLI} version || fail "cannot run cscli version"
 
 ## alerts
-# alerts list at startup should just return one entry : comunity pull
+# alerts list at startup should just return one entry : community pull
 sleep 5
 ${CSCLI} alerts list -ojson  | ${JQ} '. | length >= 1' || fail "expected at least one entry from cscli alerts list"
 ## capi