apiserver_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. package apiserver
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net/http"
  6. "net/http/httptest"
  7. "os"
  8. "path/filepath"
  9. "strings"
  10. "testing"
  11. "time"
  12. middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
  13. "github.com/crowdsecurity/crowdsec/pkg/cwversion"
  14. "github.com/crowdsecurity/crowdsec/pkg/models"
  15. "github.com/crowdsecurity/crowdsec/pkg/types"
  16. "github.com/go-openapi/strfmt"
  17. "github.com/pkg/errors"
  18. "github.com/crowdsecurity/crowdsec/pkg/csconfig"
  19. "github.com/crowdsecurity/crowdsec/pkg/database"
  20. "github.com/gin-gonic/gin"
  21. log "github.com/sirupsen/logrus"
  22. "github.com/stretchr/testify/assert"
  23. )
  24. var testMachineID = "test"
  25. var testPassword = strfmt.Password("test")
  26. var MachineTest = models.WatcherAuthRequest{
  27. MachineID: &testMachineID,
  28. Password: &testPassword,
  29. }
  30. var UserAgent = fmt.Sprintf("crowdsec-test/%s", cwversion.Version)
  31. var emptyBody = strings.NewReader("")
  32. func LoadTestConfig() csconfig.Config {
  33. config := csconfig.Config{}
  34. maxAge := "1h"
  35. flushConfig := csconfig.FlushDBCfg{
  36. MaxAge: &maxAge,
  37. }
  38. tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
  39. dbconfig := csconfig.DatabaseCfg{
  40. Type: "sqlite",
  41. DbPath: filepath.Join(tempDir, "ent"),
  42. Flush: &flushConfig,
  43. }
  44. apiServerConfig := csconfig.LocalApiServerCfg{
  45. ListenURI: "http://127.0.0.1:8080",
  46. DbConfig: &dbconfig,
  47. ProfilesPath: "./tests/profiles.yaml",
  48. ConsoleConfig: &csconfig.ConsoleConfig{
  49. ShareManualDecisions: new(bool),
  50. ShareTaintedScenarios: new(bool),
  51. ShareCustomScenarios: new(bool),
  52. },
  53. }
  54. apiConfig := csconfig.APICfg{
  55. Server: &apiServerConfig,
  56. }
  57. config.API = &apiConfig
  58. if err := config.API.Server.LoadProfiles(); err != nil {
  59. log.Fatalf("failed to load profiles: %s", err)
  60. }
  61. return config
  62. }
  63. func LoadTestConfigForwardedFor() csconfig.Config {
  64. config := csconfig.Config{}
  65. maxAge := "1h"
  66. flushConfig := csconfig.FlushDBCfg{
  67. MaxAge: &maxAge,
  68. }
  69. tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
  70. dbconfig := csconfig.DatabaseCfg{
  71. Type: "sqlite",
  72. DbPath: filepath.Join(tempDir, "ent"),
  73. Flush: &flushConfig,
  74. }
  75. apiServerConfig := csconfig.LocalApiServerCfg{
  76. ListenURI: "http://127.0.0.1:8080",
  77. DbConfig: &dbconfig,
  78. ProfilesPath: "./tests/profiles.yaml",
  79. UseForwardedForHeaders: true,
  80. TrustedProxies: &[]string{"0.0.0.0/0"},
  81. ConsoleConfig: &csconfig.ConsoleConfig{
  82. ShareManualDecisions: new(bool),
  83. ShareTaintedScenarios: new(bool),
  84. ShareCustomScenarios: new(bool),
  85. },
  86. }
  87. apiConfig := csconfig.APICfg{
  88. Server: &apiServerConfig,
  89. }
  90. config.API = &apiConfig
  91. if err := config.API.Server.LoadProfiles(); err != nil {
  92. log.Fatalf("failed to load profiles: %s", err)
  93. }
  94. return config
  95. }
  96. func NewAPIServer() (*APIServer, csconfig.Config, error) {
  97. config := LoadTestConfig()
  98. os.Remove("./ent")
  99. apiServer, err := NewServer(config.API.Server)
  100. if err != nil {
  101. return nil, config, fmt.Errorf("unable to run local API: %s", err)
  102. }
  103. log.Printf("Creating new API server")
  104. gin.SetMode(gin.TestMode)
  105. return apiServer, config, nil
  106. }
  107. func NewAPITest() (*gin.Engine, csconfig.Config, error) {
  108. apiServer, config, err := NewAPIServer()
  109. if err != nil {
  110. return nil, config, fmt.Errorf("unable to run local API: %s", err)
  111. }
  112. err = apiServer.InitController()
  113. if err != nil {
  114. return nil, config, fmt.Errorf("unable to run local API: %s", err)
  115. }
  116. router, err := apiServer.Router()
  117. if err != nil {
  118. return nil, config, fmt.Errorf("unable to run local API: %s", err)
  119. }
  120. return router, config, nil
  121. }
  122. func NewAPITestForwardedFor() (*gin.Engine, csconfig.Config, error) {
  123. config := LoadTestConfigForwardedFor()
  124. os.Remove("./ent")
  125. apiServer, err := NewServer(config.API.Server)
  126. if err != nil {
  127. return nil, config, fmt.Errorf("unable to run local API: %s", err)
  128. }
  129. err = apiServer.InitController()
  130. if err != nil {
  131. return nil, config, fmt.Errorf("unable to run local API: %s", err)
  132. }
  133. log.Printf("Creating new API server")
  134. gin.SetMode(gin.TestMode)
  135. router, err := apiServer.Router()
  136. if err != nil {
  137. return nil, config, fmt.Errorf("unable to run local API: %s", err)
  138. }
  139. return router, config, nil
  140. }
  141. func ValidateMachine(machineID string, config *csconfig.DatabaseCfg) error {
  142. dbClient, err := database.NewClient(config)
  143. if err != nil {
  144. return fmt.Errorf("unable to create new database client: %s", err)
  145. }
  146. if err := dbClient.ValidateMachine(machineID); err != nil {
  147. return fmt.Errorf("unable to validate machine: %s", err)
  148. }
  149. return nil
  150. }
  151. func GetMachineIP(machineID string, config *csconfig.DatabaseCfg) (string, error) {
  152. dbClient, err := database.NewClient(config)
  153. if err != nil {
  154. return "", fmt.Errorf("unable to create new database client: %s", err)
  155. }
  156. machines, err := dbClient.ListMachines()
  157. if err != nil {
  158. return "", fmt.Errorf("Unable to list machines: %s", err)
  159. }
  160. for _, machine := range machines {
  161. if machine.MachineId == machineID {
  162. return machine.IpAddress, nil
  163. }
  164. }
  165. return "", nil
  166. }
  167. func GetAlertReaderFromFile(path string) *strings.Reader {
  168. alertContentBytes, err := os.ReadFile(path)
  169. if err != nil {
  170. log.Fatal(err)
  171. }
  172. alerts := make([]*models.Alert, 0)
  173. if err := json.Unmarshal(alertContentBytes, &alerts); err != nil {
  174. log.Fatal(err)
  175. }
  176. for _, alert := range alerts {
  177. *alert.StartAt = time.Now().UTC().Format(time.RFC3339)
  178. *alert.StopAt = time.Now().UTC().Format(time.RFC3339)
  179. }
  180. alertContent, err := json.Marshal(alerts)
  181. if err != nil {
  182. log.Fatal(err)
  183. }
  184. return strings.NewReader(string(alertContent))
  185. }
  186. func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision, int, error) {
  187. var response []*models.Decision
  188. if resp == nil {
  189. return nil, 0, errors.New("response is nil")
  190. }
  191. err := json.Unmarshal(resp.Body.Bytes(), &response)
  192. if err != nil {
  193. return nil, resp.Code, err
  194. }
  195. return response, resp.Code, nil
  196. }
  197. func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string, int, error) {
  198. var response map[string]string
  199. if resp == nil {
  200. return nil, 0, errors.New("response is nil")
  201. }
  202. err := json.Unmarshal(resp.Body.Bytes(), &response)
  203. if err != nil {
  204. return nil, resp.Code, err
  205. }
  206. return response, resp.Code, nil
  207. }
  208. func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int, error) {
  209. var response models.DeleteDecisionResponse
  210. if resp == nil {
  211. return nil, 0, errors.New("response is nil")
  212. }
  213. err := json.Unmarshal(resp.Body.Bytes(), &response)
  214. if err != nil {
  215. return nil, resp.Code, err
  216. }
  217. return &response, resp.Code, nil
  218. }
  219. func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int, error) {
  220. response := make(map[string][]*models.Decision)
  221. if resp == nil {
  222. return nil, 0, errors.New("response is nil")
  223. }
  224. err := json.Unmarshal(resp.Body.Bytes(), &response)
  225. if err != nil {
  226. return nil, resp.Code, err
  227. }
  228. return response, resp.Code, nil
  229. }
  230. func CreateTestMachine(router *gin.Engine) (string, error) {
  231. b, err := json.Marshal(MachineTest)
  232. if err != nil {
  233. return "", fmt.Errorf("unable to marshal MachineTest")
  234. }
  235. body := string(b)
  236. w := httptest.NewRecorder()
  237. req, _ := http.NewRequest("POST", "/v1/watchers", strings.NewReader(body))
  238. req.Header.Set("User-Agent", UserAgent)
  239. router.ServeHTTP(w, req)
  240. return body, nil
  241. }
  242. func CreateTestBouncer(config *csconfig.DatabaseCfg) (string, error) {
  243. dbClient, err := database.NewClient(config)
  244. if err != nil {
  245. log.Fatalf("unable to create new database client: %s", err)
  246. }
  247. apiKey, err := middlewares.GenerateAPIKey(keyLength)
  248. if err != nil {
  249. return "", fmt.Errorf("unable to generate api key: %s", err)
  250. }
  251. err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey))
  252. if err != nil {
  253. return "", fmt.Errorf("unable to create blocker: %s", err)
  254. }
  255. return apiKey, nil
  256. }
  257. func TestWithWrongDBConfig(t *testing.T) {
  258. config := LoadTestConfig()
  259. config.API.Server.DbConfig.Type = "test"
  260. apiServer, err := NewServer(config.API.Server)
  261. assert.Equal(t, apiServer, &APIServer{})
  262. assert.Equal(t, "unable to init database client: unknown database type", err.Error())
  263. }
  264. func TestWithWrongFlushConfig(t *testing.T) {
  265. config := LoadTestConfig()
  266. maxItems := -1
  267. config.API.Server.DbConfig.Flush.MaxItems = &maxItems
  268. apiServer, err := NewServer(config.API.Server)
  269. assert.Equal(t, apiServer, &APIServer{})
  270. assert.Equal(t, "max_items can't be zero or negative number", err.Error())
  271. }
  272. func TestUnknownPath(t *testing.T) {
  273. router, _, err := NewAPITest()
  274. if err != nil {
  275. log.Fatalf("unable to run local API: %s", err)
  276. }
  277. w := httptest.NewRecorder()
  278. req, _ := http.NewRequest("GET", "/test", nil)
  279. req.Header.Set("User-Agent", UserAgent)
  280. router.ServeHTTP(w, req)
  281. assert.Equal(t, 404, w.Code)
  282. }
  283. /*
  284. ListenURI string `yaml:"listen_uri,omitempty"` //127.0.0.1:8080
  285. TLS *TLSCfg `yaml:"tls"`
  286. DbConfig *DatabaseCfg `yaml:"-"`
  287. LogDir string `yaml:"-"`
  288. LogMedia string `yaml:"-"`
  289. OnlineClient *OnlineApiClientCfg `yaml:"online_client"`
  290. ProfilesPath string `yaml:"profiles_path,omitempty"`
  291. Profiles []*ProfileCfg `yaml:"-"`
  292. LogLevel *log.Level `yaml:"log_level"`
  293. UseForwardedForHeaders bool `yaml:"use_forwarded_for_headers,omitempty"`
  294. */
  295. func TestLoggingDebugToFileConfig(t *testing.T) {
  296. /*declare settings*/
  297. maxAge := "1h"
  298. flushConfig := csconfig.FlushDBCfg{
  299. MaxAge: &maxAge,
  300. }
  301. tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
  302. dbconfig := csconfig.DatabaseCfg{
  303. Type: "sqlite",
  304. DbPath: filepath.Join(tempDir, "ent"),
  305. Flush: &flushConfig,
  306. }
  307. cfg := csconfig.LocalApiServerCfg{
  308. ListenURI: "127.0.0.1:8080",
  309. LogMedia: "file",
  310. LogDir: tempDir,
  311. DbConfig: &dbconfig,
  312. }
  313. lvl := log.DebugLevel
  314. expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
  315. expectedLines := []string{"/test42"}
  316. cfg.LogLevel = &lvl
  317. // Configure logging
  318. if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs); err != nil {
  319. t.Fatal(err.Error())
  320. }
  321. api, err := NewServer(&cfg)
  322. if err != nil {
  323. t.Fatalf("failed to create api : %s", err)
  324. }
  325. if api == nil {
  326. t.Fatalf("failed to create api #2 is nbill")
  327. }
  328. w := httptest.NewRecorder()
  329. req, _ := http.NewRequest("GET", "/test42", nil)
  330. req.Header.Set("User-Agent", UserAgent)
  331. api.router.ServeHTTP(w, req)
  332. assert.Equal(t, 404, w.Code)
  333. //wait for the request to happen
  334. time.Sleep(500 * time.Millisecond)
  335. //check file content
  336. data, err := os.ReadFile(expectedFile)
  337. if err != nil {
  338. t.Fatalf("failed to read file : %s", err)
  339. }
  340. for _, expectedStr := range expectedLines {
  341. if !strings.Contains(string(data), expectedStr) {
  342. t.Fatalf("expected %s in %s", expectedStr, string(data))
  343. }
  344. }
  345. }
  346. func TestLoggingErrorToFileConfig(t *testing.T) {
  347. /*declare settings*/
  348. maxAge := "1h"
  349. flushConfig := csconfig.FlushDBCfg{
  350. MaxAge: &maxAge,
  351. }
  352. tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
  353. dbconfig := csconfig.DatabaseCfg{
  354. Type: "sqlite",
  355. DbPath: filepath.Join(tempDir, "ent"),
  356. Flush: &flushConfig,
  357. }
  358. cfg := csconfig.LocalApiServerCfg{
  359. ListenURI: "127.0.0.1:8080",
  360. LogMedia: "file",
  361. LogDir: tempDir,
  362. DbConfig: &dbconfig,
  363. }
  364. lvl := log.ErrorLevel
  365. expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
  366. cfg.LogLevel = &lvl
  367. // Configure logging
  368. if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs); err != nil {
  369. t.Fatal(err.Error())
  370. }
  371. api, err := NewServer(&cfg)
  372. if err != nil {
  373. t.Fatalf("failed to create api : %s", err)
  374. }
  375. if api == nil {
  376. t.Fatalf("failed to create api #2 is nbill")
  377. }
  378. w := httptest.NewRecorder()
  379. req, _ := http.NewRequest("GET", "/test42", nil)
  380. req.Header.Set("User-Agent", UserAgent)
  381. api.router.ServeHTTP(w, req)
  382. assert.Equal(t, 404, w.Code)
  383. //wait for the request to happen
  384. time.Sleep(500 * time.Millisecond)
  385. //check file content
  386. x, err := os.ReadFile(expectedFile)
  387. if err == nil && len(x) > 0 {
  388. t.Fatalf("file should be empty, got '%s'", x)
  389. }
  390. os.Remove("./crowdsec.log")
  391. os.Remove(expectedFile)
  392. }