apiclient/apiserver: lint/2 (#2741)

This commit is contained in:
mmetc 2024-01-15 12:38:31 +01:00 committed by GitHub
parent 75d8ad9798
commit 48f011dc1c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 121 additions and 20 deletions

View file

@ -56,7 +56,7 @@ func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest)
return nil, nil, err return nil, nil, err
} }
var addedIds models.AddAlertsResponse addedIds := models.AddAlertsResponse{}
resp, err := s.client.Do(ctx, req, &addedIds) resp, err := s.client.Do(ctx, req, &addedIds)
if err != nil { if err != nil {

View file

@ -50,6 +50,7 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) {
log.Errorf("heartbeat unexpected return code: %d", resp.Response.StatusCode) log.Errorf("heartbeat unexpected return code: %d", resp.Response.StatusCode)
continue continue
} }
if !ok { if !ok {
log.Errorf("heartbeat returned false") log.Errorf("heartbeat returned false")
continue continue

View file

@ -651,6 +651,7 @@ func (a *apic) PullTop(forcePull bool) error {
} }
addCounters, deleteCounters := makeAddAndDeleteCounters() addCounters, deleteCounters := makeAddAndDeleteCounters()
// process deleted decisions // process deleted decisions
nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters) nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters)
if err != nil { if err != nil {

View file

@ -38,9 +38,10 @@ func FormatDecisions(decisions []*ent.Decision) []*models.Decision {
} }
func (c *Controller) GetDecision(gctx *gin.Context) { func (c *Controller) GetDecision(gctx *gin.Context) {
var err error var (
var results []*models.Decision results []*models.Decision
var data []*ent.Decision data []*ent.Decision
)
bouncerInfo, err := getBouncerFromContext(gctx) bouncerInfo, err := getBouncerFromContext(gctx)
if err != nil { if err != nil {
@ -89,6 +90,7 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) {
return return
} }
nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionByID(decisionID) nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionByID(decisionID)
if err != nil { if err != nil {
c.HandleDBErrors(gctx, err) c.HandleDBErrors(gctx, err)
@ -351,10 +353,13 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
if err != nil { if err != nil {
log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err)
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
return err return err
} }
ret["deleted"] = FormatDecisions(data) ret["deleted"] = FormatDecisions(data)
gctx.JSON(http.StatusOK, ret) gctx.JSON(http.StatusOK, ret)
return nil return nil
} }
@ -362,9 +367,11 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {
var err error var err error
streamStartTime := time.Now().UTC() streamStartTime := time.Now().UTC()
bouncerInfo, err := getBouncerFromContext(gctx) bouncerInfo, err := getBouncerFromContext(gctx)
if err != nil { if err != nil {
gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"})
return return
} }
@ -372,6 +379,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {
//For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db //For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db
//We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true) //We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true)
gctx.String(http.StatusOK, "") gctx.String(http.StatusOK, "")
return return
} }

View file

@ -115,6 +115,7 @@ func PrometheusBouncersMiddleware() gin.HandlerFunc {
func PrometheusMiddleware() gin.HandlerFunc { func PrometheusMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
startTime := time.Now() startTime := time.Now()
LapiRouteHits.With(prometheus.Labels{ LapiRouteHits.With(prometheus.Labels{
"route": c.Request.URL.Path, "route": c.Request.URL.Path,
"method": c.Request.Method}).Inc() "method": c.Request.Method}).Inc()

View file

@ -203,7 +203,6 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
} }
c.Set(bouncerContextKey, bouncer) c.Set(bouncerContextKey, bouncer)
c.Next() c.Next()
} }
} }

View file

@ -43,6 +43,7 @@ func PayloadFunc(data interface{}) jwt.MapClaims {
func IdentityHandler(c *gin.Context) interface{} { func IdentityHandler(c *gin.Context) interface{} {
claims := jwt.ExtractClaims(c) claims := jwt.ExtractClaims(c)
machineID := claims[identityKey].(string) machineID := claims[identityKey].(string)
return &models.WatcherAuthRequest{ return &models.WatcherAuthRequest{
MachineID: &machineID, MachineID: &machineID,
} }
@ -93,9 +94,12 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
"ip": c.ClientIP(), "ip": c.ClientIP(),
"cn": extractedCN, "cn": extractedCN,
}).Errorf("error generating password: %s", err) }).Errorf("error generating password: %s", err)
return nil, fmt.Errorf("error generating password") return nil, fmt.Errorf("error generating password")
} }
password := strfmt.Password(pwd) password := strfmt.Password(pwd)
ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType) ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType)
if err != nil { if err != nil {
return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err) return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err)
@ -114,27 +118,33 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
}{ }{
Scenarios: []string{}, Scenarios: []string{},
} }
err = c.ShouldBindJSON(&loginInput) err = c.ShouldBindJSON(&loginInput)
if err != nil { if err != nil {
return nil, fmt.Errorf("missing scenarios list in login request for TLS auth: %w", err) return nil, fmt.Errorf("missing scenarios list in login request for TLS auth: %w", err)
} }
ret.scenariosInput = loginInput.Scenarios ret.scenariosInput = loginInput.Scenarios
return &ret, nil return &ret, nil
} }
func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
var loginInput models.WatcherAuthRequest var (
var err error loginInput models.WatcherAuthRequest
err error
)
ret := authInput{} ret := authInput{}
if err = c.ShouldBindJSON(&loginInput); err != nil { if err = c.ShouldBindJSON(&loginInput); err != nil {
return nil, fmt.Errorf("missing: %w", err) return nil, fmt.Errorf("missing: %w", err)
} }
if err = loginInput.Validate(strfmt.Default); err != nil { if err = loginInput.Validate(strfmt.Default); err != nil {
return nil, err return nil, err
} }
ret.machineID = *loginInput.MachineID ret.machineID = *loginInput.MachineID
password := *loginInput.Password password := *loginInput.Password
ret.scenariosInput = loginInput.Scenarios ret.scenariosInput = loginInput.Scenarios
@ -168,8 +178,10 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
} }
func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
var err error var (
var auth *authInput err error
auth *authInput
)
if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
auth, err = j.authTLS(c) auth, err = j.authTLS(c)
@ -193,6 +205,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
scenarios += "," + scenario scenarios += "," + scenario
} }
} }
err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID) err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID)
if err != nil { if err != nil {
log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err) log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err)
@ -210,6 +223,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
if auth.clientMachine.IpAddress != c.ClientIP() && auth.clientMachine.IpAddress != "" { if auth.clientMachine.IpAddress != c.ClientIP() && auth.clientMachine.IpAddress != "" {
log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, c.ClientIP(), auth.clientMachine.IpAddress) log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, c.ClientIP(), auth.clientMachine.IpAddress)
err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID) err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID)
if err != nil { if err != nil {
log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err) log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err)
@ -228,10 +242,10 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
log.Errorf("bad user agent from : %s", c.ClientIP()) log.Errorf("bad user agent from : %s", c.ClientIP())
return nil, jwt.ErrFailedAuthentication return nil, jwt.ErrFailedAuthentication
} }
return &models.WatcherAuthRequest{ return &models.WatcherAuthRequest{
MachineID: &auth.machineID, MachineID: &auth.machineID,
}, nil }, nil
} }
func Authorizator(data interface{}, c *gin.Context) bool { func Authorizator(data interface{}, c *gin.Context) bool {

View file

@ -18,5 +18,6 @@ func NewMiddlewares(dbClient *database.Client) (*Middlewares, error) {
} }
ret.APIKey = NewAPIKey(dbClient) ret.APIKey = NewAPIKey(dbClient)
return ret, nil return ret, nil
} }

View file

@ -36,32 +36,40 @@ func (ta *TLSAuth) ocspQuery(server string, cert *x509.Certificate, issuer *x509
ta.logger.Errorf("TLSAuth: error creating OCSP request: %s", err) ta.logger.Errorf("TLSAuth: error creating OCSP request: %s", err)
return nil, err return nil, err
} }
httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req)) httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req))
if err != nil { if err != nil {
ta.logger.Error("TLSAuth: cannot create HTTP request for OCSP") ta.logger.Error("TLSAuth: cannot create HTTP request for OCSP")
return nil, err return nil, err
} }
ocspURL, err := url.Parse(server) ocspURL, err := url.Parse(server)
if err != nil { if err != nil {
ta.logger.Error("TLSAuth: cannot parse OCSP URL") ta.logger.Error("TLSAuth: cannot parse OCSP URL")
return nil, err return nil, err
} }
httpRequest.Header.Add("Content-Type", "application/ocsp-request") httpRequest.Header.Add("Content-Type", "application/ocsp-request")
httpRequest.Header.Add("Accept", "application/ocsp-response") httpRequest.Header.Add("Accept", "application/ocsp-response")
httpRequest.Header.Add("host", ocspURL.Host) httpRequest.Header.Add("host", ocspURL.Host)
httpClient := &http.Client{} httpClient := &http.Client{}
httpResponse, err := httpClient.Do(httpRequest) httpResponse, err := httpClient.Do(httpRequest)
if err != nil { if err != nil {
ta.logger.Error("TLSAuth: cannot send HTTP request to OCSP") ta.logger.Error("TLSAuth: cannot send HTTP request to OCSP")
return nil, err return nil, err
} }
defer httpResponse.Body.Close() defer httpResponse.Body.Close()
output, err := io.ReadAll(httpResponse.Body) output, err := io.ReadAll(httpResponse.Body)
if err != nil { if err != nil {
ta.logger.Error("TLSAuth: cannot read HTTP response from OCSP") ta.logger.Error("TLSAuth: cannot read HTTP response from OCSP")
return nil, err return nil, err
} }
ocspResponse, err := ocsp.ParseResponseForCert(output, cert, issuer) ocspResponse, err := ocsp.ParseResponseForCert(output, cert, issuer)
return ocspResponse, err return ocspResponse, err
} }
@ -72,10 +80,12 @@ func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool {
ta.logger.Errorf("TLSAuth: client certificate is expired (NotAfter: %s)", cert.NotAfter.UTC()) ta.logger.Errorf("TLSAuth: client certificate is expired (NotAfter: %s)", cert.NotAfter.UTC())
return true return true
} }
if cert.NotBefore.UTC().After(now) { if cert.NotBefore.UTC().After(now) {
ta.logger.Errorf("TLSAuth: client certificate is not yet valid (NotBefore: %s)", cert.NotBefore.UTC()) ta.logger.Errorf("TLSAuth: client certificate is not yet valid (NotBefore: %s)", cert.NotBefore.UTC())
return true return true
} }
return false return false
} }
@ -84,12 +94,14 @@ func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificat
ta.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification") ta.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification")
return false, nil return false, nil
} }
for _, server := range cert.OCSPServer { for _, server := range cert.OCSPServer {
ocspResponse, err := ta.ocspQuery(server, cert, issuer) ocspResponse, err := ta.ocspQuery(server, cert, issuer)
if err != nil { if err != nil {
ta.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err) ta.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err)
continue continue
} }
switch ocspResponse.Status { switch ocspResponse.Status {
case ocsp.Good: case ocsp.Good:
return false, nil return false, nil
@ -100,7 +112,9 @@ func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificat
continue continue
} }
} }
log.Infof("Could not get any valid OCSP response, assuming the cert is revoked") log.Infof("Could not get any valid OCSP response, assuming the cert is revoked")
return true, nil return true, nil
} }
@ -109,24 +123,29 @@ func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, error) {
ta.logger.Warn("no crl_path, skipping CRL check") ta.logger.Warn("no crl_path, skipping CRL check")
return false, nil return false, nil
} }
crlContent, err := os.ReadFile(ta.CrlPath) crlContent, err := os.ReadFile(ta.CrlPath)
if err != nil { if err != nil {
ta.logger.Warnf("could not read CRL file, skipping check: %s", err) ta.logger.Warnf("could not read CRL file, skipping check: %s", err)
return false, nil return false, nil
} }
crl, err := x509.ParseCRL(crlContent) crl, err := x509.ParseCRL(crlContent)
if err != nil { if err != nil {
ta.logger.Warnf("could not parse CRL file, skipping check: %s", err) ta.logger.Warnf("could not parse CRL file, skipping check: %s", err)
return false, nil return false, nil
} }
if crl.HasExpired(time.Now().UTC()) { if crl.HasExpired(time.Now().UTC()) {
ta.logger.Warn("CRL has expired, will still validate the cert against it.") ta.logger.Warn("CRL has expired, will still validate the cert against it.")
} }
for _, revoked := range crl.TBSCertList.RevokedCertificates { for _, revoked := range crl.TBSCertList.RevokedCertificates {
if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 { if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 {
return true, fmt.Errorf("client certificate is revoked by CRL") return true, fmt.Errorf("client certificate is revoked by CRL")
} }
} }
return false, nil return false, nil
} }
@ -143,6 +162,7 @@ func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) (
} else { } else {
ta.logger.Tracef("TLSAuth: no cached value for cert %s", sn) ta.logger.Tracef("TLSAuth: no cached value for cert %s", sn)
} }
revoked, err := ta.isOCSPRevoked(cert, issuer) revoked, err := ta.isOCSPRevoked(cert, issuer)
if err != nil { if err != nil {
ta.revokationCache[sn] = cacheEntry{ ta.revokationCache[sn] = cacheEntry{
@ -150,22 +170,27 @@ func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) (
err: err, err: err,
timestamp: time.Now().UTC(), timestamp: time.Now().UTC(),
} }
return true, err return true, err
} }
if revoked { if revoked {
ta.revokationCache[sn] = cacheEntry{ ta.revokationCache[sn] = cacheEntry{
revoked: revoked, revoked: revoked,
err: err, err: err,
timestamp: time.Now().UTC(), timestamp: time.Now().UTC(),
} }
return true, nil return true, nil
} }
revoked, err = ta.isCRLRevoked(cert) revoked, err = ta.isCRLRevoked(cert)
ta.revokationCache[sn] = cacheEntry{ ta.revokationCache[sn] = cacheEntry{
revoked: revoked, revoked: revoked,
err: err, err: err,
timestamp: time.Now().UTC(), timestamp: time.Now().UTC(),
} }
return revoked, err return revoked, err
} }
@ -173,6 +198,7 @@ func (ta *TLSAuth) isInvalid(cert *x509.Certificate, issuer *x509.Certificate) (
if ta.isExpired(cert) { if ta.isExpired(cert) {
return true, nil return true, nil
} }
revoked, err := ta.isRevoked(cert, issuer) revoked, err := ta.isRevoked(cert, issuer)
if err != nil { if err != nil {
//Fail securely, if we can't check the revocation status, let's consider the cert invalid //Fail securely, if we can't check the revocation status, let's consider the cert invalid
@ -189,24 +215,30 @@ func (ta *TLSAuth) SetAllowedOu(allowedOus []string) error {
if ou == "" { if ou == "" {
return fmt.Errorf("empty ou isn't allowed") return fmt.Errorf("empty ou isn't allowed")
} }
//drop & warn on duplicate ou //drop & warn on duplicate ou
ok := true ok := true
for _, validOu := range ta.AllowedOUs { for _, validOu := range ta.AllowedOUs {
if validOu == ou { if validOu == ou {
ta.logger.Warningf("dropping duplicate ou %s", ou) ta.logger.Warningf("dropping duplicate ou %s", ou)
ok = false ok = false
} }
} }
if ok { if ok {
ta.AllowedOUs = append(ta.AllowedOUs, ou) ta.AllowedOUs = append(ta.AllowedOUs, ou)
} }
} }
return nil return nil
} }
func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) { func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) {
//Checks cert validity, Returns true + CN if client cert matches requested OU //Checks cert validity, Returns true + CN if client cert matches requested OU
var clientCert *x509.Certificate var clientCert *x509.Certificate
if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 { if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 {
//do not error if it's not TLS or there are no peer certs //do not error if it's not TLS or there are no peer certs
return false, "", nil return false, "", nil
@ -215,6 +247,7 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) {
if len(c.Request.TLS.VerifiedChains) > 0 { if len(c.Request.TLS.VerifiedChains) > 0 {
validOU := false validOU := false
clientCert = c.Request.TLS.VerifiedChains[0][0] clientCert = c.Request.TLS.VerifiedChains[0][0]
for _, ou := range clientCert.Subject.OrganizationalUnit { for _, ou := range clientCert.Subject.OrganizationalUnit {
for _, allowedOu := range ta.AllowedOUs { for _, allowedOu := range ta.AllowedOUs {
if allowedOu == ou { if allowedOu == ou {
@ -223,21 +256,27 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) {
} }
} }
} }
if !validOU { if !validOU {
return false, "", fmt.Errorf("client certificate OU (%v) doesn't match expected OU (%v)", return false, "", fmt.Errorf("client certificate OU (%v) doesn't match expected OU (%v)",
clientCert.Subject.OrganizationalUnit, ta.AllowedOUs) clientCert.Subject.OrganizationalUnit, ta.AllowedOUs)
} }
revoked, err := ta.isInvalid(clientCert, c.Request.TLS.VerifiedChains[0][1]) revoked, err := ta.isInvalid(clientCert, c.Request.TLS.VerifiedChains[0][1])
if err != nil { if err != nil {
ta.logger.Errorf("TLSAuth: error checking if client certificate is revoked: %s", err) ta.logger.Errorf("TLSAuth: error checking if client certificate is revoked: %s", err)
return false, "", fmt.Errorf("could not check for client certification revokation status: %w", err) return false, "", fmt.Errorf("could not check for client certification revokation status: %w", err)
} }
if revoked { if revoked {
return false, "", fmt.Errorf("client certificate is revoked") return false, "", fmt.Errorf("client certificate is revoked")
} }
ta.logger.Debugf("client OU %v is allowed vs required OU %v", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs) ta.logger.Debugf("client OU %v is allowed vs required OU %v", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs)
return true, clientCert.Subject.CommonName, nil return true, clientCert.Subject.CommonName, nil
} }
return false, "", fmt.Errorf("no verified cert in request") return false, "", fmt.Errorf("no verified cert in request")
} }
@ -248,9 +287,11 @@ func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Durati
CrlPath: crlPath, CrlPath: crlPath,
logger: logger, logger: logger,
} }
err := ta.SetAllowedOu(allowedOus) err := ta.SetAllowedOu(allowedOus)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ta, nil return ta, nil
} }

View file

@ -205,12 +205,15 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error {
reversedEvents := reverse(events) //PAPI sends events in the reverse order, which is not an issue when pulling them in real time, but here we need the correct order reversedEvents := reverse(events) //PAPI sends events in the reverse order, which is not an issue when pulling them in real time, but here we need the correct order
eventsCount := len(events) eventsCount := len(events)
p.Logger.Infof("received %d events", eventsCount) p.Logger.Infof("received %d events", eventsCount)
for i, event := range reversedEvents { for i, event := range reversedEvents {
if err := p.handleEvent(event, sync); err != nil { if err := p.handleEvent(event, sync); err != nil {
p.Logger.WithField("request-id", event.RequestId).Errorf("failed to handle event: %s", err) p.Logger.WithField("request-id", event.RequestId).Errorf("failed to handle event: %s", err)
} }
p.Logger.Debugf("handled event %d/%d", i, eventsCount) p.Logger.Debugf("handled event %d/%d", i, eventsCount)
} }
p.Logger.Debugf("finished handling events") p.Logger.Debugf("finished handling events")
//Don't update the timestamp in DB, as a "real" LAPI might be running //Don't update the timestamp in DB, as a "real" LAPI might be running
//Worst case, crowdsec will receive a few duplicated events and will discard them //Worst case, crowdsec will receive a few duplicated events and will discard them
@ -223,16 +226,19 @@ func (p *Papi) Pull() error {
p.Logger.Infof("Starting Polling API Pull") p.Logger.Infof("Starting Polling API Pull")
lastTimestamp := time.Time{} lastTimestamp := time.Time{}
lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey) lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey)
if err != nil { if err != nil {
p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err) p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err)
} }
//value doesn't exist, it's first time we're pulling //value doesn't exist, it's first time we're pulling
if lastTimestampStr == nil { if lastTimestampStr == nil {
binTime, err := lastTimestamp.MarshalText() binTime, err := lastTimestamp.MarshalText()
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal last timestamp: %w", err) return fmt.Errorf("failed to marshal last timestamp: %w", err)
} }
if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil {
p.Logger.Errorf("error setting papi pull last key: %s", err) p.Logger.Errorf("error setting papi pull last key: %s", err)
} else { } else {
@ -245,10 +251,12 @@ func (p *Papi) Pull() error {
} }
p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp) p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp)
for event := range p.Client.Start(lastTimestamp) { for event := range p.Client.Start(lastTimestamp) {
logger := p.Logger.WithField("request-id", event.RequestId) logger := p.Logger.WithField("request-id", event.RequestId)
//update last timestamp in database //update last timestamp in database
newTime := time.Now().UTC() newTime := time.Now().UTC()
binTime, err := newTime.MarshalText() binTime, err := newTime.MarshalText()
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal last timestamp: %w", err) return fmt.Errorf("failed to marshal last timestamp: %w", err)
@ -262,11 +270,11 @@ func (p *Papi) Pull() error {
if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil {
return fmt.Errorf("failed to update last timestamp: %w", err) return fmt.Errorf("failed to update last timestamp: %w", err)
} else { }
logger.Debugf("set last timestamp to %s", newTime) logger.Debugf("set last timestamp to %s", newTime)
} }
}
return nil return nil
} }
@ -274,6 +282,7 @@ func (p *Papi) SyncDecisions() error {
defer trace.CatchPanic("lapi/syncDecisionsToCAPI") defer trace.CatchPanic("lapi/syncDecisionsToCAPI")
var cache models.DecisionsDeleteRequest var cache models.DecisionsDeleteRequest
ticker := time.NewTicker(p.SyncInterval) ticker := time.NewTicker(p.SyncInterval)
p.Logger.Infof("Start decisions sync to CrowdSec Central API (interval: %s)", p.SyncInterval) p.Logger.Infof("Start decisions sync to CrowdSec Central API (interval: %s)", p.SyncInterval)
@ -281,10 +290,13 @@ func (p *Papi) SyncDecisions() error {
select { select {
case <-p.syncTomb.Dying(): // if one apic routine is dying, do we kill the others? case <-p.syncTomb.Dying(): // if one apic routine is dying, do we kill the others?
p.Logger.Infof("sync decisions tomb is dying, sending cache (%d elements) before exiting", len(cache)) p.Logger.Infof("sync decisions tomb is dying, sending cache (%d elements) before exiting", len(cache))
if len(cache) == 0 { if len(cache) == 0 {
return nil return nil
} }
go p.SendDeletedDecisions(&cache) go p.SendDeletedDecisions(&cache)
return nil return nil
case <-ticker.C: case <-ticker.C:
if len(cache) > 0 { if len(cache) > 0 {
@ -293,15 +305,19 @@ func (p *Papi) SyncDecisions() error {
cache = make([]models.DecisionsDeleteRequestItem, 0) cache = make([]models.DecisionsDeleteRequestItem, 0)
p.mu.Unlock() p.mu.Unlock()
p.Logger.Infof("sync decisions: %d deleted decisions to push", len(cacheCopy)) p.Logger.Infof("sync decisions: %d deleted decisions to push", len(cacheCopy))
go p.SendDeletedDecisions(&cacheCopy) go p.SendDeletedDecisions(&cacheCopy)
} }
case deletedDecisions := <-p.Channels.DeleteDecisionChannel: case deletedDecisions := <-p.Channels.DeleteDecisionChannel:
if (p.consoleConfig.ShareManualDecisions != nil && *p.consoleConfig.ShareManualDecisions) || (p.consoleConfig.ConsoleManagement != nil && *p.consoleConfig.ConsoleManagement) { if (p.consoleConfig.ShareManualDecisions != nil && *p.consoleConfig.ShareManualDecisions) || (p.consoleConfig.ConsoleManagement != nil && *p.consoleConfig.ConsoleManagement) {
var tmpDecisions []models.DecisionsDeleteRequestItem var tmpDecisions []models.DecisionsDeleteRequestItem
p.Logger.Debugf("%d decisions deletion to add in cache", len(deletedDecisions)) p.Logger.Debugf("%d decisions deletion to add in cache", len(deletedDecisions))
for _, decision := range deletedDecisions { for _, decision := range deletedDecisions {
tmpDecisions = append(tmpDecisions, models.DecisionsDeleteRequestItem(decision.UUID)) tmpDecisions = append(tmpDecisions, models.DecisionsDeleteRequestItem(decision.UUID))
} }
p.mu.Lock() p.mu.Lock()
cache = append(cache, tmpDecisions...) cache = append(cache, tmpDecisions...)
p.mu.Unlock() p.mu.Unlock()
@ -311,33 +327,42 @@ func (p *Papi) SyncDecisions() error {
} }
func (p *Papi) SendDeletedDecisions(cacheOrig *models.DecisionsDeleteRequest) { func (p *Papi) SendDeletedDecisions(cacheOrig *models.DecisionsDeleteRequest) {
var (
var cache []models.DecisionsDeleteRequestItem = *cacheOrig cache []models.DecisionsDeleteRequestItem = *cacheOrig
var send models.DecisionsDeleteRequest send models.DecisionsDeleteRequest
)
bulkSize := 50 bulkSize := 50
pageStart := 0 pageStart := 0
pageEnd := bulkSize pageEnd := bulkSize
for { for {
if pageEnd >= len(cache) { if pageEnd >= len(cache) {
send = cache[pageStart:] send = cache[pageStart:]
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
_, _, err := p.apiClient.DecisionDelete.Add(ctx, &send) _, _, err := p.apiClient.DecisionDelete.Add(ctx, &send)
if err != nil { if err != nil {
p.Logger.Errorf("sending deleted decisions to central API: %s", err) p.Logger.Errorf("sending deleted decisions to central API: %s", err)
return return
} }
break break
} }
send = cache[pageStart:pageEnd] send = cache[pageStart:pageEnd]
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
_, _, err := p.apiClient.DecisionDelete.Add(ctx, &send) _, _, err := p.apiClient.DecisionDelete.Add(ctx, &send)
if err != nil { if err != nil {
//we log it here as well, because the return value of func might be discarded //we log it here as well, because the return value of func might be discarded
p.Logger.Errorf("sending deleted decisions to central API: %s", err) p.Logger.Errorf("sending deleted decisions to central API: %s", err)
} }
pageStart += bulkSize pageStart += bulkSize
pageEnd += bulkSize pageEnd += bulkSize
} }

View file

@ -40,17 +40,18 @@ type forcePull struct {
func DecisionCmd(message *Message, p *Papi, sync bool) error { func DecisionCmd(message *Message, p *Papi, sync bool) error {
switch message.Header.OperationCmd { switch message.Header.OperationCmd {
case "delete": case "delete":
data, err := json.Marshal(message.Data) data, err := json.Marshal(message.Data)
if err != nil { if err != nil {
return err return err
} }
UUIDs := make([]string, 0) UUIDs := make([]string, 0)
deleteDecisionMsg := deleteDecisions{ deleteDecisionMsg := deleteDecisions{
Decisions: make([]string, 0), Decisions: make([]string, 0),
} }
if err := json.Unmarshal(data, &deleteDecisionMsg); err != nil { if err := json.Unmarshal(data, &deleteDecisionMsg); err != nil {
return fmt.Errorf("message for '%s' contains bad data format: %s", message.Header.OperationType, err) return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err)
} }
UUIDs = append(UUIDs, deleteDecisionMsg.Decisions...) UUIDs = append(UUIDs, deleteDecisionMsg.Decisions...)
@ -59,10 +60,13 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error {
filter := make(map[string][]string) filter := make(map[string][]string)
filter["uuid"] = UUIDs filter["uuid"] = UUIDs
_, deletedDecisions, err := p.DBClient.SoftDeleteDecisionsWithFilter(filter) _, deletedDecisions, err := p.DBClient.SoftDeleteDecisionsWithFilter(filter)
if err != nil { if err != nil {
return fmt.Errorf("unable to delete decisions %+v : %s", UUIDs, err) return fmt.Errorf("unable to delete decisions %+v: %w", UUIDs, err)
} }
decisions := make([]*models.Decision, 0) decisions := make([]*models.Decision, 0)
for _, deletedDecision := range deletedDecisions { for _, deletedDecision := range deletedDecisions {
log.Infof("Decision from '%s' for '%s' (%s) has been deleted", deletedDecision.Origin, deletedDecision.Value, deletedDecision.Type) log.Infof("Decision from '%s' for '%s' (%s) has been deleted", deletedDecision.Origin, deletedDecision.Value, deletedDecision.Type)
dec := &models.Decision{ dec := &models.Decision{
@ -92,6 +96,7 @@ func AlertCmd(message *Message, p *Papi, sync bool) error {
if err != nil { if err != nil {
return err return err
} }
alert := &models.Alert{} alert := &models.Alert{}
if err := json.Unmarshal(data, alert); err != nil { if err := json.Unmarshal(data, alert); err != nil {
@ -105,10 +110,12 @@ func AlertCmd(message *Message, p *Papi, sync bool) error {
log.Warnf("Alert %d has no StartAt, setting it to now", alert.ID) log.Warnf("Alert %d has no StartAt, setting it to now", alert.ID)
alert.StartAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) alert.StartAt = ptr.Of(time.Now().UTC().Format(time.RFC3339))
} }
if alert.StopAt == nil || *alert.StopAt == "" { if alert.StopAt == nil || *alert.StopAt == "" {
log.Warnf("Alert %d has no StopAt, setting it to now", alert.ID) log.Warnf("Alert %d has no StopAt, setting it to now", alert.ID)
alert.StopAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) alert.StopAt = ptr.Of(time.Now().UTC().Format(time.RFC3339))
} }
alert.EventsCount = ptr.Of(int32(0)) alert.EventsCount = ptr.Of(int32(0))
alert.Capacity = ptr.Of(int32(0)) alert.Capacity = ptr.Of(int32(0))
alert.Leakspeed = ptr.Of("") alert.Leakspeed = ptr.Of("")
@ -128,12 +135,14 @@ func AlertCmd(message *Message, p *Papi, sync bool) error {
alert.Source.Scope = ptr.Of(types.ConsoleOrigin) alert.Source.Scope = ptr.Of(types.ConsoleOrigin)
alert.Source.Value = &message.Header.Source.User alert.Source.Value = &message.Header.Source.User
} }
alert.Scenario = &message.Header.Message alert.Scenario = &message.Header.Message
for _, decision := range alert.Decisions { for _, decision := range alert.Decisions {
if *decision.Scenario == "" { if *decision.Scenario == "" {
decision.Scenario = &message.Header.Message decision.Scenario = &message.Header.Message
} }
log.Infof("Adding decision for '%s' with UUID: %s", *decision.Value, decision.UUID) log.Infof("Adding decision for '%s' with UUID: %s", *decision.Value, decision.UUID)
} }
@ -157,6 +166,7 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
log.Infof("Ignoring management command from PAPI in sync mode") log.Infof("Ignoring management command from PAPI in sync mode")
return nil return nil
} }
switch message.Header.OperationCmd { switch message.Header.OperationCmd {
case "reauth": case "reauth":
log.Infof("Received reauth command from PAPI, resetting token") log.Infof("Received reauth command from PAPI, resetting token")
@ -187,12 +197,12 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
Duration: &forcePullMsg.Blocklist.Duration, Duration: &forcePullMsg.Blocklist.Duration,
}, true) }, true)
if err != nil { if err != nil {
return fmt.Errorf("failed to force pull operation: %s", err) return fmt.Errorf("failed to force pull operation: %w", err)
} }
} }
default: default:
return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType) return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType)
} }
return nil return nil
} }