123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444 |
- package apiserver
- import (
- "encoding/json"
- "fmt"
- "net/http"
- "net/http/httptest"
- "os"
- "path/filepath"
- "strings"
- "testing"
- "time"
- middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
- "github.com/crowdsecurity/crowdsec/pkg/cwversion"
- "github.com/crowdsecurity/crowdsec/pkg/models"
- "github.com/crowdsecurity/crowdsec/pkg/types"
- "github.com/go-openapi/strfmt"
- "github.com/pkg/errors"
- "github.com/crowdsecurity/crowdsec/pkg/csconfig"
- "github.com/crowdsecurity/crowdsec/pkg/database"
- "github.com/gin-gonic/gin"
- log "github.com/sirupsen/logrus"
- "github.com/stretchr/testify/assert"
- )
- var testMachineID = "test"
- var testPassword = strfmt.Password("test")
- var MachineTest = models.WatcherAuthRequest{
- MachineID: &testMachineID,
- Password: &testPassword,
- }
- var UserAgent = fmt.Sprintf("crowdsec-test/%s", cwversion.Version)
- var emptyBody = strings.NewReader("")
- func LoadTestConfig() csconfig.Config {
- config := csconfig.Config{}
- maxAge := "1h"
- flushConfig := csconfig.FlushDBCfg{
- MaxAge: &maxAge,
- }
- tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
- dbconfig := csconfig.DatabaseCfg{
- Type: "sqlite",
- DbPath: filepath.Join(tempDir, "ent"),
- Flush: &flushConfig,
- }
- apiServerConfig := csconfig.LocalApiServerCfg{
- ListenURI: "http://127.0.0.1:8080",
- DbConfig: &dbconfig,
- ProfilesPath: "./tests/profiles.yaml",
- ConsoleConfig: &csconfig.ConsoleConfig{
- ShareManualDecisions: new(bool),
- ShareTaintedScenarios: new(bool),
- ShareCustomScenarios: new(bool),
- },
- }
- apiConfig := csconfig.APICfg{
- Server: &apiServerConfig,
- }
- config.API = &apiConfig
- if err := config.API.Server.LoadProfiles(); err != nil {
- log.Fatalf("failed to load profiles: %s", err)
- }
- return config
- }
- func LoadTestConfigForwardedFor() csconfig.Config {
- config := csconfig.Config{}
- maxAge := "1h"
- flushConfig := csconfig.FlushDBCfg{
- MaxAge: &maxAge,
- }
- tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
- dbconfig := csconfig.DatabaseCfg{
- Type: "sqlite",
- DbPath: filepath.Join(tempDir, "ent"),
- Flush: &flushConfig,
- }
- apiServerConfig := csconfig.LocalApiServerCfg{
- ListenURI: "http://127.0.0.1:8080",
- DbConfig: &dbconfig,
- ProfilesPath: "./tests/profiles.yaml",
- UseForwardedForHeaders: true,
- TrustedProxies: &[]string{"0.0.0.0/0"},
- ConsoleConfig: &csconfig.ConsoleConfig{
- ShareManualDecisions: new(bool),
- ShareTaintedScenarios: new(bool),
- ShareCustomScenarios: new(bool),
- },
- }
- apiConfig := csconfig.APICfg{
- Server: &apiServerConfig,
- }
- config.API = &apiConfig
- if err := config.API.Server.LoadProfiles(); err != nil {
- log.Fatalf("failed to load profiles: %s", err)
- }
- return config
- }
- func NewAPIServer() (*APIServer, csconfig.Config, error) {
- config := LoadTestConfig()
- os.Remove("./ent")
- apiServer, err := NewServer(config.API.Server)
- if err != nil {
- return nil, config, fmt.Errorf("unable to run local API: %s", err)
- }
- log.Printf("Creating new API server")
- gin.SetMode(gin.TestMode)
- return apiServer, config, nil
- }
- func NewAPITest() (*gin.Engine, csconfig.Config, error) {
- apiServer, config, err := NewAPIServer()
- if err != nil {
- return nil, config, fmt.Errorf("unable to run local API: %s", err)
- }
- err = apiServer.InitController()
- if err != nil {
- return nil, config, fmt.Errorf("unable to run local API: %s", err)
- }
- router, err := apiServer.Router()
- if err != nil {
- return nil, config, fmt.Errorf("unable to run local API: %s", err)
- }
- return router, config, nil
- }
- func NewAPITestForwardedFor() (*gin.Engine, csconfig.Config, error) {
- config := LoadTestConfigForwardedFor()
- os.Remove("./ent")
- apiServer, err := NewServer(config.API.Server)
- if err != nil {
- return nil, config, fmt.Errorf("unable to run local API: %s", err)
- }
- err = apiServer.InitController()
- if err != nil {
- return nil, config, fmt.Errorf("unable to run local API: %s", err)
- }
- log.Printf("Creating new API server")
- gin.SetMode(gin.TestMode)
- router, err := apiServer.Router()
- if err != nil {
- return nil, config, fmt.Errorf("unable to run local API: %s", err)
- }
- return router, config, nil
- }
- func ValidateMachine(machineID string, config *csconfig.DatabaseCfg) error {
- dbClient, err := database.NewClient(config)
- if err != nil {
- return fmt.Errorf("unable to create new database client: %s", err)
- }
- if err := dbClient.ValidateMachine(machineID); err != nil {
- return fmt.Errorf("unable to validate machine: %s", err)
- }
- return nil
- }
- func GetMachineIP(machineID string, config *csconfig.DatabaseCfg) (string, error) {
- dbClient, err := database.NewClient(config)
- if err != nil {
- return "", fmt.Errorf("unable to create new database client: %s", err)
- }
- machines, err := dbClient.ListMachines()
- if err != nil {
- return "", fmt.Errorf("Unable to list machines: %s", err)
- }
- for _, machine := range machines {
- if machine.MachineId == machineID {
- return machine.IpAddress, nil
- }
- }
- return "", nil
- }
- func GetAlertReaderFromFile(path string) *strings.Reader {
- alertContentBytes, err := os.ReadFile(path)
- if err != nil {
- log.Fatal(err)
- }
- alerts := make([]*models.Alert, 0)
- if err := json.Unmarshal(alertContentBytes, &alerts); err != nil {
- log.Fatal(err)
- }
- for _, alert := range alerts {
- *alert.StartAt = time.Now().UTC().Format(time.RFC3339)
- *alert.StopAt = time.Now().UTC().Format(time.RFC3339)
- }
- alertContent, err := json.Marshal(alerts)
- if err != nil {
- log.Fatal(err)
- }
- return strings.NewReader(string(alertContent))
- }
- func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision, int, error) {
- var response []*models.Decision
- if resp == nil {
- return nil, 0, errors.New("response is nil")
- }
- err := json.Unmarshal(resp.Body.Bytes(), &response)
- if err != nil {
- return nil, resp.Code, err
- }
- return response, resp.Code, nil
- }
- func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string, int, error) {
- var response map[string]string
- if resp == nil {
- return nil, 0, errors.New("response is nil")
- }
- err := json.Unmarshal(resp.Body.Bytes(), &response)
- if err != nil {
- return nil, resp.Code, err
- }
- return response, resp.Code, nil
- }
- func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int, error) {
- var response models.DeleteDecisionResponse
- if resp == nil {
- return nil, 0, errors.New("response is nil")
- }
- err := json.Unmarshal(resp.Body.Bytes(), &response)
- if err != nil {
- return nil, resp.Code, err
- }
- return &response, resp.Code, nil
- }
- func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int, error) {
- response := make(map[string][]*models.Decision)
- if resp == nil {
- return nil, 0, errors.New("response is nil")
- }
- err := json.Unmarshal(resp.Body.Bytes(), &response)
- if err != nil {
- return nil, resp.Code, err
- }
- return response, resp.Code, nil
- }
- func CreateTestMachine(router *gin.Engine) (string, error) {
- b, err := json.Marshal(MachineTest)
- if err != nil {
- return "", fmt.Errorf("unable to marshal MachineTest")
- }
- body := string(b)
- w := httptest.NewRecorder()
- req, _ := http.NewRequest("POST", "/v1/watchers", strings.NewReader(body))
- req.Header.Set("User-Agent", UserAgent)
- router.ServeHTTP(w, req)
- return body, nil
- }
- func CreateTestBouncer(config *csconfig.DatabaseCfg) (string, error) {
- dbClient, err := database.NewClient(config)
- if err != nil {
- log.Fatalf("unable to create new database client: %s", err)
- }
- apiKey, err := middlewares.GenerateAPIKey(keyLength)
- if err != nil {
- return "", fmt.Errorf("unable to generate api key: %s", err)
- }
- err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey))
- if err != nil {
- return "", fmt.Errorf("unable to create blocker: %s", err)
- }
- return apiKey, nil
- }
- func TestWithWrongDBConfig(t *testing.T) {
- config := LoadTestConfig()
- config.API.Server.DbConfig.Type = "test"
- apiServer, err := NewServer(config.API.Server)
- assert.Equal(t, apiServer, &APIServer{})
- assert.Equal(t, "unable to init database client: unknown database type", err.Error())
- }
- func TestWithWrongFlushConfig(t *testing.T) {
- config := LoadTestConfig()
- maxItems := -1
- config.API.Server.DbConfig.Flush.MaxItems = &maxItems
- apiServer, err := NewServer(config.API.Server)
- assert.Equal(t, apiServer, &APIServer{})
- assert.Equal(t, "max_items can't be zero or negative number", err.Error())
- }
- func TestUnknownPath(t *testing.T) {
- router, _, err := NewAPITest()
- if err != nil {
- log.Fatalf("unable to run local API: %s", err)
- }
- w := httptest.NewRecorder()
- req, _ := http.NewRequest("GET", "/test", nil)
- req.Header.Set("User-Agent", UserAgent)
- router.ServeHTTP(w, req)
- assert.Equal(t, 404, w.Code)
- }
- /*
- ListenURI string `yaml:"listen_uri,omitempty"` //127.0.0.1:8080
- TLS *TLSCfg `yaml:"tls"`
- DbConfig *DatabaseCfg `yaml:"-"`
- LogDir string `yaml:"-"`
- LogMedia string `yaml:"-"`
- OnlineClient *OnlineApiClientCfg `yaml:"online_client"`
- ProfilesPath string `yaml:"profiles_path,omitempty"`
- Profiles []*ProfileCfg `yaml:"-"`
- LogLevel *log.Level `yaml:"log_level"`
- UseForwardedForHeaders bool `yaml:"use_forwarded_for_headers,omitempty"`
- */
- func TestLoggingDebugToFileConfig(t *testing.T) {
- /*declare settings*/
- maxAge := "1h"
- flushConfig := csconfig.FlushDBCfg{
- MaxAge: &maxAge,
- }
- tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
- dbconfig := csconfig.DatabaseCfg{
- Type: "sqlite",
- DbPath: filepath.Join(tempDir, "ent"),
- Flush: &flushConfig,
- }
- cfg := csconfig.LocalApiServerCfg{
- ListenURI: "127.0.0.1:8080",
- LogMedia: "file",
- LogDir: tempDir,
- DbConfig: &dbconfig,
- }
- lvl := log.DebugLevel
- expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
- expectedLines := []string{"/test42"}
- cfg.LogLevel = &lvl
- // Configure logging
- if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs); err != nil {
- t.Fatal(err.Error())
- }
- api, err := NewServer(&cfg)
- if err != nil {
- t.Fatalf("failed to create api : %s", err)
- }
- if api == nil {
- t.Fatalf("failed to create api #2 is nbill")
- }
- w := httptest.NewRecorder()
- req, _ := http.NewRequest("GET", "/test42", nil)
- req.Header.Set("User-Agent", UserAgent)
- api.router.ServeHTTP(w, req)
- assert.Equal(t, 404, w.Code)
- //wait for the request to happen
- time.Sleep(500 * time.Millisecond)
- //check file content
- data, err := os.ReadFile(expectedFile)
- if err != nil {
- t.Fatalf("failed to read file : %s", err)
- }
- for _, expectedStr := range expectedLines {
- if !strings.Contains(string(data), expectedStr) {
- t.Fatalf("expected %s in %s", expectedStr, string(data))
- }
- }
- }
- func TestLoggingErrorToFileConfig(t *testing.T) {
- /*declare settings*/
- maxAge := "1h"
- flushConfig := csconfig.FlushDBCfg{
- MaxAge: &maxAge,
- }
- tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
- dbconfig := csconfig.DatabaseCfg{
- Type: "sqlite",
- DbPath: filepath.Join(tempDir, "ent"),
- Flush: &flushConfig,
- }
- cfg := csconfig.LocalApiServerCfg{
- ListenURI: "127.0.0.1:8080",
- LogMedia: "file",
- LogDir: tempDir,
- DbConfig: &dbconfig,
- }
- lvl := log.ErrorLevel
- expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
- cfg.LogLevel = &lvl
- // Configure logging
- if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs); err != nil {
- t.Fatal(err.Error())
- }
- api, err := NewServer(&cfg)
- if err != nil {
- t.Fatalf("failed to create api : %s", err)
- }
- if api == nil {
- t.Fatalf("failed to create api #2 is nbill")
- }
- w := httptest.NewRecorder()
- req, _ := http.NewRequest("GET", "/test42", nil)
- req.Header.Set("User-Agent", UserAgent)
- api.router.ServeHTTP(w, req)
- assert.Equal(t, 404, w.Code)
- //wait for the request to happen
- time.Sleep(500 * time.Millisecond)
- //check file content
- x, err := os.ReadFile(expectedFile)
- if err == nil && len(x) > 0 {
- t.Fatalf("file should be empty, got '%s'", x)
- }
- os.Remove("./crowdsec.log")
- os.Remove(expectedFile)
- }
|