From f3ea88f64ce7a594830558c84bc6f196ddddc323 Mon Sep 17 00:00:00 2001 From: Laurence Jones Date: Wed, 21 Feb 2024 13:40:38 +0000 Subject: [PATCH 1/6] Appsec unix socket (#2737) * Appsec socket * Patch detection of nil listenaddr * Allow TLS unix socket * Merge diff issue --- pkg/acquisition/modules/appsec/appsec.go | 53 ++++++++++++++++++------ 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/pkg/acquisition/modules/appsec/appsec.go b/pkg/acquisition/modules/appsec/appsec.go index 4e2ff0bd2..a3c8c7dd8 100644 --- a/pkg/acquisition/modules/appsec/appsec.go +++ b/pkg/acquisition/modules/appsec/appsec.go @@ -4,7 +4,9 @@ import ( "context" "encoding/json" "fmt" + "net" "net/http" + "os" "sync" "time" @@ -34,6 +36,7 @@ var ( // configuration structure of the acquis for the application security engine type AppsecSourceConfig struct { ListenAddr string `yaml:"listen_addr"` + ListenSocket string `yaml:"listen_socket"` CertFilePath string `yaml:"cert_file"` KeyFilePath string `yaml:"key_file"` Path string `yaml:"path"` @@ -97,7 +100,7 @@ func (w *AppsecSource) UnmarshalConfig(yamlConfig []byte) error { return errors.Wrap(err, "Cannot parse appsec configuration") } - if w.config.ListenAddr == "" { + if w.config.ListenAddr == "" && w.config.ListenSocket == "" { w.config.ListenAddr = "127.0.0.1:7422" } @@ -123,7 +126,12 @@ func (w *AppsecSource) UnmarshalConfig(yamlConfig []byte) error { } if w.config.Name == "" { - w.config.Name = fmt.Sprintf("%s%s", w.config.ListenAddr, w.config.Path) + if w.config.ListenSocket != "" && w.config.ListenAddr == "" { + w.config.Name = w.config.ListenSocket + } + if w.config.ListenSocket == "" { + w.config.Name = fmt.Sprintf("%s%s", w.config.ListenAddr, w.config.Path) + } } csConfig := csconfig.GetConfig() @@ -251,23 +259,44 @@ func (w *AppsecSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) return runner.Run(t) }) } - - w.logger.Infof("Starting Appsec server on %s%s", w.config.ListenAddr, w.config.Path) + t.Go(func() error { + if w.config.ListenSocket != "" { + w.logger.Infof("creating unix socket %s", w.config.ListenSocket) + _ = os.RemoveAll(w.config.ListenSocket) + listener, err := net.Listen("unix", w.config.ListenSocket) + if err != nil { + return errors.Wrap(err, "Appsec server failed") + } + defer listener.Close() + if w.config.CertFilePath != "" && w.config.KeyFilePath != "" { + err = w.server.ServeTLS(listener, w.config.CertFilePath, w.config.KeyFilePath) + } else { + err = w.server.Serve(listener) + } + if err != nil && err != http.ErrServerClosed { + return errors.Wrap(err, "Appsec server failed") + } + } + return nil + }) t.Go(func() error { var err error - if w.config.CertFilePath != "" && w.config.KeyFilePath != "" { - err = w.server.ListenAndServeTLS(w.config.CertFilePath, w.config.KeyFilePath) - } else { - err = w.server.ListenAndServe() - } + if w.config.ListenAddr != "" { + w.logger.Infof("creating TCP server on %s", w.config.ListenAddr) + if w.config.CertFilePath != "" && w.config.KeyFilePath != "" { + err = w.server.ListenAndServeTLS(w.config.CertFilePath, w.config.KeyFilePath) + } else { + err = w.server.ListenAndServe() + } - if err != nil && err != http.ErrServerClosed { - return errors.Wrap(err, "Appsec server failed") + if err != nil && err != http.ErrServerClosed { + return errors.Wrap(err, "Appsec server failed") + } } return nil }) <-t.Dying() - w.logger.Infof("Stopping Appsec server on %s%s", w.config.ListenAddr, w.config.Path) + w.logger.Info("Shutting down Appsec server") //xx let's clean up the appsec runners :) appsec.AppsecRulesDetails = make(map[int]appsec.RulesDetails) w.server.Shutdown(context.TODO()) From 3e3df5e4c6e6deb1ef36bb406e86a7ebc8c30f06 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Thu, 22 Feb 2024 11:04:36 +0100 Subject: [PATCH 2/6] refact "cscli config", remove flag "cscli restore --old-backup" (#2832) * refact "cscli config show" * refact "cscli config backup" * refact "cscli confgi show-yaml" * refact "cscli config restore" * refact "cscli config feature-flags" * cscli restore: remove 'old-backup' option * lint (whitespace, wrapped errors) --- cmd/crowdsec-cli/config.go | 26 +++- cmd/crowdsec-cli/config_backup.go | 99 ++++++------- cmd/crowdsec-cli/config_feature_flags.go | 25 ++-- cmd/crowdsec-cli/config_restore.go | 177 ++++++++--------------- cmd/crowdsec-cli/config_show.go | 37 +++-- cmd/crowdsec-cli/config_showyaml.go | 12 +- cmd/crowdsec-cli/main.go | 2 +- 7 files changed, 168 insertions(+), 210 deletions(-) diff --git a/cmd/crowdsec-cli/config.go b/cmd/crowdsec-cli/config.go index e60246db7..e88845798 100644 --- a/cmd/crowdsec-cli/config.go +++ b/cmd/crowdsec-cli/config.go @@ -4,19 +4,29 @@ import ( "github.com/spf13/cobra" ) -func NewConfigCmd() *cobra.Command { - cmdConfig := &cobra.Command{ +type cliConfig struct { + cfg configGetter +} + +func NewCLIConfig(cfg configGetter) *cliConfig { + return &cliConfig{ + cfg: cfg, + } +} + +func (cli *cliConfig) NewCommand() *cobra.Command { + cmd := &cobra.Command{ Use: "config [command]", Short: "Allows to view current config", Args: cobra.ExactArgs(0), DisableAutoGenTag: true, } - cmdConfig.AddCommand(NewConfigShowCmd()) - cmdConfig.AddCommand(NewConfigShowYAMLCmd()) - cmdConfig.AddCommand(NewConfigBackupCmd()) - cmdConfig.AddCommand(NewConfigRestoreCmd()) - cmdConfig.AddCommand(NewConfigFeatureFlagsCmd()) + cmd.AddCommand(cli.newShowCmd()) + cmd.AddCommand(cli.newShowYAMLCmd()) + cmd.AddCommand(cli.newBackupCmd()) + cmd.AddCommand(cli.newRestoreCmd()) + cmd.AddCommand(cli.newFeatureFlagsCmd()) - return cmdConfig + return cmd } diff --git a/cmd/crowdsec-cli/config_backup.go b/cmd/crowdsec-cli/config_backup.go index 9414fa510..d1e4a3935 100644 --- a/cmd/crowdsec-cli/config_backup.go +++ b/cmd/crowdsec-cli/config_backup.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "errors" "fmt" "os" "path/filepath" @@ -13,8 +14,8 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func backupHub(dirPath string) error { - hub, err := require.Hub(csConfig, nil, nil) +func (cli *cliConfig) backupHub(dirPath string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) if err != nil { return err } @@ -32,7 +33,7 @@ func backupHub(dirPath string) error { itemDirectory := fmt.Sprintf("%s/%s/", dirPath, itemType) if err = os.MkdirAll(itemDirectory, os.ModePerm); err != nil { - return fmt.Errorf("error while creating %s : %s", itemDirectory, err) + return fmt.Errorf("error while creating %s: %w", itemDirectory, err) } upstreamParsers := []string{} @@ -41,18 +42,18 @@ func backupHub(dirPath string) error { clog = clog.WithFields(log.Fields{ "file": v.Name, }) - if !v.State.Installed { //only backup installed ones - clog.Debugf("[%s] : not installed", k) + if !v.State.Installed { // only backup installed ones + clog.Debugf("[%s]: not installed", k) continue } - //for the local/tainted ones, we back up the full file + // for the local/tainted ones, we back up the full file if v.State.Tainted || v.State.IsLocal() || !v.State.UpToDate { - //we need to backup stages for parsers + // we need to backup stages for parsers if itemType == cwhub.PARSERS || itemType == cwhub.POSTOVERFLOWS { fstagedir := fmt.Sprintf("%s%s", itemDirectory, v.Stage) if err = os.MkdirAll(fstagedir, os.ModePerm); err != nil { - return fmt.Errorf("error while creating stage dir %s : %s", fstagedir, err) + return fmt.Errorf("error while creating stage dir %s: %w", fstagedir, err) } } @@ -60,7 +61,7 @@ func backupHub(dirPath string) error { tfile := fmt.Sprintf("%s%s/%s", itemDirectory, v.Stage, v.FileName) if err = CopyFile(v.State.LocalPath, tfile); err != nil { - return fmt.Errorf("failed copy %s %s to %s : %s", itemType, v.State.LocalPath, tfile, err) + return fmt.Errorf("failed copy %s %s to %s: %w", itemType, v.State.LocalPath, tfile, err) } clog.Infof("local/tainted saved %s to %s", v.State.LocalPath, tfile) @@ -68,21 +69,21 @@ func backupHub(dirPath string) error { continue } - clog.Debugf("[%s] : from hub, just backup name (up-to-date:%t)", k, v.State.UpToDate) + clog.Debugf("[%s]: from hub, just backup name (up-to-date:%t)", k, v.State.UpToDate) clog.Infof("saving, version:%s, up-to-date:%t", v.Version, v.State.UpToDate) upstreamParsers = append(upstreamParsers, v.Name) } - //write the upstream items + // write the upstream items upstreamParsersFname := fmt.Sprintf("%s/upstream-%s.json", itemDirectory, itemType) upstreamParsersContent, err := json.MarshalIndent(upstreamParsers, "", " ") if err != nil { - return fmt.Errorf("failed marshaling upstream parsers : %s", err) + return fmt.Errorf("failed marshaling upstream parsers: %w", err) } err = os.WriteFile(upstreamParsersFname, upstreamParsersContent, 0o644) if err != nil { - return fmt.Errorf("unable to write to %s %s : %s", itemType, upstreamParsersFname, err) + return fmt.Errorf("unable to write to %s %s: %w", itemType, upstreamParsersFname, err) } clog.Infof("Wrote %d entries for %s to %s", len(upstreamParsers), itemType, upstreamParsersFname) @@ -102,11 +103,13 @@ func backupHub(dirPath string) error { - Tainted/local/out-of-date scenarios, parsers, postoverflows and collections - Acquisition files (acquis.yaml, acquis.d/*.yaml) */ -func backupConfigToDirectory(dirPath string) error { +func (cli *cliConfig) backup(dirPath string) error { var err error + cfg := cli.cfg() + if dirPath == "" { - return fmt.Errorf("directory path can't be empty") + return errors.New("directory path can't be empty") } log.Infof("Starting configuration backup") @@ -121,10 +124,10 @@ func backupConfigToDirectory(dirPath string) error { return fmt.Errorf("while creating %s: %w", dirPath, err) } - if csConfig.ConfigPaths.SimulationFilePath != "" { + if cfg.ConfigPaths.SimulationFilePath != "" { backupSimulation := filepath.Join(dirPath, "simulation.yaml") - if err = CopyFile(csConfig.ConfigPaths.SimulationFilePath, backupSimulation); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", csConfig.ConfigPaths.SimulationFilePath, backupSimulation, err) + if err = CopyFile(cfg.ConfigPaths.SimulationFilePath, backupSimulation); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.ConfigPaths.SimulationFilePath, backupSimulation, err) } log.Infof("Saved simulation to %s", backupSimulation) @@ -134,22 +137,22 @@ func backupConfigToDirectory(dirPath string) error { - backup AcquisitionFilePath - backup the other files of acquisition directory */ - if csConfig.Crowdsec != nil && csConfig.Crowdsec.AcquisitionFilePath != "" { + if cfg.Crowdsec != nil && cfg.Crowdsec.AcquisitionFilePath != "" { backupAcquisition := filepath.Join(dirPath, "acquis.yaml") - if err = CopyFile(csConfig.Crowdsec.AcquisitionFilePath, backupAcquisition); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.Crowdsec.AcquisitionFilePath, backupAcquisition, err) + if err = CopyFile(cfg.Crowdsec.AcquisitionFilePath, backupAcquisition); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.Crowdsec.AcquisitionFilePath, backupAcquisition, err) } } acquisBackupDir := filepath.Join(dirPath, "acquis") if err = os.Mkdir(acquisBackupDir, 0o700); err != nil { - return fmt.Errorf("error while creating %s: %s", acquisBackupDir, err) + return fmt.Errorf("error while creating %s: %w", acquisBackupDir, err) } - if csConfig.Crowdsec != nil && len(csConfig.Crowdsec.AcquisitionFiles) > 0 { - for _, acquisFile := range csConfig.Crowdsec.AcquisitionFiles { + if cfg.Crowdsec != nil && len(cfg.Crowdsec.AcquisitionFiles) > 0 { + for _, acquisFile := range cfg.Crowdsec.AcquisitionFiles { /*if it was the default one, it was already backup'ed*/ - if csConfig.Crowdsec.AcquisitionFilePath == acquisFile { + if cfg.Crowdsec.AcquisitionFilePath == acquisFile { continue } @@ -169,56 +172,48 @@ func backupConfigToDirectory(dirPath string) error { if ConfigFilePath != "" { backupMain := fmt.Sprintf("%s/config.yaml", dirPath) if err = CopyFile(ConfigFilePath, backupMain); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", ConfigFilePath, backupMain, err) + return fmt.Errorf("failed copy %s to %s: %w", ConfigFilePath, backupMain, err) } log.Infof("Saved default yaml to %s", backupMain) } - if csConfig.API != nil && csConfig.API.Server != nil && csConfig.API.Server.OnlineClient != nil && csConfig.API.Server.OnlineClient.CredentialsFilePath != "" { + if cfg.API != nil && cfg.API.Server != nil && cfg.API.Server.OnlineClient != nil && cfg.API.Server.OnlineClient.CredentialsFilePath != "" { backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) - if err = CopyFile(csConfig.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds, err) + if err = CopyFile(cfg.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.API.Server.OnlineClient.CredentialsFilePath, backupCAPICreds, err) } log.Infof("Saved online API credentials to %s", backupCAPICreds) } - if csConfig.API != nil && csConfig.API.Client != nil && csConfig.API.Client.CredentialsFilePath != "" { + if cfg.API != nil && cfg.API.Client != nil && cfg.API.Client.CredentialsFilePath != "" { backupLAPICreds := fmt.Sprintf("%s/local_api_credentials.yaml", dirPath) - if err = CopyFile(csConfig.API.Client.CredentialsFilePath, backupLAPICreds); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.API.Client.CredentialsFilePath, backupLAPICreds, err) + if err = CopyFile(cfg.API.Client.CredentialsFilePath, backupLAPICreds); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.API.Client.CredentialsFilePath, backupLAPICreds, err) } log.Infof("Saved local API credentials to %s", backupLAPICreds) } - if csConfig.API != nil && csConfig.API.Server != nil && csConfig.API.Server.ProfilesPath != "" { + if cfg.API != nil && cfg.API.Server != nil && cfg.API.Server.ProfilesPath != "" { backupProfiles := fmt.Sprintf("%s/profiles.yaml", dirPath) - if err = CopyFile(csConfig.API.Server.ProfilesPath, backupProfiles); err != nil { - return fmt.Errorf("failed copy %s to %s: %s", csConfig.API.Server.ProfilesPath, backupProfiles, err) + if err = CopyFile(cfg.API.Server.ProfilesPath, backupProfiles); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", cfg.API.Server.ProfilesPath, backupProfiles, err) } log.Infof("Saved profiles to %s", backupProfiles) } - if err = backupHub(dirPath); err != nil { - return fmt.Errorf("failed to backup hub config: %s", err) + if err = cli.backupHub(dirPath); err != nil { + return fmt.Errorf("failed to backup hub config: %w", err) } return nil } -func runConfigBackup(cmd *cobra.Command, args []string) error { - if err := backupConfigToDirectory(args[0]); err != nil { - return fmt.Errorf("failed to backup config: %w", err) - } - - return nil -} - -func NewConfigBackupCmd() *cobra.Command { - cmdConfigBackup := &cobra.Command{ +func (cli *cliConfig) newBackupCmd() *cobra.Command { + cmd := &cobra.Command{ Use: `backup "directory"`, Short: "Backup current config", Long: `Backup the current crowdsec configuration including : @@ -232,8 +227,14 @@ func NewConfigBackupCmd() *cobra.Command { Example: `cscli config backup ./my-backup`, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: runConfigBackup, + RunE: func(_ *cobra.Command, args []string) error { + if err := cli.backup(args[0]); err != nil { + return fmt.Errorf("failed to backup config: %w", err) + } + + return nil + }, } - return cmdConfigBackup + return cmd } diff --git a/cmd/crowdsec-cli/config_feature_flags.go b/cmd/crowdsec-cli/config_feature_flags.go index fbba1f567..d1dbe2b93 100644 --- a/cmd/crowdsec-cli/config_feature_flags.go +++ b/cmd/crowdsec-cli/config_feature_flags.go @@ -11,14 +11,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/fflag" ) -func runConfigFeatureFlags(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - showRetired, err := flags.GetBool("retired") - if err != nil { - return err - } - +func (cli *cliConfig) featureFlags(showRetired bool) error { green := color.New(color.FgGreen).SprintFunc() red := color.New(color.FgRed).SprintFunc() yellow := color.New(color.FgYellow).SprintFunc() @@ -121,18 +114,22 @@ func runConfigFeatureFlags(cmd *cobra.Command, args []string) error { return nil } -func NewConfigFeatureFlagsCmd() *cobra.Command { - cmdConfigFeatureFlags := &cobra.Command{ +func (cli *cliConfig) newFeatureFlagsCmd() *cobra.Command { + var showRetired bool + + cmd := &cobra.Command{ Use: "feature-flags", Short: "Displays feature flag status", Long: `Displays the supported feature flags and their current status.`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: runConfigFeatureFlags, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.featureFlags(showRetired) + }, } - flags := cmdConfigFeatureFlags.Flags() - flags.Bool("retired", false, "Show retired features") + flags := cmd.Flags() + flags.BoolVar(&showRetired, "retired", false, "Show retired features") - return cmdConfigFeatureFlags + return cmd } diff --git a/cmd/crowdsec-cli/config_restore.go b/cmd/crowdsec-cli/config_restore.go index 17d7494c6..513f993ba 100644 --- a/cmd/crowdsec-cli/config_restore.go +++ b/cmd/crowdsec-cli/config_restore.go @@ -3,25 +3,17 @@ package main import ( "encoding/json" "fmt" - "io" "os" "path/filepath" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "gopkg.in/yaml.v2" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -type OldAPICfg struct { - MachineID string `json:"machine_id"` - Password string `json:"password"` -} - -func restoreHub(dirPath string) error { +func (cli *cliConfig) restoreHub(dirPath string) error { hub, err := require.Hub(csConfig, require.RemoteHub(csConfig), nil) if err != nil { return err @@ -38,14 +30,14 @@ func restoreHub(dirPath string) error { file, err := os.ReadFile(upstreamListFN) if err != nil { - return fmt.Errorf("error while opening %s : %s", upstreamListFN, err) + return fmt.Errorf("error while opening %s: %w", upstreamListFN, err) } var upstreamList []string err = json.Unmarshal(file, &upstreamList) if err != nil { - return fmt.Errorf("error unmarshaling %s : %s", upstreamListFN, err) + return fmt.Errorf("error unmarshaling %s: %w", upstreamListFN, err) } for _, toinstall := range upstreamList { @@ -55,8 +47,7 @@ func restoreHub(dirPath string) error { continue } - err := item.Install(false, false) - if err != nil { + if err = item.Install(false, false); err != nil { log.Errorf("Error while installing %s : %s", toinstall, err) } } @@ -64,17 +55,17 @@ func restoreHub(dirPath string) error { /*restore the local and tainted items*/ files, err := os.ReadDir(itemDirectory) if err != nil { - return fmt.Errorf("failed enumerating files of %s : %s", itemDirectory, err) + return fmt.Errorf("failed enumerating files of %s: %w", itemDirectory, err) } for _, file := range files { - //this was the upstream data + // this was the upstream data if file.Name() == fmt.Sprintf("upstream-%s.json", itype) { continue } if itype == cwhub.PARSERS || itype == cwhub.POSTOVERFLOWS { - //we expect a stage here + // we expect a stage here if !file.IsDir() { continue } @@ -84,22 +75,23 @@ func restoreHub(dirPath string) error { log.Debugf("Found stage %s in %s, target directory : %s", stage, itype, stagedir) if err = os.MkdirAll(stagedir, os.ModePerm); err != nil { - return fmt.Errorf("error while creating stage directory %s : %s", stagedir, err) + return fmt.Errorf("error while creating stage directory %s: %w", stagedir, err) } // find items ifiles, err := os.ReadDir(itemDirectory + "/" + stage + "/") if err != nil { - return fmt.Errorf("failed enumerating files of %s : %s", itemDirectory+"/"+stage, err) + return fmt.Errorf("failed enumerating files of %s: %w", itemDirectory+"/"+stage, err) } - //finally copy item + + // finally copy item for _, tfile := range ifiles { log.Infof("Going to restore local/tainted [%s]", tfile.Name()) sourceFile := fmt.Sprintf("%s/%s/%s", itemDirectory, stage, tfile.Name()) destinationFile := fmt.Sprintf("%s%s", stagedir, tfile.Name()) if err = CopyFile(sourceFile, destinationFile); err != nil { - return fmt.Errorf("failed copy %s %s to %s : %s", itype, sourceFile, destinationFile, err) + return fmt.Errorf("failed copy %s %s to %s: %w", itype, sourceFile, destinationFile, err) } log.Infof("restored %s to %s", sourceFile, destinationFile) @@ -108,9 +100,11 @@ func restoreHub(dirPath string) error { log.Infof("Going to restore local/tainted [%s]", file.Name()) sourceFile := fmt.Sprintf("%s/%s", itemDirectory, file.Name()) destinationFile := fmt.Sprintf("%s/%s/%s", csConfig.ConfigPaths.ConfigDir, itype, file.Name()) + if err = CopyFile(sourceFile, destinationFile); err != nil { - return fmt.Errorf("failed copy %s %s to %s : %s", itype, sourceFile, destinationFile, err) + return fmt.Errorf("failed copy %s %s to %s: %w", itype, sourceFile, destinationFile, err) } + log.Infof("restored %s to %s", sourceFile, destinationFile) } } @@ -130,95 +124,60 @@ func restoreHub(dirPath string) error { - Tainted/local/out-of-date scenarios, parsers, postoverflows and collections - Acquisition files (acquis.yaml, acquis.d/*.yaml) */ -func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { +func (cli *cliConfig) restore(dirPath string) error { var err error - if !oldBackup { - backupMain := fmt.Sprintf("%s/config.yaml", dirPath) - if _, err = os.Stat(backupMain); err == nil { - if csConfig.ConfigPaths != nil && csConfig.ConfigPaths.ConfigDir != "" { - if err = CopyFile(backupMain, fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir)); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupMain, csConfig.ConfigPaths.ConfigDir, err) - } + backupMain := fmt.Sprintf("%s/config.yaml", dirPath) + if _, err = os.Stat(backupMain); err == nil { + if csConfig.ConfigPaths != nil && csConfig.ConfigPaths.ConfigDir != "" { + if err = CopyFile(backupMain, fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir)); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupMain, csConfig.ConfigPaths.ConfigDir, err) } } + } - // Now we have config.yaml, we should regenerate config struct to have rights paths etc - ConfigFilePath = fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir) + // Now we have config.yaml, we should regenerate config struct to have rights paths etc + ConfigFilePath = fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir) - log.Debug("Reloading configuration") + log.Debug("Reloading configuration") - csConfig, _, err = loadConfigFor("config") - if err != nil { - return fmt.Errorf("failed to reload configuration: %s", err) + csConfig, _, err = loadConfigFor("config") + if err != nil { + return fmt.Errorf("failed to reload configuration: %w", err) + } + + backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) + if _, err = os.Stat(backupCAPICreds); err == nil { + if err = CopyFile(backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath, err) } + } - backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) - if _, err = os.Stat(backupCAPICreds); err == nil { - if err = CopyFile(backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath, err) - } + backupLAPICreds := fmt.Sprintf("%s/local_api_credentials.yaml", dirPath) + if _, err = os.Stat(backupLAPICreds); err == nil { + if err = CopyFile(backupLAPICreds, csConfig.API.Client.CredentialsFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupLAPICreds, csConfig.API.Client.CredentialsFilePath, err) } + } - backupLAPICreds := fmt.Sprintf("%s/local_api_credentials.yaml", dirPath) - if _, err = os.Stat(backupLAPICreds); err == nil { - if err = CopyFile(backupLAPICreds, csConfig.API.Client.CredentialsFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupLAPICreds, csConfig.API.Client.CredentialsFilePath, err) - } - } - - backupProfiles := fmt.Sprintf("%s/profiles.yaml", dirPath) - if _, err = os.Stat(backupProfiles); err == nil { - if err = CopyFile(backupProfiles, csConfig.API.Server.ProfilesPath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupProfiles, csConfig.API.Server.ProfilesPath, err) - } - } - } else { - var oldAPICfg OldAPICfg - backupOldAPICfg := fmt.Sprintf("%s/api_creds.json", dirPath) - - jsonFile, err := os.Open(backupOldAPICfg) - if err != nil { - log.Warningf("failed to open %s : %s", backupOldAPICfg, err) - } else { - byteValue, _ := io.ReadAll(jsonFile) - err = json.Unmarshal(byteValue, &oldAPICfg) - if err != nil { - return fmt.Errorf("failed to load json file %s : %s", backupOldAPICfg, err) - } - - apiCfg := csconfig.ApiCredentialsCfg{ - Login: oldAPICfg.MachineID, - Password: oldAPICfg.Password, - URL: CAPIBaseURL, - } - apiConfigDump, err := yaml.Marshal(apiCfg) - if err != nil { - return fmt.Errorf("unable to dump api credentials: %s", err) - } - apiConfigDumpFile := fmt.Sprintf("%s/online_api_credentials.yaml", csConfig.ConfigPaths.ConfigDir) - if csConfig.API.Server.OnlineClient != nil && csConfig.API.Server.OnlineClient.CredentialsFilePath != "" { - apiConfigDumpFile = csConfig.API.Server.OnlineClient.CredentialsFilePath - } - err = os.WriteFile(apiConfigDumpFile, apiConfigDump, 0o600) - if err != nil { - return fmt.Errorf("write api credentials in '%s' failed: %s", apiConfigDumpFile, err) - } - log.Infof("Saved API credentials to %s", apiConfigDumpFile) + backupProfiles := fmt.Sprintf("%s/profiles.yaml", dirPath) + if _, err = os.Stat(backupProfiles); err == nil { + if err = CopyFile(backupProfiles, csConfig.API.Server.ProfilesPath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupProfiles, csConfig.API.Server.ProfilesPath, err) } } backupSimulation := fmt.Sprintf("%s/simulation.yaml", dirPath) if _, err = os.Stat(backupSimulation); err == nil { if err = CopyFile(backupSimulation, csConfig.ConfigPaths.SimulationFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupSimulation, csConfig.ConfigPaths.SimulationFilePath, err) + return fmt.Errorf("failed copy %s to %s: %w", backupSimulation, csConfig.ConfigPaths.SimulationFilePath, err) } } /*if there is a acquisition dir, restore its content*/ if csConfig.Crowdsec.AcquisitionDirPath != "" { if err = os.MkdirAll(csConfig.Crowdsec.AcquisitionDirPath, 0o700); err != nil { - return fmt.Errorf("error while creating %s : %s", csConfig.Crowdsec.AcquisitionDirPath, err) + return fmt.Errorf("error while creating %s: %w", csConfig.Crowdsec.AcquisitionDirPath, err) } } @@ -228,7 +187,7 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { log.Debugf("restoring backup'ed %s", backupAcquisition) if err = CopyFile(backupAcquisition, csConfig.Crowdsec.AcquisitionFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", backupAcquisition, csConfig.Crowdsec.AcquisitionFilePath, err) + return fmt.Errorf("failed copy %s to %s: %w", backupAcquisition, csConfig.Crowdsec.AcquisitionFilePath, err) } } @@ -244,7 +203,7 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { log.Debugf("restoring %s to %s", acquisFile, targetFname) if err = CopyFile(acquisFile, targetFname); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", acquisFile, targetFname, err) + return fmt.Errorf("failed copy %s to %s: %w", acquisFile, targetFname, err) } } } @@ -265,37 +224,22 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error { } if err = CopyFile(acquisFile, targetFname); err != nil { - return fmt.Errorf("failed copy %s to %s : %s", acquisFile, targetFname, err) + return fmt.Errorf("failed copy %s to %s: %w", acquisFile, targetFname, err) } log.Infof("Saved acquis %s to %s", acquisFile, targetFname) } } - if err = restoreHub(dirPath); err != nil { - return fmt.Errorf("failed to restore hub config : %s", err) + if err = cli.restoreHub(dirPath); err != nil { + return fmt.Errorf("failed to restore hub config: %w", err) } return nil } -func runConfigRestore(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - oldBackup, err := flags.GetBool("old-backup") - if err != nil { - return err - } - - if err := restoreConfigFromDirectory(args[0], oldBackup); err != nil { - return fmt.Errorf("failed to restore config from %s: %w", args[0], err) - } - - return nil -} - -func NewConfigRestoreCmd() *cobra.Command { - cmdConfigRestore := &cobra.Command{ +func (cli *cliConfig) newRestoreCmd() *cobra.Command { + cmd := &cobra.Command{ Use: `restore "directory"`, Short: `Restore config in backup "directory"`, Long: `Restore the crowdsec configuration from specified backup "directory" including: @@ -308,11 +252,16 @@ func NewConfigRestoreCmd() *cobra.Command { - Backup of API credentials (local API and online API)`, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: runConfigRestore, + RunE: func(_ *cobra.Command, args []string) error { + dirPath := args[0] + + if err := cli.restore(dirPath); err != nil { + return fmt.Errorf("failed to restore config from %s: %w", dirPath, err) + } + + return nil + }, } - flags := cmdConfigRestore.Flags() - flags.BoolP("old-backup", "", false, "To use when you are upgrading crowdsec v0.X to v1.X and you need to restore backup from v0.X") - - return cmdConfigRestore + return cmd } diff --git a/cmd/crowdsec-cli/config_show.go b/cmd/crowdsec-cli/config_show.go index bab911cc3..634ca7741 100644 --- a/cmd/crowdsec-cli/config_show.go +++ b/cmd/crowdsec-cli/config_show.go @@ -182,31 +182,26 @@ Central API: {{- end }} ` -func runConfigShow(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() +func (cli *cliConfig) show(key string) error { + cfg := cli.cfg() - if err := csConfig.LoadAPIClient(); err != nil { + if err := cfg.LoadAPIClient(); err != nil { log.Errorf("failed to load API client configuration: %s", err) // don't return, we can still show the configuration } - key, err := flags.GetString("key") - if err != nil { - return err - } - if key != "" { return showConfigKey(key) } - switch csConfig.Cscli.Output { + switch cfg.Cscli.Output { case "human": // The tests on .Enable look funny because the option has a true default which has // not been set yet (we don't really load the LAPI) and go templates don't dereference // pointers in boolean tests. Prefix notation is the cherry on top. funcs := template.FuncMap{ // can't use generics here - "ValueBool": func(b *bool) bool { return b!=nil && *b }, + "ValueBool": func(b *bool) bool { return b != nil && *b }, } tmp, err := template.New("config").Funcs(funcs).Parse(configShowTemplate) @@ -214,19 +209,19 @@ func runConfigShow(cmd *cobra.Command, args []string) error { return err } - err = tmp.Execute(os.Stdout, csConfig) + err = tmp.Execute(os.Stdout, cfg) if err != nil { return err } case "json": - data, err := json.MarshalIndent(csConfig, "", " ") + data, err := json.MarshalIndent(cfg, "", " ") if err != nil { return fmt.Errorf("failed to marshal configuration: %w", err) } fmt.Printf("%s\n", string(data)) case "raw": - data, err := yaml.Marshal(csConfig) + data, err := yaml.Marshal(cfg) if err != nil { return fmt.Errorf("failed to marshal configuration: %w", err) } @@ -237,18 +232,22 @@ func runConfigShow(cmd *cobra.Command, args []string) error { return nil } -func NewConfigShowCmd() *cobra.Command { - cmdConfigShow := &cobra.Command{ +func (cli *cliConfig) newShowCmd() *cobra.Command { + var key string + + cmd := &cobra.Command{ Use: "show", Short: "Displays current config", Long: `Displays the current cli configuration.`, Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: runConfigShow, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.show(key) + }, } - flags := cmdConfigShow.Flags() - flags.StringP("key", "", "", "Display only this value (Config.API.Server.ListenURI)") + flags := cmd.Flags() + flags.StringVarP(&key, "key", "", "", "Display only this value (Config.API.Server.ListenURI)") - return cmdConfigShow + return cmd } diff --git a/cmd/crowdsec-cli/config_showyaml.go b/cmd/crowdsec-cli/config_showyaml.go index 82bc67ffc..52daee6a6 100644 --- a/cmd/crowdsec-cli/config_showyaml.go +++ b/cmd/crowdsec-cli/config_showyaml.go @@ -6,19 +6,21 @@ import ( "github.com/spf13/cobra" ) -func runConfigShowYAML(cmd *cobra.Command, args []string) error { +func (cli *cliConfig) showYAML() error { fmt.Println(mergedConfig) return nil } -func NewConfigShowYAMLCmd() *cobra.Command { - cmdConfigShow := &cobra.Command{ +func (cli *cliConfig) newShowYAMLCmd() *cobra.Command { + cmd := &cobra.Command{ Use: "show-yaml", Short: "Displays merged config.yaml + config.yaml.local", Args: cobra.ExactArgs(0), DisableAutoGenTag: true, - RunE: runConfigShowYAML, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.showYAML() + }, } - return cmdConfigShow + return cmd } diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index 27ac17d55..1f87390b6 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -231,7 +231,7 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall cmd.AddCommand(NewCLIDoc().NewCommand(cmd)) cmd.AddCommand(NewCLIVersion().NewCommand()) - cmd.AddCommand(NewConfigCmd()) + cmd.AddCommand(NewCLIConfig(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIHub(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIMetrics(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIDashboard(cli.cfg).NewCommand()) From 8da490f5930406180bef6f4b0b99e0b0dc86dff8 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Thu, 22 Feb 2024 11:42:33 +0100 Subject: [PATCH 3/6] refact pkg/apiclient (#2846) * extract resperr.go * extract method prepareRequest() * reset token inside mutex --- pkg/apiclient/auth_jwt.go | 39 +++++++++++++++++++++++---------- pkg/apiclient/client.go | 36 ------------------------------ pkg/apiclient/resperr.go | 46 +++++++++++++++++++++++++++++++++++++++ pkg/apiserver/apic.go | 1 - 4 files changed, 73 insertions(+), 49 deletions(-) create mode 100644 pkg/apiclient/resperr.go diff --git a/pkg/apiclient/auth_jwt.go b/pkg/apiclient/auth_jwt.go index 71b0e2731..2ead10cf6 100644 --- a/pkg/apiclient/auth_jwt.go +++ b/pkg/apiclient/auth_jwt.go @@ -130,20 +130,24 @@ func (t *JWTTransport) refreshJwtToken() error { return nil } -// RoundTrip implements the RoundTripper interface. -func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI - // we use a mutex to avoid this - // We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request) - t.refreshTokenMutex.Lock() - if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) { - if err := t.refreshJwtToken(); err != nil { - t.refreshTokenMutex.Unlock() +func (t *JWTTransport) needsTokenRefresh() bool { + return t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) +} +// prepareRequest returns a copy of the request with the necessary authentication headers. +func (t *JWTTransport) prepareRequest(req *http.Request) (*http.Request, error) { + // In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless + // and will cause overload on CAPI. We use a mutex to avoid this. + t.refreshTokenMutex.Lock() + defer t.refreshTokenMutex.Unlock() + + // We bypass the refresh if we are requesting the login endpoint, as it does not require a token, + // and it leads to do 2 requests instead of one (refresh + actual login request). + if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && t.needsTokenRefresh() { + if err := t.refreshJwtToken(); err != nil { return nil, err } } - t.refreshTokenMutex.Unlock() if t.UserAgent != "" { req.Header.Add("User-Agent", t.UserAgent) @@ -151,6 +155,16 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token)) + return req, nil +} + +// RoundTrip implements the RoundTripper interface. +func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req, err := t.prepareRequest(req) + if err != nil { + return nil, err + } + if log.GetLevel() >= log.TraceLevel { //requestToDump := cloneRequest(req) dump, _ := httputil.DumpRequest(req, true) @@ -166,7 +180,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { if err != nil { // we had an error (network error for example, or 401 because token is refused), reset the token? - t.Token = "" + t.ResetToken() return resp, fmt.Errorf("performing jwt auth: %w", err) } @@ -189,7 +203,8 @@ func (t *JWTTransport) ResetToken() { t.refreshTokenMutex.Unlock() } -// transport() returns a round tripper that retries once when the status is unauthorized, and 5 times when the infrastructure is overloaded. +// transport() returns a round tripper that retries once when the status is unauthorized, +// and 5 times when the infrastructure is overloaded. func (t *JWTTransport) transport() http.RoundTripper { transport := t.Transport if transport == nil { diff --git a/pkg/apiclient/client.go b/pkg/apiclient/client.go index b183a8c79..b487f68a6 100644 --- a/pkg/apiclient/client.go +++ b/pkg/apiclient/client.go @@ -4,9 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/json" "fmt" - "io" "net/http" "net/url" @@ -167,44 +165,10 @@ type Response struct { //... } -type ErrorResponse struct { - models.ErrorResponse -} - -func (e *ErrorResponse) Error() string { - err := fmt.Sprintf("API error: %s", *e.Message) - if len(e.Errors) > 0 { - err += fmt.Sprintf(" (%s)", e.Errors) - } - - return err -} - func newResponse(r *http.Response) *Response { return &Response{Response: r} } -func CheckResponse(r *http.Response) error { - if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 { - return nil - } - - errorResponse := &ErrorResponse{} - - data, err := io.ReadAll(r.Body) - if err == nil && len(data)>0 { - err := json.Unmarshal(data, errorResponse) - if err != nil { - return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err) - } - } else { - errorResponse.Message = new(string) - *errorResponse.Message = fmt.Sprintf("http code %d, no error message", r.StatusCode) - } - - return errorResponse -} - type ListOpts struct { //Page int //PerPage int diff --git a/pkg/apiclient/resperr.go b/pkg/apiclient/resperr.go new file mode 100644 index 000000000..ff954a736 --- /dev/null +++ b/pkg/apiclient/resperr.go @@ -0,0 +1,46 @@ +package apiclient + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +type ErrorResponse struct { + models.ErrorResponse +} + +func (e *ErrorResponse) Error() string { + err := fmt.Sprintf("API error: %s", *e.Message) + if len(e.Errors) > 0 { + err += fmt.Sprintf(" (%s)", e.Errors) + } + + return err +} + +// CheckResponse verifies the API response and builds an appropriate Go error if necessary. +func CheckResponse(r *http.Response) error { + if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 { + return nil + } + + ret := &ErrorResponse{} + + data, err := io.ReadAll(r.Body) + if err != nil || len(data) == 0 { + ret.Message = ptr.Of(fmt.Sprintf("http code %d, no error message", r.StatusCode)) + return ret + } + + if err := json.Unmarshal(data, ret); err != nil { + return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err) + } + + return ret +} diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 2fdb01144..f57ae685e 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -539,7 +539,6 @@ func createAlertForDecision(decision *models.Decision) *models.Alert { scenario = *decision.Scenario scope = types.ListOrigin default: - // XXX: this or nil? scenario = "" scope = "" From 0df8f54fbbd08ab857e153229a43cf9e3c3f258e Mon Sep 17 00:00:00 2001 From: Laurence Jones Date: Thu, 22 Feb 2024 11:18:29 +0000 Subject: [PATCH 4/6] Add unix socket option to http plugin, we have to use this in conjunction with URL parameter as we dont know which path the user wants so if they would like to communicate over unix socket they need to use both, however, the hostname can be whatever they want. We could be a little smarter and actually parse the url, however, increasing code when a user can just define it correctly make no sense (#2764) --- cmd/notification-http/main.go | 42 +++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/cmd/notification-http/main.go b/cmd/notification-http/main.go index 340d462c1..382f30fea 100644 --- a/cmd/notification-http/main.go +++ b/cmd/notification-http/main.go @@ -7,8 +7,10 @@ import ( "crypto/x509" "fmt" "io" + "net" "net/http" "os" + "strings" "github.com/crowdsecurity/crowdsec/pkg/protobufs" "github.com/hashicorp/go-hclog" @@ -19,6 +21,7 @@ import ( type PluginConfig struct { Name string `yaml:"name"` URL string `yaml:"url"` + UnixSocket string `yaml:"unix_socket"` Headers map[string]string `yaml:"headers"` SkipTLSVerification bool `yaml:"skip_tls_verification"` Method string `yaml:"method"` @@ -66,36 +69,40 @@ func getCertPool(caPath string) (*x509.CertPool, error) { return cp, nil } -func getTLSClient(tlsVerify bool, caPath, certPath, keyPath string) (*http.Client, error) { - var client *http.Client - - caCertPool, err := getCertPool(caPath) +func getTLSClient(c *PluginConfig) error { + caCertPool, err := getCertPool(c.CAPath) if err != nil { - return nil, err + return err } tlsConfig := &tls.Config{ RootCAs: caCertPool, - InsecureSkipVerify: tlsVerify, + InsecureSkipVerify: c.SkipTLSVerification, } - if certPath != "" && keyPath != "" { - logger.Info(fmt.Sprintf("Using client certificate '%s' and key '%s'", certPath, keyPath)) + if c.CertPath != "" && c.KeyPath != "" { + logger.Info(fmt.Sprintf("Using client certificate '%s' and key '%s'", c.CertPath, c.KeyPath)) - cert, err := tls.LoadX509KeyPair(certPath, keyPath) + cert, err := tls.LoadX509KeyPair(c.CertPath, c.KeyPath) if err != nil { - return nil, fmt.Errorf("unable to load client certificate '%s' and key '%s': %w", certPath, keyPath, err) + return fmt.Errorf("unable to load client certificate '%s' and key '%s': %w", c.CertPath, c.KeyPath, err) } tlsConfig.Certificates = []tls.Certificate{cert} } - - client = &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: tlsConfig, - }, + transport := &http.Transport{ + TLSClientConfig: tlsConfig, } - return client, err + if c.UnixSocket != "" { + logger.Info(fmt.Sprintf("Using socket '%s'", c.UnixSocket)) + transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", strings.TrimSuffix(c.UnixSocket, "/")) + } + } + c.Client = &http.Client{ + Transport: transport, + } + return nil } func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) { @@ -135,6 +142,7 @@ func (s *HTTPPlugin) Notify(ctx context.Context, notification *protobufs.Notific if resp.StatusCode < 200 || resp.StatusCode >= 300 { logger.Warn(fmt.Sprintf("HTTP server returned non 200 status code: %d", resp.StatusCode)) + logger.Debug(fmt.Sprintf("HTTP server returned body: %s", string(respData))) return &protobufs.Empty{}, nil } @@ -147,7 +155,7 @@ func (s *HTTPPlugin) Configure(ctx context.Context, config *protobufs.Config) (* if err != nil { return nil, err } - d.Client, err = getTLSClient(d.SkipTLSVerification, d.CAPath, d.CertPath, d.KeyPath) + err = getTLSClient(&d) if err != nil { return nil, err } From e34af358d7b96df49634f28696e9c1b1f01e097c Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Fri, 23 Feb 2024 10:37:04 +0100 Subject: [PATCH 5/6] refact cscli (globals) (#2854) * cscli capi: avoid globals, extract methods * cscli config restore: avoid global * cscli hubtest: avoid global * lint (whitespace, wrapped errors) --- cmd/crowdsec-cli/bouncers.go | 24 +-- cmd/crowdsec-cli/capi.go | 255 ++++++++++++++++------------- cmd/crowdsec-cli/config_restore.go | 54 +++--- cmd/crowdsec-cli/hubtest.go | 172 +++++++++++-------- cmd/crowdsec-cli/main.go | 4 +- 5 files changed, 287 insertions(+), 222 deletions(-) diff --git a/cmd/crowdsec-cli/bouncers.go b/cmd/crowdsec-cli/bouncers.go index 717e9aef5..35f4320c5 100644 --- a/cmd/crowdsec-cli/bouncers.go +++ b/cmd/crowdsec-cli/bouncers.go @@ -3,6 +3,7 @@ package main import ( "encoding/csv" "encoding/json" + "errors" "fmt" "os" "slices" @@ -58,13 +59,16 @@ Note: This command requires database direct access, so is intended to be run on DisableAutoGenTag: true, PersistentPreRunE: func(_ *cobra.Command, _ []string) error { var err error - if err = require.LAPI(cli.cfg()); err != nil { + + cfg := cli.cfg() + + if err = require.LAPI(cfg); err != nil { return err } - cli.db, err = database.NewClient(cli.cfg().DbConfig) + cli.db, err = database.NewClient(cfg.DbConfig) if err != nil { - return fmt.Errorf("can't connect to the database: %s", err) + return fmt.Errorf("can't connect to the database: %w", err) } return nil @@ -84,7 +88,7 @@ func (cli *cliBouncers) list() error { bouncers, err := cli.db.ListBouncers() if err != nil { - return fmt.Errorf("unable to list bouncers: %s", err) + return fmt.Errorf("unable to list bouncers: %w", err) } switch cli.cfg().Cscli.Output { @@ -146,13 +150,13 @@ func (cli *cliBouncers) add(bouncerName string, key string) error { if key == "" { key, err = middlewares.GenerateAPIKey(keyLength) if err != nil { - return fmt.Errorf("unable to generate api key: %s", err) + return fmt.Errorf("unable to generate api key: %w", err) } } _, err = cli.db.CreateBouncer(bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType) if err != nil { - return fmt.Errorf("unable to create bouncer: %s", err) + return fmt.Errorf("unable to create bouncer: %w", err) } switch cli.cfg().Cscli.Output { @@ -165,7 +169,7 @@ func (cli *cliBouncers) add(bouncerName string, key string) error { case "json": j, err := json.Marshal(key) if err != nil { - return fmt.Errorf("unable to marshal api key") + return errors.New("unable to marshal api key") } fmt.Print(string(j)) @@ -191,7 +195,7 @@ cscli bouncers add MyBouncerName --key `, flags := cmd.Flags() flags.StringP("length", "l", "", "length of the api key") - flags.MarkDeprecated("length", "use --key instead") + _ = flags.MarkDeprecated("length", "use --key instead") flags.StringVarP(&key, "key", "k", "", "api key for the bouncer") return cmd @@ -218,7 +222,7 @@ func (cli *cliBouncers) delete(bouncers []string) error { for _, bouncerID := range bouncers { err := cli.db.DeleteBouncer(bouncerID) if err != nil { - return fmt.Errorf("unable to delete bouncer '%s': %s", bouncerID, err) + return fmt.Errorf("unable to delete bouncer '%s': %w", bouncerID, err) } log.Infof("bouncer '%s' deleted successfully", bouncerID) @@ -280,7 +284,7 @@ func (cli *cliBouncers) prune(duration time.Duration, force bool) error { deleted, err := cli.db.BulkDeleteBouncers(bouncers) if err != nil { - return fmt.Errorf("unable to prune bouncers: %s", err) + return fmt.Errorf("unable to prune bouncers: %w", err) } fmt.Fprintf(os.Stderr, "Successfully deleted %d bouncers\n", deleted) diff --git a/cmd/crowdsec-cli/capi.go b/cmd/crowdsec-cli/capi.go index 358d91ee2..e56a8a747 100644 --- a/cmd/crowdsec-cli/capi.go +++ b/cmd/crowdsec-cli/capi.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "fmt" "net/url" "os" @@ -26,24 +27,29 @@ const ( CAPIURLPrefix = "v3" ) -type cliCapi struct{} - -func NewCLICapi() *cliCapi { - return &cliCapi{} +type cliCapi struct { + cfg configGetter } -func (cli cliCapi) NewCommand() *cobra.Command { - var cmd = &cobra.Command{ +func NewCLICapi(cfg configGetter) *cliCapi { + return &cliCapi{ + cfg: cfg, + } +} + +func (cli *cliCapi) NewCommand() *cobra.Command { + cmd := &cobra.Command{ Use: "capi [action]", Short: "Manage interaction with Central API (CAPI)", Args: cobra.MinimumNArgs(1), DisableAutoGenTag: true, PersistentPreRunE: func(_ *cobra.Command, _ []string) error { - if err := require.LAPI(csConfig); err != nil { + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { return err } - if err := require.CAPI(csConfig); err != nil { + if err := require.CAPI(cfg); err != nil { return err } @@ -51,78 +57,92 @@ func (cli cliCapi) NewCommand() *cobra.Command { }, } - cmd.AddCommand(cli.NewRegisterCmd()) - cmd.AddCommand(cli.NewStatusCmd()) + cmd.AddCommand(cli.newRegisterCmd()) + cmd.AddCommand(cli.newStatusCmd()) return cmd } -func (cli cliCapi) NewRegisterCmd() *cobra.Command { +func (cli *cliCapi) register(capiUserPrefix string, outputFile string) error { + cfg := cli.cfg() + + capiUser, err := generateID(capiUserPrefix) + if err != nil { + return fmt.Errorf("unable to generate machine id: %w", err) + } + + password := strfmt.Password(generatePassword(passwordLength)) + + apiurl, err := url.Parse(types.CAPIBaseURL) + if err != nil { + return fmt.Errorf("unable to parse api url %s: %w", types.CAPIBaseURL, err) + } + + _, err = apiclient.RegisterClient(&apiclient.Config{ + MachineID: capiUser, + Password: password, + UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), + URL: apiurl, + VersionPrefix: CAPIURLPrefix, + }, nil) + + if err != nil { + return fmt.Errorf("api client register ('%s'): %w", types.CAPIBaseURL, err) + } + + log.Infof("Successfully registered to Central API (CAPI)") + + var dumpFile string + + switch { + case outputFile != "": + dumpFile = outputFile + case cfg.API.Server.OnlineClient.CredentialsFilePath != "": + dumpFile = cfg.API.Server.OnlineClient.CredentialsFilePath + default: + dumpFile = "" + } + + apiCfg := csconfig.ApiCredentialsCfg{ + Login: capiUser, + Password: password.String(), + URL: types.CAPIBaseURL, + } + + apiConfigDump, err := yaml.Marshal(apiCfg) + if err != nil { + return fmt.Errorf("unable to marshal api credentials: %w", err) + } + + if dumpFile != "" { + err = os.WriteFile(dumpFile, apiConfigDump, 0o600) + if err != nil { + return fmt.Errorf("write api credentials in '%s' failed: %w", dumpFile, err) + } + + log.Infof("Central API credentials written to '%s'", dumpFile) + } else { + fmt.Println(string(apiConfigDump)) + } + + log.Warning(ReloadMessage()) + + return nil +} + +func (cli *cliCapi) newRegisterCmd() *cobra.Command { var ( capiUserPrefix string - outputFile string + outputFile string ) - var cmd = &cobra.Command{ + cmd := &cobra.Command{ Use: "register", Short: "Register to Central API (CAPI)", Args: cobra.MinimumNArgs(0), DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { - var err error - capiUser, err := generateID(capiUserPrefix) - if err != nil { - return fmt.Errorf("unable to generate machine id: %s", err) - } - password := strfmt.Password(generatePassword(passwordLength)) - apiurl, err := url.Parse(types.CAPIBaseURL) - if err != nil { - return fmt.Errorf("unable to parse api url %s: %w", types.CAPIBaseURL, err) - } - _, err = apiclient.RegisterClient(&apiclient.Config{ - MachineID: capiUser, - Password: password, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiurl, - VersionPrefix: CAPIURLPrefix, - }, nil) - - if err != nil { - return fmt.Errorf("api client register ('%s'): %w", types.CAPIBaseURL, err) - } - log.Printf("Successfully registered to Central API (CAPI)") - - var dumpFile string - - if outputFile != "" { - dumpFile = outputFile - } else if csConfig.API.Server.OnlineClient.CredentialsFilePath != "" { - dumpFile = csConfig.API.Server.OnlineClient.CredentialsFilePath - } else { - dumpFile = "" - } - apiCfg := csconfig.ApiCredentialsCfg{ - Login: capiUser, - Password: password.String(), - URL: types.CAPIBaseURL, - } - apiConfigDump, err := yaml.Marshal(apiCfg) - if err != nil { - return fmt.Errorf("unable to marshal api credentials: %w", err) - } - if dumpFile != "" { - err = os.WriteFile(dumpFile, apiConfigDump, 0o600) - if err != nil { - return fmt.Errorf("write api credentials in '%s' failed: %w", dumpFile, err) - } - log.Printf("Central API credentials written to '%s'", dumpFile) - } else { - fmt.Println(string(apiConfigDump)) - } - - log.Warning(ReloadMessage()) - - return nil + return cli.register(capiUserPrefix, outputFile) }, } @@ -136,59 +156,66 @@ func (cli cliCapi) NewRegisterCmd() *cobra.Command { return cmd } -func (cli cliCapi) NewStatusCmd() *cobra.Command { +func (cli *cliCapi) status() error { + cfg := cli.cfg() + + if err := require.CAPIRegistered(cfg); err != nil { + return err + } + + password := strfmt.Password(cfg.API.Server.OnlineClient.Credentials.Password) + + apiurl, err := url.Parse(cfg.API.Server.OnlineClient.Credentials.URL) + if err != nil { + return fmt.Errorf("parsing api url ('%s'): %w", cfg.API.Server.OnlineClient.Credentials.URL, err) + } + + hub, err := require.Hub(cfg, nil, nil) + if err != nil { + return err + } + + scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) + if err != nil { + return fmt.Errorf("failed to get scenarios: %w", err) + } + + if len(scenarios) == 0 { + return errors.New("no scenarios installed, abort") + } + + Client, err = apiclient.NewDefaultClient(apiurl, CAPIURLPrefix, fmt.Sprintf("crowdsec/%s", version.String()), nil) + if err != nil { + return fmt.Errorf("init default client: %w", err) + } + + t := models.WatcherAuthRequest{ + MachineID: &cfg.API.Server.OnlineClient.Credentials.Login, + Password: &password, + Scenarios: scenarios, + } + + log.Infof("Loaded credentials from %s", cfg.API.Server.OnlineClient.CredentialsFilePath) + log.Infof("Trying to authenticate with username %s on %s", cfg.API.Server.OnlineClient.Credentials.Login, apiurl) + + _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) + if err != nil { + return fmt.Errorf("failed to authenticate to Central API (CAPI): %w", err) + } + + log.Info("You can successfully interact with Central API (CAPI)") + + return nil +} + +func (cli *cliCapi) newStatusCmd() *cobra.Command { cmd := &cobra.Command{ Use: "status", Short: "Check status with the Central API (CAPI)", Args: cobra.MinimumNArgs(0), DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { - if err := require.CAPIRegistered(csConfig); err != nil { - return err - } - - password := strfmt.Password(csConfig.API.Server.OnlineClient.Credentials.Password) - - apiurl, err := url.Parse(csConfig.API.Server.OnlineClient.Credentials.URL) - if err != nil { - return fmt.Errorf("parsing api url ('%s'): %w", csConfig.API.Server.OnlineClient.Credentials.URL, err) - } - - hub, err := require.Hub(csConfig, nil, nil) - if err != nil { - return err - } - - scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) - if err != nil { - return fmt.Errorf("failed to get scenarios: %w", err) - } - - if len(scenarios) == 0 { - return fmt.Errorf("no scenarios installed, abort") - } - - Client, err = apiclient.NewDefaultClient(apiurl, CAPIURLPrefix, fmt.Sprintf("crowdsec/%s", version.String()), nil) - if err != nil { - return fmt.Errorf("init default client: %w", err) - } - - t := models.WatcherAuthRequest{ - MachineID: &csConfig.API.Server.OnlineClient.Credentials.Login, - Password: &password, - Scenarios: scenarios, - } - - log.Infof("Loaded credentials from %s", csConfig.API.Server.OnlineClient.CredentialsFilePath) - log.Infof("Trying to authenticate with username %s on %s", csConfig.API.Server.OnlineClient.Credentials.Login, apiurl) - - _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) - if err != nil { - return fmt.Errorf("failed to authenticate to Central API (CAPI): %w", err) - } - log.Infof("You can successfully interact with Central API (CAPI)") - - return nil + return cli.status() }, } diff --git a/cmd/crowdsec-cli/config_restore.go b/cmd/crowdsec-cli/config_restore.go index 513f993ba..ee7179b73 100644 --- a/cmd/crowdsec-cli/config_restore.go +++ b/cmd/crowdsec-cli/config_restore.go @@ -14,7 +14,9 @@ import ( ) func (cli *cliConfig) restoreHub(dirPath string) error { - hub, err := require.Hub(csConfig, require.RemoteHub(csConfig), nil) + cfg := cli.cfg() + + hub, err := require.Hub(cfg, require.RemoteHub(cfg), nil) if err != nil { return err } @@ -71,7 +73,7 @@ func (cli *cliConfig) restoreHub(dirPath string) error { } stage := file.Name() - stagedir := fmt.Sprintf("%s/%s/%s/", csConfig.ConfigPaths.ConfigDir, itype, stage) + stagedir := fmt.Sprintf("%s/%s/%s/", cfg.ConfigPaths.ConfigDir, itype, stage) log.Debugf("Found stage %s in %s, target directory : %s", stage, itype, stagedir) if err = os.MkdirAll(stagedir, os.ModePerm); err != nil { @@ -99,7 +101,7 @@ func (cli *cliConfig) restoreHub(dirPath string) error { } else { log.Infof("Going to restore local/tainted [%s]", file.Name()) sourceFile := fmt.Sprintf("%s/%s", itemDirectory, file.Name()) - destinationFile := fmt.Sprintf("%s/%s/%s", csConfig.ConfigPaths.ConfigDir, itype, file.Name()) + destinationFile := fmt.Sprintf("%s/%s/%s", cfg.ConfigPaths.ConfigDir, itype, file.Name()) if err = CopyFile(sourceFile, destinationFile); err != nil { return fmt.Errorf("failed copy %s %s to %s: %w", itype, sourceFile, destinationFile, err) @@ -127,17 +129,19 @@ func (cli *cliConfig) restoreHub(dirPath string) error { func (cli *cliConfig) restore(dirPath string) error { var err error + cfg := cli.cfg() + backupMain := fmt.Sprintf("%s/config.yaml", dirPath) if _, err = os.Stat(backupMain); err == nil { - if csConfig.ConfigPaths != nil && csConfig.ConfigPaths.ConfigDir != "" { - if err = CopyFile(backupMain, fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir)); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", backupMain, csConfig.ConfigPaths.ConfigDir, err) + if cfg.ConfigPaths != nil && cfg.ConfigPaths.ConfigDir != "" { + if err = CopyFile(backupMain, fmt.Sprintf("%s/config.yaml", cfg.ConfigPaths.ConfigDir)); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupMain, cfg.ConfigPaths.ConfigDir, err) } } } // Now we have config.yaml, we should regenerate config struct to have rights paths etc - ConfigFilePath = fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir) + ConfigFilePath = fmt.Sprintf("%s/config.yaml", cfg.ConfigPaths.ConfigDir) log.Debug("Reloading configuration") @@ -146,38 +150,40 @@ func (cli *cliConfig) restore(dirPath string) error { return fmt.Errorf("failed to reload configuration: %w", err) } + cfg = cli.cfg() + backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath) if _, err = os.Stat(backupCAPICreds); err == nil { - if err = CopyFile(backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", backupCAPICreds, csConfig.API.Server.OnlineClient.CredentialsFilePath, err) + if err = CopyFile(backupCAPICreds, cfg.API.Server.OnlineClient.CredentialsFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupCAPICreds, cfg.API.Server.OnlineClient.CredentialsFilePath, err) } } backupLAPICreds := fmt.Sprintf("%s/local_api_credentials.yaml", dirPath) if _, err = os.Stat(backupLAPICreds); err == nil { - if err = CopyFile(backupLAPICreds, csConfig.API.Client.CredentialsFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", backupLAPICreds, csConfig.API.Client.CredentialsFilePath, err) + if err = CopyFile(backupLAPICreds, cfg.API.Client.CredentialsFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupLAPICreds, cfg.API.Client.CredentialsFilePath, err) } } backupProfiles := fmt.Sprintf("%s/profiles.yaml", dirPath) if _, err = os.Stat(backupProfiles); err == nil { - if err = CopyFile(backupProfiles, csConfig.API.Server.ProfilesPath); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", backupProfiles, csConfig.API.Server.ProfilesPath, err) + if err = CopyFile(backupProfiles, cfg.API.Server.ProfilesPath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupProfiles, cfg.API.Server.ProfilesPath, err) } } backupSimulation := fmt.Sprintf("%s/simulation.yaml", dirPath) if _, err = os.Stat(backupSimulation); err == nil { - if err = CopyFile(backupSimulation, csConfig.ConfigPaths.SimulationFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", backupSimulation, csConfig.ConfigPaths.SimulationFilePath, err) + if err = CopyFile(backupSimulation, cfg.ConfigPaths.SimulationFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupSimulation, cfg.ConfigPaths.SimulationFilePath, err) } } /*if there is a acquisition dir, restore its content*/ - if csConfig.Crowdsec.AcquisitionDirPath != "" { - if err = os.MkdirAll(csConfig.Crowdsec.AcquisitionDirPath, 0o700); err != nil { - return fmt.Errorf("error while creating %s: %w", csConfig.Crowdsec.AcquisitionDirPath, err) + if cfg.Crowdsec.AcquisitionDirPath != "" { + if err = os.MkdirAll(cfg.Crowdsec.AcquisitionDirPath, 0o700); err != nil { + return fmt.Errorf("error while creating %s: %w", cfg.Crowdsec.AcquisitionDirPath, err) } } @@ -186,8 +192,8 @@ func (cli *cliConfig) restore(dirPath string) error { if _, err = os.Stat(backupAcquisition); err == nil { log.Debugf("restoring backup'ed %s", backupAcquisition) - if err = CopyFile(backupAcquisition, csConfig.Crowdsec.AcquisitionFilePath); err != nil { - return fmt.Errorf("failed copy %s to %s: %w", backupAcquisition, csConfig.Crowdsec.AcquisitionFilePath, err) + if err = CopyFile(backupAcquisition, cfg.Crowdsec.AcquisitionFilePath); err != nil { + return fmt.Errorf("failed copy %s to %s: %w", backupAcquisition, cfg.Crowdsec.AcquisitionFilePath, err) } } @@ -195,7 +201,7 @@ func (cli *cliConfig) restore(dirPath string) error { acquisBackupDir := filepath.Join(dirPath, "acquis", "*.yaml") if acquisFiles, err := filepath.Glob(acquisBackupDir); err == nil { for _, acquisFile := range acquisFiles { - targetFname, err := filepath.Abs(csConfig.Crowdsec.AcquisitionDirPath + "/" + filepath.Base(acquisFile)) + targetFname, err := filepath.Abs(cfg.Crowdsec.AcquisitionDirPath + "/" + filepath.Base(acquisFile)) if err != nil { return fmt.Errorf("while saving %s to %s: %w", acquisFile, targetFname, err) } @@ -208,12 +214,12 @@ func (cli *cliConfig) restore(dirPath string) error { } } - if csConfig.Crowdsec != nil && len(csConfig.Crowdsec.AcquisitionFiles) > 0 { - for _, acquisFile := range csConfig.Crowdsec.AcquisitionFiles { + if cfg.Crowdsec != nil && len(cfg.Crowdsec.AcquisitionFiles) > 0 { + for _, acquisFile := range cfg.Crowdsec.AcquisitionFiles { log.Infof("backup filepath from dir -> %s", acquisFile) // if it was the default one, it has already been backed up - if csConfig.Crowdsec.AcquisitionFilePath == acquisFile { + if cfg.Crowdsec.AcquisitionFilePath == acquisFile { log.Infof("skip this one") continue } diff --git a/cmd/crowdsec-cli/hubtest.go b/cmd/crowdsec-cli/hubtest.go index 1860540e7..8f5ab0873 100644 --- a/cmd/crowdsec-cli/hubtest.go +++ b/cmd/crowdsec-cli/hubtest.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "errors" "fmt" "math" "os" @@ -20,21 +21,29 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/hubtest" ) -var HubTest hubtest.HubTest -var HubAppsecTests hubtest.HubTest -var hubPtr *hubtest.HubTest -var isAppsecTest bool +var ( + HubTest hubtest.HubTest + HubAppsecTests hubtest.HubTest + hubPtr *hubtest.HubTest + isAppsecTest bool +) -type cliHubTest struct{} - -func NewCLIHubTest() *cliHubTest { - return &cliHubTest{} +type cliHubTest struct { + cfg configGetter } -func (cli cliHubTest) NewCommand() *cobra.Command { - var hubPath string - var crowdsecPath string - var cscliPath string +func NewCLIHubTest(cfg configGetter) *cliHubTest { + return &cliHubTest{ + cfg: cfg, + } +} + +func (cli *cliHubTest) NewCommand() *cobra.Command { + var ( + hubPath string + crowdsecPath string + cscliPath string + ) cmd := &cobra.Command{ Use: "hubtest", @@ -53,11 +62,13 @@ func (cli cliHubTest) NewCommand() *cobra.Command { if err != nil { return fmt.Errorf("unable to load appsec specific hubtest: %+v", err) } - /*commands will use the hubPtr, will point to the default hubTest object, or the one dedicated to appsec tests*/ + + // commands will use the hubPtr, will point to the default hubTest object, or the one dedicated to appsec tests hubPtr = &HubTest if isAppsecTest { hubPtr = &HubAppsecTests } + return nil }, } @@ -79,13 +90,16 @@ func (cli cliHubTest) NewCommand() *cobra.Command { return cmd } -func (cli cliHubTest) NewCreateCmd() *cobra.Command { +func (cli *cliHubTest) NewCreateCmd() *cobra.Command { + var ( + ignoreParsers bool + labels map[string]string + logType string + ) + parsers := []string{} postoverflows := []string{} scenarios := []string{} - var ignoreParsers bool - var labels map[string]string - var logType string cmd := &cobra.Command{ Use: "create", @@ -107,7 +121,7 @@ cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios } if logType == "" { - return fmt.Errorf("please provide a type (--type) for the test") + return errors.New("please provide a type (--type) for the test") } if err := os.MkdirAll(testPath, os.ModePerm); err != nil { @@ -118,7 +132,7 @@ cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios configFileData := &hubtest.HubTestItemConfig{} if logType == "appsec" { - //create empty nuclei template file + // create empty nuclei template file nucleiFileName := fmt.Sprintf("%s.yaml", testName) nucleiFilePath := filepath.Join(testPath, nucleiFileName) nucleiFile, err := os.OpenFile(nucleiFilePath, os.O_RDWR|os.O_CREATE, 0755) @@ -128,7 +142,7 @@ cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios ntpl := template.Must(template.New("nuclei").Parse(hubtest.TemplateNucleiFile)) if ntpl == nil { - return fmt.Errorf("unable to parse nuclei template") + return errors.New("unable to parse nuclei template") } ntpl.ExecuteTemplate(nucleiFile, "nuclei", struct{ TestName string }{TestName: testName}) nucleiFile.Close() @@ -188,24 +202,24 @@ cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios fmt.Printf(" Parser assertion file : %s (please fill it with assertion)\n", parserAssertFilePath) fmt.Printf(" Scenario assertion file : %s (please fill it with assertion)\n", scenarioAssertFilePath) fmt.Printf(" Configuration File : %s (please fill it with parsers, scenarios...)\n", configFilePath) - } fd, err := os.Create(configFilePath) if err != nil { - return fmt.Errorf("open: %s", err) + return fmt.Errorf("open: %w", err) } data, err := yaml.Marshal(configFileData) if err != nil { - return fmt.Errorf("marshal: %s", err) + return fmt.Errorf("marshal: %w", err) } _, err = fd.Write(data) if err != nil { - return fmt.Errorf("write: %s", err) + return fmt.Errorf("write: %w", err) } if err := fd.Close(); err != nil { - return fmt.Errorf("close: %s", err) + return fmt.Errorf("close: %w", err) } + return nil }, } @@ -219,20 +233,25 @@ cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios return cmd } -func (cli cliHubTest) NewRunCmd() *cobra.Command { - var noClean bool - var runAll bool - var forceClean bool - var NucleiTargetHost string - var AppSecHost string - var cmd = &cobra.Command{ +func (cli *cliHubTest) NewRunCmd() *cobra.Command { + var ( + noClean bool + runAll bool + forceClean bool + NucleiTargetHost string + AppSecHost string + ) + + cmd := &cobra.Command{ Use: "run", Short: "run [test_name]", DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, args []string) error { + cfg := cli.cfg() + if !runAll && len(args) == 0 { printHelp(cmd) - return fmt.Errorf("please provide test to run or --all flag") + return errors.New("please provide test to run or --all flag") } hubPtr.NucleiTargetHost = NucleiTargetHost hubPtr.AppSecHost = AppSecHost @@ -244,7 +263,7 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { for _, testName := range args { _, err := hubPtr.LoadTestItem(testName) if err != nil { - return fmt.Errorf("unable to load test '%s': %s", testName, err) + return fmt.Errorf("unable to load test '%s': %w", testName, err) } } } @@ -252,7 +271,7 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { // set timezone to avoid DST issues os.Setenv("TZ", "UTC") for _, test := range hubPtr.Tests { - if csConfig.Cscli.Output == "human" { + if cfg.Cscli.Output == "human" { log.Infof("Running test '%s'", test.Name) } err := test.Run() @@ -264,6 +283,8 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { return nil }, PersistentPostRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + success := true testResult := make(map[string]bool) for _, test := range hubPtr.Tests { @@ -280,7 +301,7 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { } if !noClean { if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) } } fmt.Printf("\nPlease fill your assert file(s) for test '%s', exiting\n", test.Name) @@ -288,18 +309,18 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { } testResult[test.Name] = test.Success if test.Success { - if csConfig.Cscli.Output == "human" { + if cfg.Cscli.Output == "human" { log.Infof("Test '%s' passed successfully (%d assertions)\n", test.Name, test.ParserAssert.NbAssert+test.ScenarioAssert.NbAssert) } if !noClean { if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) } } } else { success = false cleanTestEnv := false - if csConfig.Cscli.Output == "human" { + if cfg.Cscli.Output == "human" { if len(test.ParserAssert.Fails) > 0 { fmt.Println() log.Errorf("Parser test '%s' failed (%d errors)\n", test.Name, len(test.ParserAssert.Fails)) @@ -330,20 +351,20 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { Default: true, } if err := survey.AskOne(prompt, &cleanTestEnv); err != nil { - return fmt.Errorf("unable to ask to remove runtime folder: %s", err) + return fmt.Errorf("unable to ask to remove runtime folder: %w", err) } } } if cleanTestEnv || forceClean { if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) } } } } - switch csConfig.Cscli.Output { + switch cfg.Cscli.Output { case "human": hubTestResultTable(color.Output, testResult) case "json": @@ -359,11 +380,11 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { } jsonStr, err := json.Marshal(jsonResult) if err != nil { - return fmt.Errorf("unable to json test result: %s", err) + return fmt.Errorf("unable to json test result: %w", err) } fmt.Println(string(jsonStr)) default: - return fmt.Errorf("only human/json output modes are supported") + return errors.New("only human/json output modes are supported") } if !success { @@ -383,7 +404,7 @@ func (cli cliHubTest) NewRunCmd() *cobra.Command { return cmd } -func (cli cliHubTest) NewCleanCmd() *cobra.Command { +func (cli *cliHubTest) NewCleanCmd() *cobra.Command { var cmd = &cobra.Command{ Use: "clean", Short: "clean [test_name]", @@ -393,10 +414,10 @@ func (cli cliHubTest) NewCleanCmd() *cobra.Command { for _, testName := range args { test, err := hubPtr.LoadTestItem(testName) if err != nil { - return fmt.Errorf("unable to load test '%s': %s", testName, err) + return fmt.Errorf("unable to load test '%s': %w", testName, err) } if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %s", test.Name, err) + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) } } @@ -407,7 +428,7 @@ func (cli cliHubTest) NewCleanCmd() *cobra.Command { return cmd } -func (cli cliHubTest) NewInfoCmd() *cobra.Command { +func (cli *cliHubTest) NewInfoCmd() *cobra.Command { cmd := &cobra.Command{ Use: "info", Short: "info [test_name]", @@ -417,7 +438,7 @@ func (cli cliHubTest) NewInfoCmd() *cobra.Command { for _, testName := range args { test, err := hubPtr.LoadTestItem(testName) if err != nil { - return fmt.Errorf("unable to load test '%s': %s", testName, err) + return fmt.Errorf("unable to load test '%s': %w", testName, err) } fmt.Println() fmt.Printf(" Test name : %s\n", test.Name) @@ -440,17 +461,19 @@ func (cli cliHubTest) NewInfoCmd() *cobra.Command { return cmd } -func (cli cliHubTest) NewListCmd() *cobra.Command { +func (cli *cliHubTest) NewListCmd() *cobra.Command { cmd := &cobra.Command{ Use: "list", Short: "list", DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := hubPtr.LoadAllTests(); err != nil { - return fmt.Errorf("unable to load all tests: %s", err) + return fmt.Errorf("unable to load all tests: %w", err) } - switch csConfig.Cscli.Output { + switch cfg.Cscli.Output { case "human": hubTestListTable(color.Output, hubPtr.Tests) case "json": @@ -460,7 +483,7 @@ func (cli cliHubTest) NewListCmd() *cobra.Command { } fmt.Println(string(j)) default: - return fmt.Errorf("only human/json output modes are supported") + return errors.New("only human/json output modes are supported") } return nil @@ -470,18 +493,22 @@ func (cli cliHubTest) NewListCmd() *cobra.Command { return cmd } -func (cli cliHubTest) NewCoverageCmd() *cobra.Command { - var showParserCov bool - var showScenarioCov bool - var showOnlyPercent bool - var showAppsecCov bool +func (cli *cliHubTest) NewCoverageCmd() *cobra.Command { + var ( + showParserCov bool + showScenarioCov bool + showOnlyPercent bool + showAppsecCov bool + ) cmd := &cobra.Command{ Use: "coverage", Short: "coverage", DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { - //for this one we explicitly don't do for appsec + cfg := cli.cfg() + + // for this one we explicitly don't do for appsec if err := HubTest.LoadAllTests(); err != nil { return fmt.Errorf("unable to load all tests: %+v", err) } @@ -499,7 +526,7 @@ func (cli cliHubTest) NewCoverageCmd() *cobra.Command { if showParserCov || showAll { parserCoverage, err = HubTest.GetParsersCoverage() if err != nil { - return fmt.Errorf("while getting parser coverage: %s", err) + return fmt.Errorf("while getting parser coverage: %w", err) } parserTested := 0 for _, test := range parserCoverage { @@ -513,7 +540,7 @@ func (cli cliHubTest) NewCoverageCmd() *cobra.Command { if showScenarioCov || showAll { scenarioCoverage, err = HubTest.GetScenariosCoverage() if err != nil { - return fmt.Errorf("while getting scenario coverage: %s", err) + return fmt.Errorf("while getting scenario coverage: %w", err) } scenarioTested := 0 @@ -529,7 +556,7 @@ func (cli cliHubTest) NewCoverageCmd() *cobra.Command { if showAppsecCov || showAll { appsecRuleCoverage, err = HubTest.GetAppsecCoverage() if err != nil { - return fmt.Errorf("while getting scenario coverage: %s", err) + return fmt.Errorf("while getting scenario coverage: %w", err) } appsecRuleTested := 0 @@ -542,19 +569,20 @@ func (cli cliHubTest) NewCoverageCmd() *cobra.Command { } if showOnlyPercent { - if showAll { + switch { + case showAll: fmt.Printf("parsers=%d%%\nscenarios=%d%%\nappsec_rules=%d%%", parserCoveragePercent, scenarioCoveragePercent, appsecRuleCoveragePercent) - } else if showParserCov { + case showParserCov: fmt.Printf("parsers=%d%%", parserCoveragePercent) - } else if showScenarioCov { + case showScenarioCov: fmt.Printf("scenarios=%d%%", scenarioCoveragePercent) - } else if showAppsecCov { + case showAppsecCov: fmt.Printf("appsec_rules=%d%%", appsecRuleCoveragePercent) } os.Exit(0) } - switch csConfig.Cscli.Output { + switch cfg.Cscli.Output { case "human": if showParserCov || showAll { hubTestParserCoverageTable(color.Output, parserCoverage) @@ -595,7 +623,7 @@ func (cli cliHubTest) NewCoverageCmd() *cobra.Command { } fmt.Printf("%s", dump) default: - return fmt.Errorf("only human/json output modes are supported") + return errors.New("only human/json output modes are supported") } return nil @@ -610,7 +638,7 @@ func (cli cliHubTest) NewCoverageCmd() *cobra.Command { return cmd } -func (cli cliHubTest) NewEvalCmd() *cobra.Command { +func (cli *cliHubTest) NewEvalCmd() *cobra.Command { var evalExpression string cmd := &cobra.Command{ @@ -647,7 +675,7 @@ func (cli cliHubTest) NewEvalCmd() *cobra.Command { return cmd } -func (cli cliHubTest) NewExplainCmd() *cobra.Command { +func (cli *cliHubTest) NewExplainCmd() *cobra.Command { cmd := &cobra.Command{ Use: "explain", Short: "explain [test_name]", @@ -666,7 +694,7 @@ func (cli cliHubTest) NewExplainCmd() *cobra.Command { } if err = test.ParserAssert.LoadTest(test.ParserResultFile); err != nil { - return fmt.Errorf("unable to load parser result after run: %s", err) + return fmt.Errorf("unable to load parser result after run: %w", err) } } @@ -677,7 +705,7 @@ func (cli cliHubTest) NewExplainCmd() *cobra.Command { } if err = test.ScenarioAssert.LoadTest(test.ScenarioResultFile, test.BucketPourResultFile); err != nil { - return fmt.Errorf("unable to load scenario result after run: %s", err) + return fmt.Errorf("unable to load scenario result after run: %w", err) } } opts := dumps.DumpOpts{} diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index 1f87390b6..446901e4a 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -240,12 +240,12 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall cmd.AddCommand(NewCLISimulation(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIBouncers(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIMachines(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLICapi().NewCommand()) + cmd.AddCommand(NewCLICapi(cli.cfg).NewCommand()) cmd.AddCommand(NewCLILapi(cli.cfg).NewCommand()) cmd.AddCommand(NewCompletionCmd()) cmd.AddCommand(NewCLIConsole(cli.cfg).NewCommand()) cmd.AddCommand(NewCLIExplain(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIHubTest().NewCommand()) + cmd.AddCommand(NewCLIHubTest(cli.cfg).NewCommand()) cmd.AddCommand(NewCLINotifications(cli.cfg).NewCommand()) cmd.AddCommand(NewCLISupport().NewCommand()) cmd.AddCommand(NewCLIPapi(cli.cfg).NewCommand()) From 4bf640c6e86185b506fde7332a338ccf2eb711ca Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Fri, 23 Feb 2024 14:03:50 +0100 Subject: [PATCH 6/6] refact pkg/apiserver (auth helpers) (#2856) --- pkg/apiserver/controllers/v1/alerts.go | 5 +--- pkg/apiserver/controllers/v1/heartbeat.go | 5 +--- pkg/apiserver/controllers/v1/metrics.go | 34 ++++++++++------------- pkg/apiserver/controllers/v1/utils.go | 32 +++++++++++++++++---- pkg/apiserver/middlewares/v1/api_key.go | 11 ++------ pkg/apiserver/middlewares/v1/jwt.go | 8 +++--- 6 files changed, 50 insertions(+), 45 deletions(-) diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index e7d106d72..ad183e4ba 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -9,7 +9,6 @@ import ( "strings" "time" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" "github.com/google/uuid" @@ -143,9 +142,7 @@ func normalizeScope(scope string) string { func (c *Controller) CreateAlert(gctx *gin.Context) { var input models.AddAlertsRequest - claims := jwt.ExtractClaims(gctx) - // TBD: use defined rather than hardcoded key to find back owner - machineID := claims["id"].(string) + machineID, _ := getMachineIDFromContext(gctx) if err := gctx.ShouldBindJSON(&input); err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) diff --git a/pkg/apiserver/controllers/v1/heartbeat.go b/pkg/apiserver/controllers/v1/heartbeat.go index b19b450f0..e1231eaa9 100644 --- a/pkg/apiserver/controllers/v1/heartbeat.go +++ b/pkg/apiserver/controllers/v1/heartbeat.go @@ -3,14 +3,11 @@ package v1 import ( "net/http" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" ) func (c *Controller) HeartBeat(gctx *gin.Context) { - claims := jwt.ExtractClaims(gctx) - // TBD: use defined rather than hardcoded key to find back owner - machineID := claims["id"].(string) + machineID, _ := getMachineIDFromContext(gctx) if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil { c.HandleDBErrors(gctx, err) diff --git a/pkg/apiserver/controllers/v1/metrics.go b/pkg/apiserver/controllers/v1/metrics.go index 13ccf9ac9..ddb38512a 100644 --- a/pkg/apiserver/controllers/v1/metrics.go +++ b/pkg/apiserver/controllers/v1/metrics.go @@ -3,7 +3,6 @@ package v1 import ( "time" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" "github.com/prometheus/client_golang/prometheus" ) @@ -66,32 +65,29 @@ var LapiResponseTime = prometheus.NewHistogramVec( []string{"endpoint", "method"}) func PrometheusBouncersHasEmptyDecision(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiNilDecisions.With(prometheus.Labels{ - "bouncer": name.(string)}).Inc() + "bouncer": bouncer.Name}).Inc() } } func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiNonNilDecisions.With(prometheus.Labels{ - "bouncer": name.(string)}).Inc() + "bouncer": bouncer.Name}).Inc() } } func PrometheusMachinesMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - claims := jwt.ExtractClaims(c) - if claims != nil { - if rawID, ok := claims["id"]; ok { - machineID := rawID.(string) - LapiMachineHits.With(prometheus.Labels{ - "machine": machineID, - "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() - } + machineID, _ := getMachineIDFromContext(c) + if machineID != "" { + LapiMachineHits.With(prometheus.Labels{ + "machine": machineID, + "route": c.Request.URL.Path, + "method": c.Request.Method}).Inc() } c.Next() @@ -100,10 +96,10 @@ func PrometheusMachinesMiddleware() gin.HandlerFunc { func PrometheusBouncersMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiBouncerHits.With(prometheus.Labels{ - "bouncer": name.(string), + "bouncer": bouncer.Name, "route": c.Request.URL.Path, "method": c.Request.Method}).Inc() } diff --git a/pkg/apiserver/controllers/v1/utils.go b/pkg/apiserver/controllers/v1/utils.go index 6afd00513..6f14dd920 100644 --- a/pkg/apiserver/controllers/v1/utils.go +++ b/pkg/apiserver/controllers/v1/utils.go @@ -1,30 +1,50 @@ package v1 import ( - "fmt" + "errors" "net/http" + jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" + middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/database/ent" ) -const bouncerContextKey = "bouncer_info" - func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) { - bouncerInterface, exist := ctx.Get(bouncerContextKey) + bouncerInterface, exist := ctx.Get(middlewares.BouncerContextKey) if !exist { - return nil, fmt.Errorf("bouncer not found") + return nil, errors.New("bouncer not found") } bouncerInfo, ok := bouncerInterface.(*ent.Bouncer) if !ok { - return nil, fmt.Errorf("bouncer not found") + return nil, errors.New("bouncer not found") } return bouncerInfo, nil } +func getMachineIDFromContext(ctx *gin.Context) (string, error) { + claims := jwt.ExtractClaims(ctx) + if claims == nil { + return "", errors.New("failed to extract claims") + } + + rawID, ok := claims[middlewares.MachineIDKey] + if !ok { + return "", errors.New("MachineID not found in claims") + } + + id, ok := rawID.(string) + if !ok { + // should never happen + return "", errors.New("failed to cast machineID to string") + } + + return id, nil +} + func (c *Controller) AbortRemoteIf(option bool) gin.HandlerFunc { return func(gctx *gin.Context) { incomingIP := gctx.ClientIP() diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 41ee15b44..4e273371b 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -18,9 +18,9 @@ import ( const ( APIKeyHeader = "X-Api-Key" - bouncerContextKey = "bouncer_info" - // max allowed by bcrypt 72 = 54 bytes in base64 + BouncerContextKey = "bouncer_info" dummyAPIKeySize = 54 + // max allowed by bcrypt 72 = 54 bytes in base64 ) type APIKey struct { @@ -159,11 +159,6 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { "name": bouncer.Name, }) - // maybe we want to store the whole bouncer object in the context instead, this would avoid another db query - // in StreamDecision - c.Set("BOUNCER_NAME", bouncer.Name) - c.Set("BOUNCER_HASHED_KEY", bouncer.APIKey) - if bouncer.IPAddress == "" { if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) @@ -203,7 +198,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { } } - c.Set(bouncerContextKey, bouncer) + c.Set(BouncerContextKey, bouncer) c.Next() } } diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index ed4ad107b..6fe053713 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -22,7 +22,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var identityKey = "id" +const MachineIDKey = "id" type JWT struct { Middleware *jwt.GinJWTMiddleware @@ -33,7 +33,7 @@ type JWT struct { func PayloadFunc(data interface{}) jwt.MapClaims { if value, ok := data.(*models.WatcherAuthRequest); ok { return jwt.MapClaims{ - identityKey: &value.MachineID, + MachineIDKey: &value.MachineID, } } @@ -42,7 +42,7 @@ func PayloadFunc(data interface{}) jwt.MapClaims { func IdentityHandler(c *gin.Context) interface{} { claims := jwt.ExtractClaims(c) - machineID := claims[identityKey].(string) + machineID := claims[MachineIDKey].(string) return &models.WatcherAuthRequest{ MachineID: &machineID, @@ -307,7 +307,7 @@ func NewJWT(dbClient *database.Client) (*JWT, error) { Key: secret, Timeout: time.Hour, MaxRefresh: time.Hour, - IdentityKey: identityKey, + IdentityKey: MachineIDKey, PayloadFunc: PayloadFunc, IdentityHandler: IdentityHandler, Authenticator: jwtMiddleware.Authenticator,