From 48f011dc1c24d6c31ee47741ab6c4b20e0cf0f16 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Mon, 15 Jan 2024 12:38:31 +0100 Subject: [PATCH] apiclient/apiserver: lint/2 (#2741) --- pkg/apiclient/alerts_service.go | 2 +- pkg/apiclient/heartbeat.go | 1 + pkg/apiserver/apic.go | 1 + pkg/apiserver/controllers/v1/decisions.go | 14 +++++-- pkg/apiserver/controllers/v1/metrics.go | 1 + pkg/apiserver/middlewares/v1/api_key.go | 1 - pkg/apiserver/middlewares/v1/jwt.go | 24 +++++++++--- pkg/apiserver/middlewares/v1/middlewares.go | 1 + pkg/apiserver/middlewares/v1/tls_auth.go | 41 +++++++++++++++++++++ pkg/apiserver/papi.go | 35 +++++++++++++++--- pkg/apiserver/papi_cmd.go | 20 +++++++--- 11 files changed, 121 insertions(+), 20 deletions(-) diff --git a/pkg/apiclient/alerts_service.go b/pkg/apiclient/alerts_service.go index eb41452ea..ad75dd393 100644 --- a/pkg/apiclient/alerts_service.go +++ b/pkg/apiclient/alerts_service.go @@ -56,7 +56,7 @@ func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) return nil, nil, err } - var addedIds models.AddAlertsResponse + addedIds := models.AddAlertsResponse{} resp, err := s.client.Do(ctx, req, &addedIds) if err != nil { diff --git a/pkg/apiclient/heartbeat.go b/pkg/apiclient/heartbeat.go index df3afc52f..c6b3d0832 100644 --- a/pkg/apiclient/heartbeat.go +++ b/pkg/apiclient/heartbeat.go @@ -50,6 +50,7 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) { log.Errorf("heartbeat unexpected return code: %d", resp.Response.StatusCode) continue } + if !ok { log.Errorf("heartbeat returned false") continue diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 961a9b5ac..d0b205c25 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -651,6 +651,7 @@ func (a *apic) PullTop(forcePull bool) error { } addCounters, deleteCounters := makeAddAndDeleteCounters() + // process deleted decisions nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters) if err != nil { diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 9acfc1f2e..f3c6a7bba 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -38,9 +38,10 @@ func FormatDecisions(decisions []*ent.Decision) []*models.Decision { } func (c *Controller) GetDecision(gctx *gin.Context) { - var err error - var results []*models.Decision - var data []*ent.Decision + var ( + results []*models.Decision + data []*ent.Decision + ) bouncerInfo, err := getBouncerFromContext(gctx) if err != nil { @@ -89,6 +90,7 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { return } + nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionByID(decisionID) if err != nil { c.HandleDBErrors(gctx, err) @@ -351,10 +353,13 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + return err } + ret["deleted"] = FormatDecisions(data) gctx.JSON(http.StatusOK, ret) + return nil } @@ -362,9 +367,11 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { var err error streamStartTime := time.Now().UTC() + bouncerInfo, err := getBouncerFromContext(gctx) if err != nil { gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) + return } @@ -372,6 +379,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { //For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db //We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true) gctx.String(http.StatusOK, "") + return } diff --git a/pkg/apiserver/controllers/v1/metrics.go b/pkg/apiserver/controllers/v1/metrics.go index b1d95dd67..13ccf9ac9 100644 --- a/pkg/apiserver/controllers/v1/metrics.go +++ b/pkg/apiserver/controllers/v1/metrics.go @@ -115,6 +115,7 @@ func PrometheusBouncersMiddleware() gin.HandlerFunc { func PrometheusMiddleware() gin.HandlerFunc { return func(c *gin.Context) { startTime := time.Now() + LapiRouteHits.With(prometheus.Labels{ "route": c.Request.URL.Path, "method": c.Request.Method}).Inc() diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 682f6b638..2f5f808ca 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -203,7 +203,6 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { } c.Set(bouncerContextKey, bouncer) - c.Next() } } diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index ef863a7a2..ed4ad107b 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -43,6 +43,7 @@ func PayloadFunc(data interface{}) jwt.MapClaims { func IdentityHandler(c *gin.Context) interface{} { claims := jwt.ExtractClaims(c) machineID := claims[identityKey].(string) + return &models.WatcherAuthRequest{ MachineID: &machineID, } @@ -93,9 +94,12 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { "ip": c.ClientIP(), "cn": extractedCN, }).Errorf("error generating password: %s", err) + return nil, fmt.Errorf("error generating password") } + password := strfmt.Password(pwd) + ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType) if err != nil { return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err) @@ -114,27 +118,33 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { }{ Scenarios: []string{}, } + err = c.ShouldBindJSON(&loginInput) if err != nil { return nil, fmt.Errorf("missing scenarios list in login request for TLS auth: %w", err) } + ret.scenariosInput = loginInput.Scenarios return &ret, nil } func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { - var loginInput models.WatcherAuthRequest - var err error + var ( + loginInput models.WatcherAuthRequest + err error + ) ret := authInput{} if err = c.ShouldBindJSON(&loginInput); err != nil { return nil, fmt.Errorf("missing: %w", err) } + if err = loginInput.Validate(strfmt.Default); err != nil { return nil, err } + ret.machineID = *loginInput.MachineID password := *loginInput.Password ret.scenariosInput = loginInput.Scenarios @@ -168,8 +178,10 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { } func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { - var err error - var auth *authInput + var ( + err error + auth *authInput + ) if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { auth, err = j.authTLS(c) @@ -193,6 +205,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { scenarios += "," + scenario } } + err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err) @@ -210,6 +223,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { if auth.clientMachine.IpAddress != c.ClientIP() && auth.clientMachine.IpAddress != "" { log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, c.ClientIP(), auth.clientMachine.IpAddress) + err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err) @@ -228,10 +242,10 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { log.Errorf("bad user agent from : %s", c.ClientIP()) return nil, jwt.ErrFailedAuthentication } + return &models.WatcherAuthRequest{ MachineID: &auth.machineID, }, nil - } func Authorizator(data interface{}, c *gin.Context) bool { diff --git a/pkg/apiserver/middlewares/v1/middlewares.go b/pkg/apiserver/middlewares/v1/middlewares.go index ef2d93b92..a5409ea5c 100644 --- a/pkg/apiserver/middlewares/v1/middlewares.go +++ b/pkg/apiserver/middlewares/v1/middlewares.go @@ -18,5 +18,6 @@ func NewMiddlewares(dbClient *database.Client) (*Middlewares, error) { } ret.APIKey = NewAPIKey(dbClient) + return ret, nil } diff --git a/pkg/apiserver/middlewares/v1/tls_auth.go b/pkg/apiserver/middlewares/v1/tls_auth.go index 87ca896a8..904f6cd44 100644 --- a/pkg/apiserver/middlewares/v1/tls_auth.go +++ b/pkg/apiserver/middlewares/v1/tls_auth.go @@ -36,32 +36,40 @@ func (ta *TLSAuth) ocspQuery(server string, cert *x509.Certificate, issuer *x509 ta.logger.Errorf("TLSAuth: error creating OCSP request: %s", err) return nil, err } + httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req)) if err != nil { ta.logger.Error("TLSAuth: cannot create HTTP request for OCSP") return nil, err } + ocspURL, err := url.Parse(server) if err != nil { ta.logger.Error("TLSAuth: cannot parse OCSP URL") return nil, err } + httpRequest.Header.Add("Content-Type", "application/ocsp-request") httpRequest.Header.Add("Accept", "application/ocsp-response") httpRequest.Header.Add("host", ocspURL.Host) + httpClient := &http.Client{} + httpResponse, err := httpClient.Do(httpRequest) if err != nil { ta.logger.Error("TLSAuth: cannot send HTTP request to OCSP") return nil, err } defer httpResponse.Body.Close() + output, err := io.ReadAll(httpResponse.Body) if err != nil { ta.logger.Error("TLSAuth: cannot read HTTP response from OCSP") return nil, err } + ocspResponse, err := ocsp.ParseResponseForCert(output, cert, issuer) + return ocspResponse, err } @@ -72,10 +80,12 @@ func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool { ta.logger.Errorf("TLSAuth: client certificate is expired (NotAfter: %s)", cert.NotAfter.UTC()) return true } + if cert.NotBefore.UTC().After(now) { ta.logger.Errorf("TLSAuth: client certificate is not yet valid (NotBefore: %s)", cert.NotBefore.UTC()) return true } + return false } @@ -84,12 +94,14 @@ func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificat ta.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification") return false, nil } + for _, server := range cert.OCSPServer { ocspResponse, err := ta.ocspQuery(server, cert, issuer) if err != nil { ta.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err) continue } + switch ocspResponse.Status { case ocsp.Good: return false, nil @@ -100,7 +112,9 @@ func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificat continue } } + log.Infof("Could not get any valid OCSP response, assuming the cert is revoked") + return true, nil } @@ -109,24 +123,29 @@ func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, error) { ta.logger.Warn("no crl_path, skipping CRL check") return false, nil } + crlContent, err := os.ReadFile(ta.CrlPath) if err != nil { ta.logger.Warnf("could not read CRL file, skipping check: %s", err) return false, nil } + crl, err := x509.ParseCRL(crlContent) if err != nil { ta.logger.Warnf("could not parse CRL file, skipping check: %s", err) return false, nil } + if crl.HasExpired(time.Now().UTC()) { ta.logger.Warn("CRL has expired, will still validate the cert against it.") } + for _, revoked := range crl.TBSCertList.RevokedCertificates { if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 { return true, fmt.Errorf("client certificate is revoked by CRL") } } + return false, nil } @@ -143,6 +162,7 @@ func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) ( } else { ta.logger.Tracef("TLSAuth: no cached value for cert %s", sn) } + revoked, err := ta.isOCSPRevoked(cert, issuer) if err != nil { ta.revokationCache[sn] = cacheEntry{ @@ -150,22 +170,27 @@ func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) ( err: err, timestamp: time.Now().UTC(), } + return true, err } + if revoked { ta.revokationCache[sn] = cacheEntry{ revoked: revoked, err: err, timestamp: time.Now().UTC(), } + return true, nil } + revoked, err = ta.isCRLRevoked(cert) ta.revokationCache[sn] = cacheEntry{ revoked: revoked, err: err, timestamp: time.Now().UTC(), } + return revoked, err } @@ -173,6 +198,7 @@ func (ta *TLSAuth) isInvalid(cert *x509.Certificate, issuer *x509.Certificate) ( if ta.isExpired(cert) { return true, nil } + revoked, err := ta.isRevoked(cert, issuer) if err != nil { //Fail securely, if we can't check the revocation status, let's consider the cert invalid @@ -189,24 +215,30 @@ func (ta *TLSAuth) SetAllowedOu(allowedOus []string) error { if ou == "" { return fmt.Errorf("empty ou isn't allowed") } + //drop & warn on duplicate ou ok := true + for _, validOu := range ta.AllowedOUs { if validOu == ou { ta.logger.Warningf("dropping duplicate ou %s", ou) + ok = false } } + if ok { ta.AllowedOUs = append(ta.AllowedOUs, ou) } } + return nil } func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) { //Checks cert validity, Returns true + CN if client cert matches requested OU var clientCert *x509.Certificate + if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 { //do not error if it's not TLS or there are no peer certs return false, "", nil @@ -215,6 +247,7 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) { if len(c.Request.TLS.VerifiedChains) > 0 { validOU := false clientCert = c.Request.TLS.VerifiedChains[0][0] + for _, ou := range clientCert.Subject.OrganizationalUnit { for _, allowedOu := range ta.AllowedOUs { if allowedOu == ou { @@ -223,21 +256,27 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) { } } } + if !validOU { return false, "", fmt.Errorf("client certificate OU (%v) doesn't match expected OU (%v)", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs) } + revoked, err := ta.isInvalid(clientCert, c.Request.TLS.VerifiedChains[0][1]) if err != nil { ta.logger.Errorf("TLSAuth: error checking if client certificate is revoked: %s", err) return false, "", fmt.Errorf("could not check for client certification revokation status: %w", err) } + if revoked { return false, "", fmt.Errorf("client certificate is revoked") } + ta.logger.Debugf("client OU %v is allowed vs required OU %v", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs) + return true, clientCert.Subject.CommonName, nil } + return false, "", fmt.Errorf("no verified cert in request") } @@ -248,9 +287,11 @@ func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Durati CrlPath: crlPath, logger: logger, } + err := ta.SetAllowedOu(allowedOus) if err != nil { return nil, err } + return ta, nil } diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index 2cf032d26..a3996850a 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -205,12 +205,15 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error { reversedEvents := reverse(events) //PAPI sends events in the reverse order, which is not an issue when pulling them in real time, but here we need the correct order eventsCount := len(events) p.Logger.Infof("received %d events", eventsCount) + for i, event := range reversedEvents { if err := p.handleEvent(event, sync); err != nil { p.Logger.WithField("request-id", event.RequestId).Errorf("failed to handle event: %s", err) } + p.Logger.Debugf("handled event %d/%d", i, eventsCount) } + p.Logger.Debugf("finished handling events") //Don't update the timestamp in DB, as a "real" LAPI might be running //Worst case, crowdsec will receive a few duplicated events and will discard them @@ -223,16 +226,19 @@ func (p *Papi) Pull() error { p.Logger.Infof("Starting Polling API Pull") lastTimestamp := time.Time{} + lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey) if err != nil { p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err) } + //value doesn't exist, it's first time we're pulling if lastTimestampStr == nil { binTime, err := lastTimestamp.MarshalText() if err != nil { return fmt.Errorf("failed to marshal last timestamp: %w", err) } + if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { p.Logger.Errorf("error setting papi pull last key: %s", err) } else { @@ -245,10 +251,12 @@ func (p *Papi) Pull() error { } p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp) + for event := range p.Client.Start(lastTimestamp) { logger := p.Logger.WithField("request-id", event.RequestId) //update last timestamp in database newTime := time.Now().UTC() + binTime, err := newTime.MarshalText() if err != nil { return fmt.Errorf("failed to marshal last timestamp: %w", err) @@ -262,11 +270,11 @@ func (p *Papi) Pull() error { if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { return fmt.Errorf("failed to update last timestamp: %w", err) - } else { - logger.Debugf("set last timestamp to %s", newTime) } + logger.Debugf("set last timestamp to %s", newTime) } + return nil } @@ -274,6 +282,7 @@ func (p *Papi) SyncDecisions() error { defer trace.CatchPanic("lapi/syncDecisionsToCAPI") var cache models.DecisionsDeleteRequest + ticker := time.NewTicker(p.SyncInterval) p.Logger.Infof("Start decisions sync to CrowdSec Central API (interval: %s)", p.SyncInterval) @@ -281,10 +290,13 @@ func (p *Papi) SyncDecisions() error { select { case <-p.syncTomb.Dying(): // if one apic routine is dying, do we kill the others? p.Logger.Infof("sync decisions tomb is dying, sending cache (%d elements) before exiting", len(cache)) + if len(cache) == 0 { return nil } + go p.SendDeletedDecisions(&cache) + return nil case <-ticker.C: if len(cache) > 0 { @@ -293,15 +305,19 @@ func (p *Papi) SyncDecisions() error { cache = make([]models.DecisionsDeleteRequestItem, 0) p.mu.Unlock() p.Logger.Infof("sync decisions: %d deleted decisions to push", len(cacheCopy)) + go p.SendDeletedDecisions(&cacheCopy) } case deletedDecisions := <-p.Channels.DeleteDecisionChannel: if (p.consoleConfig.ShareManualDecisions != nil && *p.consoleConfig.ShareManualDecisions) || (p.consoleConfig.ConsoleManagement != nil && *p.consoleConfig.ConsoleManagement) { var tmpDecisions []models.DecisionsDeleteRequestItem + p.Logger.Debugf("%d decisions deletion to add in cache", len(deletedDecisions)) + for _, decision := range deletedDecisions { tmpDecisions = append(tmpDecisions, models.DecisionsDeleteRequestItem(decision.UUID)) } + p.mu.Lock() cache = append(cache, tmpDecisions...) p.mu.Unlock() @@ -311,33 +327,42 @@ func (p *Papi) SyncDecisions() error { } func (p *Papi) SendDeletedDecisions(cacheOrig *models.DecisionsDeleteRequest) { - - var cache []models.DecisionsDeleteRequestItem = *cacheOrig - var send models.DecisionsDeleteRequest + var ( + cache []models.DecisionsDeleteRequestItem = *cacheOrig + send models.DecisionsDeleteRequest + ) bulkSize := 50 pageStart := 0 pageEnd := bulkSize + for { if pageEnd >= len(cache) { send = cache[pageStart:] ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, _, err := p.apiClient.DecisionDelete.Add(ctx, &send) if err != nil { p.Logger.Errorf("sending deleted decisions to central API: %s", err) return } + break } + send = cache[pageStart:pageEnd] ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, _, err := p.apiClient.DecisionDelete.Add(ctx, &send) if err != nil { //we log it here as well, because the return value of func might be discarded p.Logger.Errorf("sending deleted decisions to central API: %s", err) } + pageStart += bulkSize pageEnd += bulkSize } diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index 6ab8f3734..ba0203488 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -40,17 +40,18 @@ type forcePull struct { func DecisionCmd(message *Message, p *Papi, sync bool) error { switch message.Header.OperationCmd { case "delete": - data, err := json.Marshal(message.Data) if err != nil { return err } + UUIDs := make([]string, 0) deleteDecisionMsg := deleteDecisions{ Decisions: make([]string, 0), } + if err := json.Unmarshal(data, &deleteDecisionMsg); err != nil { - return fmt.Errorf("message for '%s' contains bad data format: %s", message.Header.OperationType, err) + return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) } UUIDs = append(UUIDs, deleteDecisionMsg.Decisions...) @@ -59,10 +60,13 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error { filter := make(map[string][]string) filter["uuid"] = UUIDs _, deletedDecisions, err := p.DBClient.SoftDeleteDecisionsWithFilter(filter) + if err != nil { - return fmt.Errorf("unable to delete decisions %+v : %s", UUIDs, err) + return fmt.Errorf("unable to delete decisions %+v: %w", UUIDs, err) } + decisions := make([]*models.Decision, 0) + for _, deletedDecision := range deletedDecisions { log.Infof("Decision from '%s' for '%s' (%s) has been deleted", deletedDecision.Origin, deletedDecision.Value, deletedDecision.Type) dec := &models.Decision{ @@ -92,6 +96,7 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { if err != nil { return err } + alert := &models.Alert{} if err := json.Unmarshal(data, alert); err != nil { @@ -105,10 +110,12 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { log.Warnf("Alert %d has no StartAt, setting it to now", alert.ID) alert.StartAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) } + if alert.StopAt == nil || *alert.StopAt == "" { log.Warnf("Alert %d has no StopAt, setting it to now", alert.ID) alert.StopAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) } + alert.EventsCount = ptr.Of(int32(0)) alert.Capacity = ptr.Of(int32(0)) alert.Leakspeed = ptr.Of("") @@ -128,12 +135,14 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { alert.Source.Scope = ptr.Of(types.ConsoleOrigin) alert.Source.Value = &message.Header.Source.User } + alert.Scenario = &message.Header.Message for _, decision := range alert.Decisions { if *decision.Scenario == "" { decision.Scenario = &message.Header.Message } + log.Infof("Adding decision for '%s' with UUID: %s", *decision.Value, decision.UUID) } @@ -157,6 +166,7 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { log.Infof("Ignoring management command from PAPI in sync mode") return nil } + switch message.Header.OperationCmd { case "reauth": log.Infof("Received reauth command from PAPI, resetting token") @@ -187,12 +197,12 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { Duration: &forcePullMsg.Blocklist.Duration, }, true) if err != nil { - return fmt.Errorf("failed to force pull operation: %s", err) + return fmt.Errorf("failed to force pull operation: %w", err) } } - default: return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType) } + return nil }