Browse Source

Use explicit transaction when inserting community blocklist (#1835)

blotus 2 years ago
parent
commit
b7c4bfd4e3
3 changed files with 46 additions and 11 deletions
  1. 6 0
      pkg/apiserver/apic.go
  2. 37 10
      pkg/database/alerts.go
  3. 3 1
      pkg/database/database.go

+ 6 - 0
pkg/apiserver/apic.go

@@ -419,6 +419,9 @@ func (a *apic) PullTop() error {
 	a.startup = false
 	a.startup = false
 	/*to count additions/deletions across lists*/
 	/*to count additions/deletions across lists*/
 
 
+	log.Debugf("Received %d new decisions", len(data.New))
+	log.Debugf("Received %d deleted decisions", len(data.Deleted))
+
 	add_counters, delete_counters := makeAddAndDeleteCounters()
 	add_counters, delete_counters := makeAddAndDeleteCounters()
 	// process deleted decisions
 	// process deleted decisions
 	if nbDeleted, err := a.HandleDeletedDecisions(data.Deleted, delete_counters); err != nil {
 	if nbDeleted, err := a.HandleDeletedDecisions(data.Deleted, delete_counters); err != nil {
@@ -441,6 +444,9 @@ func (a *apic) PullTop() error {
 	for idx, alert := range alertsFromCapi {
 	for idx, alert := range alertsFromCapi {
 		alertsFromCapi[idx] = setAlertScenario(add_counters, delete_counters, alert)
 		alertsFromCapi[idx] = setAlertScenario(add_counters, delete_counters, alert)
 		log.Debugf("%s has %d decisions", *alertsFromCapi[idx].Source.Scope, len(alertsFromCapi[idx].Decisions))
 		log.Debugf("%s has %d decisions", *alertsFromCapi[idx].Source.Scope, len(alertsFromCapi[idx].Decisions))
+		if a.dbClient.Type == "sqlite" && (a.dbClient.WalMode == nil || !*a.dbClient.WalMode) {
+			log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist")
+		}
 		alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alertsFromCapi[idx])
 		alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alertsFromCapi[idx])
 		if err != nil {
 		if err != nil {
 			return errors.Wrapf(err, "while saving alert from %s", *alertsFromCapi[idx].Source.Scope)
 			return errors.Wrapf(err, "while saving alert from %s", *alertsFromCapi[idx].Source.Scope)

+ 37 - 10
pkg/database/alerts.go

@@ -178,6 +178,10 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
 	}
 	}
 
 
 	if len(alertItem.Decisions) > 0 {
 	if len(alertItem.Decisions) > 0 {
+		txClient, err := c.Ent.Tx(c.CTX)
+		if err != nil {
+			return 0, 0, 0, errors.Wrapf(BulkError, "error creating transaction : %s", err)
+		}
 		decisionBulk := make([]*ent.DecisionCreate, 0, decisionBulkSize)
 		decisionBulk := make([]*ent.DecisionCreate, 0, decisionBulkSize)
 		valueList := make([]string, 0, decisionBulkSize)
 		valueList := make([]string, 0, decisionBulkSize)
 		DecOrigin := CapiMachineID
 		DecOrigin := CapiMachineID
@@ -195,6 +199,10 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
 			}
 			}
 			duration, err := time.ParseDuration(*decisionItem.Duration)
 			duration, err := time.ParseDuration(*decisionItem.Duration)
 			if err != nil {
 			if err != nil {
+				rollbackErr := txClient.Rollback()
+				if rollbackErr != nil {
+					log.Errorf("rollback error: %s", rollbackErr)
+				}
 				return 0, 0, 0, errors.Wrapf(ParseDurationFail, "decision duration '%+v' : %s", *decisionItem.Duration, err)
 				return 0, 0, 0, errors.Wrapf(ParseDurationFail, "decision duration '%+v' : %s", *decisionItem.Duration, err)
 			}
 			}
 			if decisionItem.Scope == nil {
 			if decisionItem.Scope == nil {
@@ -205,6 +213,10 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
 			if strings.ToLower(*decisionItem.Scope) == "ip" || strings.ToLower(*decisionItem.Scope) == "range" {
 			if strings.ToLower(*decisionItem.Scope) == "ip" || strings.ToLower(*decisionItem.Scope) == "range" {
 				sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(*decisionItem.Value)
 				sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(*decisionItem.Value)
 				if err != nil {
 				if err != nil {
+					rollbackErr := txClient.Rollback()
+					if rollbackErr != nil {
+						log.Errorf("rollback error: %s", rollbackErr)
+					}
 					return 0, 0, 0, errors.Wrapf(ParseDurationFail, "invalid addr/range %s : %s", *decisionItem.Value, err)
 					return 0, 0, 0, errors.Wrapf(ParseDurationFail, "invalid addr/range %s : %s", *decisionItem.Value, err)
 				}
 				}
 			}
 			}
@@ -232,20 +244,29 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
 			valueList = append(valueList, *decisionItem.Value)
 			valueList = append(valueList, *decisionItem.Value)
 
 
 			if len(decisionBulk) == decisionBulkSize {
 			if len(decisionBulk) == decisionBulkSize {
-				insertedDecisions, err := c.Ent.Decision.CreateBulk(decisionBulk...).Save(c.CTX)
+
+				insertedDecisions, err := txClient.Decision.CreateBulk(decisionBulk...).Save(c.CTX)
 				if err != nil {
 				if err != nil {
+					rollbackErr := txClient.Rollback()
+					if rollbackErr != nil {
+						log.Errorf("rollback error: %s", rollbackErr)
+					}
 					return 0, 0, 0, errors.Wrapf(BulkError, "bulk creating decisions : %s", err)
 					return 0, 0, 0, errors.Wrapf(BulkError, "bulk creating decisions : %s", err)
 				}
 				}
 				inserted += len(insertedDecisions)
 				inserted += len(insertedDecisions)
 
 
 				/*Deleting older decisions from capi*/
 				/*Deleting older decisions from capi*/
-				deletedDecisions, err := c.Ent.Decision.Delete().
+				deletedDecisions, err := txClient.Decision.Delete().
 					Where(decision.And(
 					Where(decision.And(
 						decision.OriginEQ(DecOrigin),
 						decision.OriginEQ(DecOrigin),
 						decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))),
 						decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))),
 						decision.ValueIn(valueList...),
 						decision.ValueIn(valueList...),
 					)).Exec(c.CTX)
 					)).Exec(c.CTX)
 				if err != nil {
 				if err != nil {
+					rollbackErr := txClient.Rollback()
+					if rollbackErr != nil {
+						log.Errorf("rollback error: %s", rollbackErr)
+					}
 					return 0, 0, 0, errors.Wrap(err, "while deleting older community blocklist decisions")
 					return 0, 0, 0, errors.Wrap(err, "while deleting older community blocklist decisions")
 				}
 				}
 				deleted += deletedDecisions
 				deleted += deletedDecisions
@@ -257,34 +278,40 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in
 					decisionBulk = make([]*ent.DecisionCreate, 0, decisionBulkSize)
 					decisionBulk = make([]*ent.DecisionCreate, 0, decisionBulkSize)
 					valueList = make([]string, 0, decisionBulkSize)
 					valueList = make([]string, 0, decisionBulkSize)
 				}
 				}
-
-				// The 90's called, they want their concurrency back.
-				// This is needed for sqlite, which does not support concurrent access while writing.
-				// If we pull a large number of IPs from CAPI, and we have a slow disk, LAPI won't respond until all IPs are inserted (which can take up to a few seconds).
-				time.Sleep(100 * time.Millisecond)
 			}
 			}
 
 
 		}
 		}
 		log.Debugf("deleted %d decisions for %s vs %s", deleted, DecOrigin, *alertItem.Decisions[0].Origin)
 		log.Debugf("deleted %d decisions for %s vs %s", deleted, DecOrigin, *alertItem.Decisions[0].Origin)
-		insertedDecisions, err := c.Ent.Decision.CreateBulk(decisionBulk...).Save(c.CTX)
+		insertedDecisions, err := txClient.Decision.CreateBulk(decisionBulk...).Save(c.CTX)
 		if err != nil {
 		if err != nil {
 			return 0, 0, 0, errors.Wrapf(BulkError, "creating alert decisions: %s", err)
 			return 0, 0, 0, errors.Wrapf(BulkError, "creating alert decisions: %s", err)
 		}
 		}
 		inserted += len(insertedDecisions)
 		inserted += len(insertedDecisions)
 		/*Deleting older decisions from capi*/
 		/*Deleting older decisions from capi*/
 		if len(valueList) > 0 {
 		if len(valueList) > 0 {
-			deletedDecisions, err := c.Ent.Decision.Delete().
+			deletedDecisions, err := txClient.Decision.Delete().
 				Where(decision.And(
 				Where(decision.And(
 					decision.OriginEQ(DecOrigin),
 					decision.OriginEQ(DecOrigin),
 					decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))),
 					decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))),
 					decision.ValueIn(valueList...),
 					decision.ValueIn(valueList...),
 				)).Exec(c.CTX)
 				)).Exec(c.CTX)
 			if err != nil {
 			if err != nil {
+				rollbackErr := txClient.Rollback()
+				if rollbackErr != nil {
+					log.Errorf("rollback error: %s", rollbackErr)
+				}
 				return 0, 0, 0, errors.Wrap(err, "while deleting older community blocklist decisions")
 				return 0, 0, 0, errors.Wrap(err, "while deleting older community blocklist decisions")
 			}
 			}
 			deleted += deletedDecisions
 			deleted += deletedDecisions
 		}
 		}
-
+		err = txClient.Commit()
+		if err != nil {
+			rollbackErr := txClient.Rollback()
+			if rollbackErr != nil {
+				log.Errorf("rollback error: %s", rollbackErr)
+			}
+			return 0, 0, 0, errors.Wrapf(BulkError, "error committing transaction : %s", err)
+		}
 	}
 	}
 
 
 	return alertRef.ID, inserted, deleted, nil
 	return alertRef.ID, inserted, deleted, nil

+ 3 - 1
pkg/database/database.go

@@ -26,6 +26,8 @@ type Client struct {
 	CTX      context.Context
 	CTX      context.Context
 	Log      *log.Logger
 	Log      *log.Logger
 	CanFlush bool
 	CanFlush bool
+	Type     string
+	WalMode  *bool
 }
 }
 
 
 func getEntDriver(dbtype string, dbdialect string, dsn string, config *csconfig.DatabaseCfg) (*entsql.Driver, error) {
 func getEntDriver(dbtype string, dbdialect string, dsn string, config *csconfig.DatabaseCfg) (*entsql.Driver, error) {
@@ -118,7 +120,7 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
 	if err = client.Schema.Create(context.Background()); err != nil {
 	if err = client.Schema.Create(context.Background()); err != nil {
 		return nil, fmt.Errorf("failed creating schema resources: %v", err)
 		return nil, fmt.Errorf("failed creating schema resources: %v", err)
 	}
 	}
-	return &Client{Ent: client, CTX: context.Background(), Log: clog, CanFlush: true}, nil
+	return &Client{Ent: client, CTX: context.Background(), Log: clog, CanFlush: true, Type: config.Type, WalMode: config.UseWal}, nil
 }
 }
 
 
 func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) {
 func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) {