Compare commits

...

11 commits

Author SHA1 Message Date
sabban
9d678a8070 Controller is now exported 2022-07-05 17:02:35 +02:00
sabban
ffe88905aa Merge branch 'improve_notifications' of github.com:crowdsecurity/crowdsec into improve_notifications 2022-06-16 17:26:17 +02:00
sabban
7976558256 change error management 2022-06-16 17:23:44 +02:00
sabban
5fb61756df non-blocking send 2022-06-16 17:23:44 +02:00
sabban
c994e50844 debug leftoover 2022-06-16 17:23:44 +02:00
sabban
54ebaeb412 fix 2022-06-16 17:23:44 +02:00
sabban
ae2767c8a2 * add a way to test notifications
* fix an antipattern in broker system with tomb
* kill an unneeded goroutine
2022-06-16 17:23:41 +02:00
sabban
644d0fe955 non-blocking send 2022-05-20 14:40:10 +02:00
sabban
b168820dbe debug leftoover 2022-05-20 13:34:58 +02:00
sabban
e9a661c87f fix 2022-05-20 11:55:50 +02:00
sabban
4efedbed34 * add a way to test notifications
* fix an antipattern in broker system with tomb
* kill an unneeded goroutine
2022-05-20 09:50:51 +02:00
5 changed files with 119 additions and 33 deletions

View file

@ -1,20 +1,28 @@
package main package main
import ( import (
"context"
"encoding/csv" "encoding/csv"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/fs" "io/fs"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"time"
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/csplugin"
"github.com/crowdsecurity/crowdsec/pkg/cwversion"
"github.com/go-openapi/strfmt"
"github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"gopkg.in/tomb.v2"
) )
type NotificationsCfg struct { type NotificationsCfg struct {
@ -135,6 +143,80 @@ func NewNotificationsCmd() *cobra.Command {
}, },
} }
cmdNotifications.AddCommand(cmdNotificationsInspect) cmdNotifications.AddCommand(cmdNotificationsInspect)
var cmdNotificationsReinject = &cobra.Command{
Use: "reinject",
Short: "reinject alerts into notifications system",
Long: `Reinject alerts into notifications system`,
Example: `cscli notifications reinject <alert_id> <plugin_name>`,
Args: cobra.ExactArgs(2),
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
var (
pluginBroker csplugin.PluginBroker
pluginTomb tomb.Tomb
)
if len(args) != 2 {
printHelp(cmd)
return errors.New("Wrong number of argument")
}
id, err := strconv.Atoi(args[0])
if err != nil {
return errors.New(fmt.Sprintf("bad alert id %s", args[0]))
}
if err := csConfig.LoadAPIClient(); err != nil {
return errors.Wrapf(err, "loading api client")
}
if csConfig.API.Client == nil {
return errors.New("There is no configuration on 'api_client:'")
}
if csConfig.API.Client.Credentials == nil {
return errors.New(fmt.Sprintf("Please provide credentials for the API in '%s'", csConfig.API.Client.CredentialsFilePath))
}
apiURL, err := url.Parse(csConfig.API.Client.Credentials.URL)
Client, err = apiclient.NewClient(&apiclient.Config{
MachineID: csConfig.API.Client.Credentials.Login,
Password: strfmt.Password(csConfig.API.Client.Credentials.Password),
UserAgent: fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
URL: apiURL,
VersionPrefix: "v1",
})
alert, _, err := Client.Alerts.GetByID(context.Background(), id)
if err != nil {
errors.Wrapf(err, "can't find alert with id %s: %s", args[0])
}
err = pluginBroker.Init(csConfig.PluginConfig, csConfig.API.Server.Profiles, csConfig.ConfigPaths)
if err != nil {
errors.Wrapf(err, "Can't initialize plugins")
}
pluginTomb.Go(func() error {
pluginBroker.Run(&pluginTomb)
return nil
})
loop:
for {
select {
case pluginBroker.PluginChannel <- csplugin.ProfileAlert{
ProfileID: 1,
Alert: alert,
}:
break loop
default:
time.Sleep(50 * time.Millisecond)
log.Info("sleeping\n")
}
}
pluginTomb.Kill(errors.New("terminating"))
pluginTomb.Wait()
return nil
},
}
cmdNotifications.AddCommand(cmdNotificationsReinject)
return cmdNotifications return cmdNotifications
} }

View file

@ -152,7 +152,7 @@ func TestCreateAlertChannels(t *testing.T) {
if err != nil { if err != nil {
log.Fatalln(err.Error()) log.Fatalln(err.Error())
} }
apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert) apiServer.Controller.PluginChannel = make(chan csplugin.ProfileAlert)
apiServer.InitController() apiServer.InitController()
loginResp, err := LoginToTestAPI(apiServer.router, config) loginResp, err := LoginToTestAPI(apiServer.router, config)
@ -166,7 +166,7 @@ func TestCreateAlertChannels(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
pd = <-apiServer.controller.PluginChannel pd = <-apiServer.Controller.PluginChannel
wg.Done() wg.Done()
}() }()

View file

@ -34,9 +34,9 @@ var (
type APIServer struct { type APIServer struct {
URL string URL string
TLS *csconfig.TLSCfg TLS *csconfig.TLSCfg
Controller *controllers.Controller
dbClient *database.Client dbClient *database.Client
logFile string logFile string
controller *controllers.Controller
flushScheduler *gocron.Scheduler flushScheduler *gocron.Scheduler
router *gin.Engine router *gin.Engine
httpServer *http.Server httpServer *http.Server
@ -230,7 +230,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
TLS: config.TLS, TLS: config.TLS,
logFile: logFile, logFile: logFile,
dbClient: dbClient, dbClient: dbClient,
controller: controller, Controller: controller,
flushScheduler: flushScheduler, flushScheduler: flushScheduler,
router: router, router: router,
apic: apiClient, apic: apiClient,
@ -370,12 +370,11 @@ func (s *APIServer) Shutdown() error {
} }
func (s *APIServer) AttachPluginBroker(broker *csplugin.PluginBroker) { func (s *APIServer) AttachPluginBroker(broker *csplugin.PluginBroker) {
s.controller.PluginChannel = broker.PluginChannel s.Controller.PluginChannel = broker.PluginChannel
} }
func (s *APIServer) InitController() error { func (s *APIServer) InitController() error {
err := s.Controller.Init()
err := s.controller.Init()
if err != nil { if err != nil {
return errors.Wrap(err, "controller init") return errors.Wrap(err, "controller init")
} }
@ -386,7 +385,7 @@ func (s *APIServer) InitController() error {
} else { } else {
cacheExpiration = time.Hour cacheExpiration = time.Hour
} }
s.controller.HandlerV1.Middlewares.JWT.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedAgentsOU, s.TLS.CRLPath, s.Controller.HandlerV1.Middlewares.JWT.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedAgentsOU, s.TLS.CRLPath,
cacheExpiration, cacheExpiration,
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"component": "tls-auth", "component": "tls-auth",
@ -395,7 +394,7 @@ func (s *APIServer) InitController() error {
if err != nil { if err != nil {
return errors.Wrap(err, "while creating TLS auth for agents") return errors.Wrap(err, "while creating TLS auth for agents")
} }
s.controller.HandlerV1.Middlewares.APIKey.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedBouncersOU, s.TLS.CRLPath, s.Controller.HandlerV1.Middlewares.APIKey.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedBouncersOU, s.TLS.CRLPath,
cacheExpiration, cacheExpiration,
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"component": "tls-auth", "component": "tls-auth",

View file

@ -96,9 +96,9 @@ func (pb *PluginBroker) Kill() {
} }
} }
func (pb *PluginBroker) Run(tomb *tomb.Tomb) { func (pb *PluginBroker) Run(pluginTomb *tomb.Tomb) {
//we get signaled via the channel when notifications need to be delivered to plugin (via the watcher) //we get signaled via the channel when notifications need to be delivered to plugin (via the watcher)
pb.watcher.Start(tomb) pb.watcher.Start(&tomb.Tomb{})
for { for {
select { select {
case profileAlert := <-pb.PluginChannel: case profileAlert := <-pb.PluginChannel:
@ -116,8 +116,25 @@ func (pb *PluginBroker) Run(tomb *tomb.Tomb) {
log.WithField("plugin:", pluginName).Error(err) log.WithField("plugin:", pluginName).Error(err)
} }
}() }()
case <-pluginTomb.Dying():
case <-tomb.Dying(): pb.watcher.tomb.Kill(errors.New("Terminating"))
loop:
for {
select {
case pluginName := <-pb.watcher.PluginEvents:
// this can be ran in goroutine, but then locks will be needed
pluginMutex.Lock()
log.Tracef("going to deliver %d alerts to plugin %s", len(pb.alertsByPluginName[pluginName]), pluginName)
tmpAlerts := pb.alertsByPluginName[pluginName]
pb.alertsByPluginName[pluginName] = make([]*models.Alert, 0)
pluginMutex.Unlock()
if err := pb.pushNotificationsToPlugin(pluginName, tmpAlerts); err != nil {
log.WithField("plugin:", pluginName).Error(err)
}
case <-pb.watcher.tomb.Dead():
break loop
}
}
log.Info("killing all plugins") log.Info("killing all plugins")
pb.Kill() pb.Kill()
return return
@ -133,7 +150,10 @@ func (pb *PluginBroker) addProfileAlert(profileAlert ProfileAlert) {
pluginMutex.Lock() pluginMutex.Lock()
pb.alertsByPluginName[pluginName] = append(pb.alertsByPluginName[pluginName], profileAlert.Alert) pb.alertsByPluginName[pluginName] = append(pb.alertsByPluginName[pluginName], profileAlert.Alert)
pluginMutex.Unlock() pluginMutex.Unlock()
pb.watcher.Inserts <- pluginName if _, ok := pb.watcher.PluginConfigByName[pluginName]; ok {
curr, _ := pb.watcher.AlertCountByPluginName.Get(pluginName)
pb.watcher.AlertCountByPluginName.Set(pluginName, curr+1)
}
} }
} }
func (pb *PluginBroker) profilesContainPlugin(pluginName string) bool { func (pb *PluginBroker) profilesContainPlugin(pluginName string) bool {

View file

@ -73,11 +73,6 @@ func (pw *PluginWatcher) Start(tomb *tomb.Tomb) {
return nil return nil
}) })
} }
pw.tomb.Go(func() error {
pw.watchPluginAlertCounts()
return nil
})
} }
func (pw *PluginWatcher) watchPluginTicker(pluginName string) { func (pw *PluginWatcher) watchPluginTicker(pluginName string) {
@ -139,21 +134,11 @@ func (pw *PluginWatcher) watchPluginTicker(pluginName string) {
} }
case <-pw.tomb.Dying(): case <-pw.tomb.Dying():
ticker.Stop() ticker.Stop()
return
}
}
}
func (pw *PluginWatcher) watchPluginAlertCounts() {
for {
select { select {
case pluginName := <-pw.Inserts: case pw.PluginEvents <- pluginName:
//we only "count" pending alerts, and watchPluginTicker is actually going to send it default:
if _, ok := pw.PluginConfigByName[pluginName]; ok {
curr, _ := pw.AlertCountByPluginName.Get(pluginName)
pw.AlertCountByPluginName.Set(pluginName, curr+1)
} }
case <-pw.tomb.Dying(): log.Tracef("sending alerts to %s", pluginName)
return return
} }
} }