apiclient/apiserver: lint/2 (#2741)
This commit is contained in:
parent
75d8ad9798
commit
48f011dc1c
11 changed files with 121 additions and 20 deletions
|
@ -56,7 +56,7 @@ func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest)
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
var addedIds models.AddAlertsResponse
|
||||
addedIds := models.AddAlertsResponse{}
|
||||
|
||||
resp, err := s.client.Do(ctx, req, &addedIds)
|
||||
if err != nil {
|
||||
|
|
|
@ -50,6 +50,7 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) {
|
|||
log.Errorf("heartbeat unexpected return code: %d", resp.Response.StatusCode)
|
||||
continue
|
||||
}
|
||||
|
||||
if !ok {
|
||||
log.Errorf("heartbeat returned false")
|
||||
continue
|
||||
|
|
|
@ -651,6 +651,7 @@ func (a *apic) PullTop(forcePull bool) error {
|
|||
}
|
||||
|
||||
addCounters, deleteCounters := makeAddAndDeleteCounters()
|
||||
|
||||
// process deleted decisions
|
||||
nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters)
|
||||
if err != nil {
|
||||
|
|
|
@ -38,9 +38,10 @@ func FormatDecisions(decisions []*ent.Decision) []*models.Decision {
|
|||
}
|
||||
|
||||
func (c *Controller) GetDecision(gctx *gin.Context) {
|
||||
var err error
|
||||
var results []*models.Decision
|
||||
var data []*ent.Decision
|
||||
var (
|
||||
results []*models.Decision
|
||||
data []*ent.Decision
|
||||
)
|
||||
|
||||
bouncerInfo, err := getBouncerFromContext(gctx)
|
||||
if err != nil {
|
||||
|
@ -89,6 +90,7 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) {
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionByID(decisionID)
|
||||
if err != nil {
|
||||
c.HandleDBErrors(gctx, err)
|
||||
|
@ -351,10 +353,13 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
|
|||
if err != nil {
|
||||
log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err)
|
||||
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
ret["deleted"] = FormatDecisions(data)
|
||||
gctx.JSON(http.StatusOK, ret)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -362,9 +367,11 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {
|
|||
var err error
|
||||
|
||||
streamStartTime := time.Now().UTC()
|
||||
|
||||
bouncerInfo, err := getBouncerFromContext(gctx)
|
||||
if err != nil {
|
||||
gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -372,6 +379,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {
|
|||
//For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db
|
||||
//We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true)
|
||||
gctx.String(http.StatusOK, "")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -115,6 +115,7 @@ func PrometheusBouncersMiddleware() gin.HandlerFunc {
|
|||
func PrometheusMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
|
||||
LapiRouteHits.With(prometheus.Labels{
|
||||
"route": c.Request.URL.Path,
|
||||
"method": c.Request.Method}).Inc()
|
||||
|
|
|
@ -203,7 +203,6 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
|
|||
}
|
||||
|
||||
c.Set(bouncerContextKey, bouncer)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,6 +43,7 @@ func PayloadFunc(data interface{}) jwt.MapClaims {
|
|||
func IdentityHandler(c *gin.Context) interface{} {
|
||||
claims := jwt.ExtractClaims(c)
|
||||
machineID := claims[identityKey].(string)
|
||||
|
||||
return &models.WatcherAuthRequest{
|
||||
MachineID: &machineID,
|
||||
}
|
||||
|
@ -93,9 +94,12 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
|
|||
"ip": c.ClientIP(),
|
||||
"cn": extractedCN,
|
||||
}).Errorf("error generating password: %s", err)
|
||||
|
||||
return nil, fmt.Errorf("error generating password")
|
||||
}
|
||||
|
||||
password := strfmt.Password(pwd)
|
||||
|
||||
ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err)
|
||||
|
@ -114,27 +118,33 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
|
|||
}{
|
||||
Scenarios: []string{},
|
||||
}
|
||||
|
||||
err = c.ShouldBindJSON(&loginInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("missing scenarios list in login request for TLS auth: %w", err)
|
||||
}
|
||||
|
||||
ret.scenariosInput = loginInput.Scenarios
|
||||
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
|
||||
var loginInput models.WatcherAuthRequest
|
||||
var err error
|
||||
var (
|
||||
loginInput models.WatcherAuthRequest
|
||||
err error
|
||||
)
|
||||
|
||||
ret := authInput{}
|
||||
|
||||
if err = c.ShouldBindJSON(&loginInput); err != nil {
|
||||
return nil, fmt.Errorf("missing: %w", err)
|
||||
}
|
||||
|
||||
if err = loginInput.Validate(strfmt.Default); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret.machineID = *loginInput.MachineID
|
||||
password := *loginInput.Password
|
||||
ret.scenariosInput = loginInput.Scenarios
|
||||
|
@ -168,8 +178,10 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
|
|||
}
|
||||
|
||||
func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
|
||||
var err error
|
||||
var auth *authInput
|
||||
var (
|
||||
err error
|
||||
auth *authInput
|
||||
)
|
||||
|
||||
if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
|
||||
auth, err = j.authTLS(c)
|
||||
|
@ -193,6 +205,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
|
|||
scenarios += "," + scenario
|
||||
}
|
||||
}
|
||||
|
||||
err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err)
|
||||
|
@ -210,6 +223,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
|
|||
|
||||
if auth.clientMachine.IpAddress != c.ClientIP() && auth.clientMachine.IpAddress != "" {
|
||||
log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, c.ClientIP(), auth.clientMachine.IpAddress)
|
||||
|
||||
err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err)
|
||||
|
@ -228,10 +242,10 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
|
|||
log.Errorf("bad user agent from : %s", c.ClientIP())
|
||||
return nil, jwt.ErrFailedAuthentication
|
||||
}
|
||||
|
||||
return &models.WatcherAuthRequest{
|
||||
MachineID: &auth.machineID,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func Authorizator(data interface{}, c *gin.Context) bool {
|
||||
|
|
|
@ -18,5 +18,6 @@ func NewMiddlewares(dbClient *database.Client) (*Middlewares, error) {
|
|||
}
|
||||
|
||||
ret.APIKey = NewAPIKey(dbClient)
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
|
|
@ -36,32 +36,40 @@ func (ta *TLSAuth) ocspQuery(server string, cert *x509.Certificate, issuer *x509
|
|||
ta.logger.Errorf("TLSAuth: error creating OCSP request: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req))
|
||||
if err != nil {
|
||||
ta.logger.Error("TLSAuth: cannot create HTTP request for OCSP")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ocspURL, err := url.Parse(server)
|
||||
if err != nil {
|
||||
ta.logger.Error("TLSAuth: cannot parse OCSP URL")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpRequest.Header.Add("Content-Type", "application/ocsp-request")
|
||||
httpRequest.Header.Add("Accept", "application/ocsp-response")
|
||||
httpRequest.Header.Add("host", ocspURL.Host)
|
||||
|
||||
httpClient := &http.Client{}
|
||||
|
||||
httpResponse, err := httpClient.Do(httpRequest)
|
||||
if err != nil {
|
||||
ta.logger.Error("TLSAuth: cannot send HTTP request to OCSP")
|
||||
return nil, err
|
||||
}
|
||||
defer httpResponse.Body.Close()
|
||||
|
||||
output, err := io.ReadAll(httpResponse.Body)
|
||||
if err != nil {
|
||||
ta.logger.Error("TLSAuth: cannot read HTTP response from OCSP")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ocspResponse, err := ocsp.ParseResponseForCert(output, cert, issuer)
|
||||
|
||||
return ocspResponse, err
|
||||
}
|
||||
|
||||
|
@ -72,10 +80,12 @@ func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool {
|
|||
ta.logger.Errorf("TLSAuth: client certificate is expired (NotAfter: %s)", cert.NotAfter.UTC())
|
||||
return true
|
||||
}
|
||||
|
||||
if cert.NotBefore.UTC().After(now) {
|
||||
ta.logger.Errorf("TLSAuth: client certificate is not yet valid (NotBefore: %s)", cert.NotBefore.UTC())
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -84,12 +94,14 @@ func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificat
|
|||
ta.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, server := range cert.OCSPServer {
|
||||
ocspResponse, err := ta.ocspQuery(server, cert, issuer)
|
||||
if err != nil {
|
||||
ta.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch ocspResponse.Status {
|
||||
case ocsp.Good:
|
||||
return false, nil
|
||||
|
@ -100,7 +112,9 @@ func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificat
|
|||
continue
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("Could not get any valid OCSP response, assuming the cert is revoked")
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
@ -109,24 +123,29 @@ func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, error) {
|
|||
ta.logger.Warn("no crl_path, skipping CRL check")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
crlContent, err := os.ReadFile(ta.CrlPath)
|
||||
if err != nil {
|
||||
ta.logger.Warnf("could not read CRL file, skipping check: %s", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
crl, err := x509.ParseCRL(crlContent)
|
||||
if err != nil {
|
||||
ta.logger.Warnf("could not parse CRL file, skipping check: %s", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if crl.HasExpired(time.Now().UTC()) {
|
||||
ta.logger.Warn("CRL has expired, will still validate the cert against it.")
|
||||
}
|
||||
|
||||
for _, revoked := range crl.TBSCertList.RevokedCertificates {
|
||||
if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 {
|
||||
return true, fmt.Errorf("client certificate is revoked by CRL")
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
@ -143,6 +162,7 @@ func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) (
|
|||
} else {
|
||||
ta.logger.Tracef("TLSAuth: no cached value for cert %s", sn)
|
||||
}
|
||||
|
||||
revoked, err := ta.isOCSPRevoked(cert, issuer)
|
||||
if err != nil {
|
||||
ta.revokationCache[sn] = cacheEntry{
|
||||
|
@ -150,22 +170,27 @@ func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) (
|
|||
err: err,
|
||||
timestamp: time.Now().UTC(),
|
||||
}
|
||||
|
||||
return true, err
|
||||
}
|
||||
|
||||
if revoked {
|
||||
ta.revokationCache[sn] = cacheEntry{
|
||||
revoked: revoked,
|
||||
err: err,
|
||||
timestamp: time.Now().UTC(),
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
revoked, err = ta.isCRLRevoked(cert)
|
||||
ta.revokationCache[sn] = cacheEntry{
|
||||
revoked: revoked,
|
||||
err: err,
|
||||
timestamp: time.Now().UTC(),
|
||||
}
|
||||
|
||||
return revoked, err
|
||||
}
|
||||
|
||||
|
@ -173,6 +198,7 @@ func (ta *TLSAuth) isInvalid(cert *x509.Certificate, issuer *x509.Certificate) (
|
|||
if ta.isExpired(cert) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
revoked, err := ta.isRevoked(cert, issuer)
|
||||
if err != nil {
|
||||
//Fail securely, if we can't check the revocation status, let's consider the cert invalid
|
||||
|
@ -189,24 +215,30 @@ func (ta *TLSAuth) SetAllowedOu(allowedOus []string) error {
|
|||
if ou == "" {
|
||||
return fmt.Errorf("empty ou isn't allowed")
|
||||
}
|
||||
|
||||
//drop & warn on duplicate ou
|
||||
ok := true
|
||||
|
||||
for _, validOu := range ta.AllowedOUs {
|
||||
if validOu == ou {
|
||||
ta.logger.Warningf("dropping duplicate ou %s", ou)
|
||||
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
|
||||
if ok {
|
||||
ta.AllowedOUs = append(ta.AllowedOUs, ou)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) {
|
||||
//Checks cert validity, Returns true + CN if client cert matches requested OU
|
||||
var clientCert *x509.Certificate
|
||||
|
||||
if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 {
|
||||
//do not error if it's not TLS or there are no peer certs
|
||||
return false, "", nil
|
||||
|
@ -215,6 +247,7 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) {
|
|||
if len(c.Request.TLS.VerifiedChains) > 0 {
|
||||
validOU := false
|
||||
clientCert = c.Request.TLS.VerifiedChains[0][0]
|
||||
|
||||
for _, ou := range clientCert.Subject.OrganizationalUnit {
|
||||
for _, allowedOu := range ta.AllowedOUs {
|
||||
if allowedOu == ou {
|
||||
|
@ -223,21 +256,27 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !validOU {
|
||||
return false, "", fmt.Errorf("client certificate OU (%v) doesn't match expected OU (%v)",
|
||||
clientCert.Subject.OrganizationalUnit, ta.AllowedOUs)
|
||||
}
|
||||
|
||||
revoked, err := ta.isInvalid(clientCert, c.Request.TLS.VerifiedChains[0][1])
|
||||
if err != nil {
|
||||
ta.logger.Errorf("TLSAuth: error checking if client certificate is revoked: %s", err)
|
||||
return false, "", fmt.Errorf("could not check for client certification revokation status: %w", err)
|
||||
}
|
||||
|
||||
if revoked {
|
||||
return false, "", fmt.Errorf("client certificate is revoked")
|
||||
}
|
||||
|
||||
ta.logger.Debugf("client OU %v is allowed vs required OU %v", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs)
|
||||
|
||||
return true, clientCert.Subject.CommonName, nil
|
||||
}
|
||||
|
||||
return false, "", fmt.Errorf("no verified cert in request")
|
||||
}
|
||||
|
||||
|
@ -248,9 +287,11 @@ func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Durati
|
|||
CrlPath: crlPath,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
err := ta.SetAllowedOu(allowedOus)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ta, nil
|
||||
}
|
||||
|
|
|
@ -205,12 +205,15 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error {
|
|||
reversedEvents := reverse(events) //PAPI sends events in the reverse order, which is not an issue when pulling them in real time, but here we need the correct order
|
||||
eventsCount := len(events)
|
||||
p.Logger.Infof("received %d events", eventsCount)
|
||||
|
||||
for i, event := range reversedEvents {
|
||||
if err := p.handleEvent(event, sync); err != nil {
|
||||
p.Logger.WithField("request-id", event.RequestId).Errorf("failed to handle event: %s", err)
|
||||
}
|
||||
|
||||
p.Logger.Debugf("handled event %d/%d", i, eventsCount)
|
||||
}
|
||||
|
||||
p.Logger.Debugf("finished handling events")
|
||||
//Don't update the timestamp in DB, as a "real" LAPI might be running
|
||||
//Worst case, crowdsec will receive a few duplicated events and will discard them
|
||||
|
@ -223,16 +226,19 @@ func (p *Papi) Pull() error {
|
|||
p.Logger.Infof("Starting Polling API Pull")
|
||||
|
||||
lastTimestamp := time.Time{}
|
||||
|
||||
lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey)
|
||||
if err != nil {
|
||||
p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err)
|
||||
}
|
||||
|
||||
//value doesn't exist, it's first time we're pulling
|
||||
if lastTimestampStr == nil {
|
||||
binTime, err := lastTimestamp.MarshalText()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal last timestamp: %w", err)
|
||||
}
|
||||
|
||||
if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil {
|
||||
p.Logger.Errorf("error setting papi pull last key: %s", err)
|
||||
} else {
|
||||
|
@ -245,10 +251,12 @@ func (p *Papi) Pull() error {
|
|||
}
|
||||
|
||||
p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp)
|
||||
|
||||
for event := range p.Client.Start(lastTimestamp) {
|
||||
logger := p.Logger.WithField("request-id", event.RequestId)
|
||||
//update last timestamp in database
|
||||
newTime := time.Now().UTC()
|
||||
|
||||
binTime, err := newTime.MarshalText()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal last timestamp: %w", err)
|
||||
|
@ -262,11 +270,11 @@ func (p *Papi) Pull() error {
|
|||
|
||||
if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil {
|
||||
return fmt.Errorf("failed to update last timestamp: %w", err)
|
||||
} else {
|
||||
logger.Debugf("set last timestamp to %s", newTime)
|
||||
}
|
||||
|
||||
logger.Debugf("set last timestamp to %s", newTime)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -274,6 +282,7 @@ func (p *Papi) SyncDecisions() error {
|
|||
defer trace.CatchPanic("lapi/syncDecisionsToCAPI")
|
||||
|
||||
var cache models.DecisionsDeleteRequest
|
||||
|
||||
ticker := time.NewTicker(p.SyncInterval)
|
||||
p.Logger.Infof("Start decisions sync to CrowdSec Central API (interval: %s)", p.SyncInterval)
|
||||
|
||||
|
@ -281,10 +290,13 @@ func (p *Papi) SyncDecisions() error {
|
|||
select {
|
||||
case <-p.syncTomb.Dying(): // if one apic routine is dying, do we kill the others?
|
||||
p.Logger.Infof("sync decisions tomb is dying, sending cache (%d elements) before exiting", len(cache))
|
||||
|
||||
if len(cache) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
go p.SendDeletedDecisions(&cache)
|
||||
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
if len(cache) > 0 {
|
||||
|
@ -293,15 +305,19 @@ func (p *Papi) SyncDecisions() error {
|
|||
cache = make([]models.DecisionsDeleteRequestItem, 0)
|
||||
p.mu.Unlock()
|
||||
p.Logger.Infof("sync decisions: %d deleted decisions to push", len(cacheCopy))
|
||||
|
||||
go p.SendDeletedDecisions(&cacheCopy)
|
||||
}
|
||||
case deletedDecisions := <-p.Channels.DeleteDecisionChannel:
|
||||
if (p.consoleConfig.ShareManualDecisions != nil && *p.consoleConfig.ShareManualDecisions) || (p.consoleConfig.ConsoleManagement != nil && *p.consoleConfig.ConsoleManagement) {
|
||||
var tmpDecisions []models.DecisionsDeleteRequestItem
|
||||
|
||||
p.Logger.Debugf("%d decisions deletion to add in cache", len(deletedDecisions))
|
||||
|
||||
for _, decision := range deletedDecisions {
|
||||
tmpDecisions = append(tmpDecisions, models.DecisionsDeleteRequestItem(decision.UUID))
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
cache = append(cache, tmpDecisions...)
|
||||
p.mu.Unlock()
|
||||
|
@ -311,33 +327,42 @@ func (p *Papi) SyncDecisions() error {
|
|||
}
|
||||
|
||||
func (p *Papi) SendDeletedDecisions(cacheOrig *models.DecisionsDeleteRequest) {
|
||||
|
||||
var cache []models.DecisionsDeleteRequestItem = *cacheOrig
|
||||
var send models.DecisionsDeleteRequest
|
||||
var (
|
||||
cache []models.DecisionsDeleteRequestItem = *cacheOrig
|
||||
send models.DecisionsDeleteRequest
|
||||
)
|
||||
|
||||
bulkSize := 50
|
||||
pageStart := 0
|
||||
pageEnd := bulkSize
|
||||
|
||||
for {
|
||||
if pageEnd >= len(cache) {
|
||||
send = cache[pageStart:]
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
defer cancel()
|
||||
|
||||
_, _, err := p.apiClient.DecisionDelete.Add(ctx, &send)
|
||||
if err != nil {
|
||||
p.Logger.Errorf("sending deleted decisions to central API: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
send = cache[pageStart:pageEnd]
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
defer cancel()
|
||||
|
||||
_, _, err := p.apiClient.DecisionDelete.Add(ctx, &send)
|
||||
if err != nil {
|
||||
//we log it here as well, because the return value of func might be discarded
|
||||
p.Logger.Errorf("sending deleted decisions to central API: %s", err)
|
||||
}
|
||||
|
||||
pageStart += bulkSize
|
||||
pageEnd += bulkSize
|
||||
}
|
||||
|
|
|
@ -40,17 +40,18 @@ type forcePull struct {
|
|||
func DecisionCmd(message *Message, p *Papi, sync bool) error {
|
||||
switch message.Header.OperationCmd {
|
||||
case "delete":
|
||||
|
||||
data, err := json.Marshal(message.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
UUIDs := make([]string, 0)
|
||||
deleteDecisionMsg := deleteDecisions{
|
||||
Decisions: make([]string, 0),
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &deleteDecisionMsg); err != nil {
|
||||
return fmt.Errorf("message for '%s' contains bad data format: %s", message.Header.OperationType, err)
|
||||
return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err)
|
||||
}
|
||||
|
||||
UUIDs = append(UUIDs, deleteDecisionMsg.Decisions...)
|
||||
|
@ -59,10 +60,13 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error {
|
|||
filter := make(map[string][]string)
|
||||
filter["uuid"] = UUIDs
|
||||
_, deletedDecisions, err := p.DBClient.SoftDeleteDecisionsWithFilter(filter)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to delete decisions %+v : %s", UUIDs, err)
|
||||
return fmt.Errorf("unable to delete decisions %+v: %w", UUIDs, err)
|
||||
}
|
||||
|
||||
decisions := make([]*models.Decision, 0)
|
||||
|
||||
for _, deletedDecision := range deletedDecisions {
|
||||
log.Infof("Decision from '%s' for '%s' (%s) has been deleted", deletedDecision.Origin, deletedDecision.Value, deletedDecision.Type)
|
||||
dec := &models.Decision{
|
||||
|
@ -92,6 +96,7 @@ func AlertCmd(message *Message, p *Papi, sync bool) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
alert := &models.Alert{}
|
||||
|
||||
if err := json.Unmarshal(data, alert); err != nil {
|
||||
|
@ -105,10 +110,12 @@ func AlertCmd(message *Message, p *Papi, sync bool) error {
|
|||
log.Warnf("Alert %d has no StartAt, setting it to now", alert.ID)
|
||||
alert.StartAt = ptr.Of(time.Now().UTC().Format(time.RFC3339))
|
||||
}
|
||||
|
||||
if alert.StopAt == nil || *alert.StopAt == "" {
|
||||
log.Warnf("Alert %d has no StopAt, setting it to now", alert.ID)
|
||||
alert.StopAt = ptr.Of(time.Now().UTC().Format(time.RFC3339))
|
||||
}
|
||||
|
||||
alert.EventsCount = ptr.Of(int32(0))
|
||||
alert.Capacity = ptr.Of(int32(0))
|
||||
alert.Leakspeed = ptr.Of("")
|
||||
|
@ -128,12 +135,14 @@ func AlertCmd(message *Message, p *Papi, sync bool) error {
|
|||
alert.Source.Scope = ptr.Of(types.ConsoleOrigin)
|
||||
alert.Source.Value = &message.Header.Source.User
|
||||
}
|
||||
|
||||
alert.Scenario = &message.Header.Message
|
||||
|
||||
for _, decision := range alert.Decisions {
|
||||
if *decision.Scenario == "" {
|
||||
decision.Scenario = &message.Header.Message
|
||||
}
|
||||
|
||||
log.Infof("Adding decision for '%s' with UUID: %s", *decision.Value, decision.UUID)
|
||||
}
|
||||
|
||||
|
@ -157,6 +166,7 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
|
|||
log.Infof("Ignoring management command from PAPI in sync mode")
|
||||
return nil
|
||||
}
|
||||
|
||||
switch message.Header.OperationCmd {
|
||||
case "reauth":
|
||||
log.Infof("Received reauth command from PAPI, resetting token")
|
||||
|
@ -187,12 +197,12 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
|
|||
Duration: &forcePullMsg.Blocklist.Duration,
|
||||
}, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to force pull operation: %s", err)
|
||||
return fmt.Errorf("failed to force pull operation: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue