apiserver_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. package apiserver
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "net/http"
  6. "net/http/httptest"
  7. "os"
  8. "path/filepath"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/crowdsecurity/go-cs-lib/pkg/version"
  13. middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
  14. "github.com/crowdsecurity/crowdsec/pkg/models"
  15. "github.com/crowdsecurity/crowdsec/pkg/types"
  16. "github.com/go-openapi/strfmt"
  17. "github.com/pkg/errors"
  18. "github.com/crowdsecurity/crowdsec/pkg/csconfig"
  19. "github.com/crowdsecurity/crowdsec/pkg/database"
  20. "github.com/gin-gonic/gin"
  21. log "github.com/sirupsen/logrus"
  22. "github.com/stretchr/testify/assert"
  23. )
  24. var testMachineID = "test"
  25. var testPassword = strfmt.Password("test")
  26. var MachineTest = models.WatcherAuthRequest{
  27. MachineID: &testMachineID,
  28. Password: &testPassword,
  29. }
  30. var UserAgent = fmt.Sprintf("crowdsec-test/%s", version.Version)
  31. var emptyBody = strings.NewReader("")
  32. func LoadTestConfig(t *testing.T) csconfig.Config {
  33. config := csconfig.Config{}
  34. maxAge := "1h"
  35. flushConfig := csconfig.FlushDBCfg{
  36. MaxAge: &maxAge,
  37. }
  38. tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
  39. t.Cleanup(func() { os.RemoveAll(tempDir) })
  40. dbconfig := csconfig.DatabaseCfg{
  41. Type: "sqlite",
  42. DbPath: filepath.Join(tempDir, "ent"),
  43. Flush: &flushConfig,
  44. }
  45. apiServerConfig := csconfig.LocalApiServerCfg{
  46. ListenURI: "http://127.0.0.1:8080",
  47. DbConfig: &dbconfig,
  48. ProfilesPath: "./tests/profiles.yaml",
  49. ConsoleConfig: &csconfig.ConsoleConfig{
  50. ShareManualDecisions: new(bool),
  51. ShareTaintedScenarios: new(bool),
  52. ShareCustomScenarios: new(bool),
  53. },
  54. }
  55. apiConfig := csconfig.APICfg{
  56. Server: &apiServerConfig,
  57. }
  58. config.API = &apiConfig
  59. if err := config.API.Server.LoadProfiles(); err != nil {
  60. log.Fatalf("failed to load profiles: %s", err)
  61. }
  62. return config
  63. }
  64. func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config {
  65. config := csconfig.Config{}
  66. maxAge := "1h"
  67. flushConfig := csconfig.FlushDBCfg{
  68. MaxAge: &maxAge,
  69. }
  70. tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
  71. t.Cleanup(func() { os.RemoveAll(tempDir) })
  72. dbconfig := csconfig.DatabaseCfg{
  73. Type: "sqlite",
  74. DbPath: filepath.Join(tempDir, "ent"),
  75. Flush: &flushConfig,
  76. }
  77. apiServerConfig := csconfig.LocalApiServerCfg{
  78. ListenURI: "http://127.0.0.1:8080",
  79. DbConfig: &dbconfig,
  80. ProfilesPath: "./tests/profiles.yaml",
  81. UseForwardedForHeaders: true,
  82. TrustedProxies: &[]string{"0.0.0.0/0"},
  83. ConsoleConfig: &csconfig.ConsoleConfig{
  84. ShareManualDecisions: new(bool),
  85. ShareTaintedScenarios: new(bool),
  86. ShareCustomScenarios: new(bool),
  87. },
  88. }
  89. apiConfig := csconfig.APICfg{
  90. Server: &apiServerConfig,
  91. }
  92. config.API = &apiConfig
  93. if err := config.API.Server.LoadProfiles(); err != nil {
  94. log.Fatalf("failed to load profiles: %s", err)
  95. }
  96. return config
  97. }
  98. func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config, error) {
  99. config := LoadTestConfig(t)
  100. os.Remove("./ent")
  101. apiServer, err := NewServer(config.API.Server)
  102. if err != nil {
  103. return nil, config, fmt.Errorf("unable to run local API: %s", err)
  104. }
  105. log.Printf("Creating new API server")
  106. gin.SetMode(gin.TestMode)
  107. return apiServer, config, nil
  108. }
  109. func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config, error) {
  110. apiServer, config, err := NewAPIServer(t)
  111. if err != nil {
  112. return nil, config, fmt.Errorf("unable to run local API: %s", err)
  113. }
  114. err = apiServer.InitController()
  115. if err != nil {
  116. return nil, config, fmt.Errorf("unable to run local API: %s", err)
  117. }
  118. router, err := apiServer.Router()
  119. if err != nil {
  120. return nil, config, fmt.Errorf("unable to run local API: %s", err)
  121. }
  122. return router, config, nil
  123. }
  124. func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config, error) {
  125. config := LoadTestConfigForwardedFor(t)
  126. os.Remove("./ent")
  127. apiServer, err := NewServer(config.API.Server)
  128. if err != nil {
  129. return nil, config, fmt.Errorf("unable to run local API: %s", err)
  130. }
  131. err = apiServer.InitController()
  132. if err != nil {
  133. return nil, config, fmt.Errorf("unable to run local API: %s", err)
  134. }
  135. log.Printf("Creating new API server")
  136. gin.SetMode(gin.TestMode)
  137. router, err := apiServer.Router()
  138. if err != nil {
  139. return nil, config, fmt.Errorf("unable to run local API: %s", err)
  140. }
  141. return router, config, nil
  142. }
  143. func ValidateMachine(machineID string, config *csconfig.DatabaseCfg) error {
  144. dbClient, err := database.NewClient(config)
  145. if err != nil {
  146. return fmt.Errorf("unable to create new database client: %s", err)
  147. }
  148. if err := dbClient.ValidateMachine(machineID); err != nil {
  149. return fmt.Errorf("unable to validate machine: %s", err)
  150. }
  151. return nil
  152. }
  153. func GetMachineIP(machineID string, config *csconfig.DatabaseCfg) (string, error) {
  154. dbClient, err := database.NewClient(config)
  155. if err != nil {
  156. return "", fmt.Errorf("unable to create new database client: %s", err)
  157. }
  158. machines, err := dbClient.ListMachines()
  159. if err != nil {
  160. return "", fmt.Errorf("Unable to list machines: %s", err)
  161. }
  162. for _, machine := range machines {
  163. if machine.MachineId == machineID {
  164. return machine.IpAddress, nil
  165. }
  166. }
  167. return "", nil
  168. }
  169. func GetAlertReaderFromFile(path string) *strings.Reader {
  170. alertContentBytes, err := os.ReadFile(path)
  171. if err != nil {
  172. log.Fatal(err)
  173. }
  174. alerts := make([]*models.Alert, 0)
  175. if err := json.Unmarshal(alertContentBytes, &alerts); err != nil {
  176. log.Fatal(err)
  177. }
  178. for _, alert := range alerts {
  179. *alert.StartAt = time.Now().UTC().Format(time.RFC3339)
  180. *alert.StopAt = time.Now().UTC().Format(time.RFC3339)
  181. }
  182. alertContent, err := json.Marshal(alerts)
  183. if err != nil {
  184. log.Fatal(err)
  185. }
  186. return strings.NewReader(string(alertContent))
  187. }
  188. func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision, int, error) {
  189. var response []*models.Decision
  190. if resp == nil {
  191. return nil, 0, errors.New("response is nil")
  192. }
  193. err := json.Unmarshal(resp.Body.Bytes(), &response)
  194. if err != nil {
  195. return nil, resp.Code, err
  196. }
  197. return response, resp.Code, nil
  198. }
  199. func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string, int, error) {
  200. var response map[string]string
  201. if resp == nil {
  202. return nil, 0, errors.New("response is nil")
  203. }
  204. err := json.Unmarshal(resp.Body.Bytes(), &response)
  205. if err != nil {
  206. return nil, resp.Code, err
  207. }
  208. return response, resp.Code, nil
  209. }
  210. func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int, error) {
  211. var response models.DeleteDecisionResponse
  212. if resp == nil {
  213. return nil, 0, errors.New("response is nil")
  214. }
  215. err := json.Unmarshal(resp.Body.Bytes(), &response)
  216. if err != nil {
  217. return nil, resp.Code, err
  218. }
  219. return &response, resp.Code, nil
  220. }
  221. func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int, error) {
  222. response := make(map[string][]*models.Decision)
  223. if resp == nil {
  224. return nil, 0, errors.New("response is nil")
  225. }
  226. err := json.Unmarshal(resp.Body.Bytes(), &response)
  227. if err != nil {
  228. return nil, resp.Code, err
  229. }
  230. return response, resp.Code, nil
  231. }
  232. func CreateTestMachine(router *gin.Engine) (string, error) {
  233. b, err := json.Marshal(MachineTest)
  234. if err != nil {
  235. return "", fmt.Errorf("unable to marshal MachineTest")
  236. }
  237. body := string(b)
  238. w := httptest.NewRecorder()
  239. req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body))
  240. req.Header.Set("User-Agent", UserAgent)
  241. router.ServeHTTP(w, req)
  242. return body, nil
  243. }
  244. func CreateTestBouncer(config *csconfig.DatabaseCfg) (string, error) {
  245. dbClient, err := database.NewClient(config)
  246. if err != nil {
  247. log.Fatalf("unable to create new database client: %s", err)
  248. }
  249. apiKey, err := middlewares.GenerateAPIKey(keyLength)
  250. if err != nil {
  251. return "", fmt.Errorf("unable to generate api key: %s", err)
  252. }
  253. _, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType)
  254. if err != nil {
  255. return "", fmt.Errorf("unable to create blocker: %s", err)
  256. }
  257. return apiKey, nil
  258. }
  259. func TestWithWrongDBConfig(t *testing.T) {
  260. config := LoadTestConfig(t)
  261. config.API.Server.DbConfig.Type = "test"
  262. apiServer, err := NewServer(config.API.Server)
  263. assert.Equal(t, apiServer, &APIServer{})
  264. assert.Equal(t, "unable to init database client: unknown database type 'test'", err.Error())
  265. }
  266. func TestWithWrongFlushConfig(t *testing.T) {
  267. config := LoadTestConfig(t)
  268. maxItems := -1
  269. config.API.Server.DbConfig.Flush.MaxItems = &maxItems
  270. apiServer, err := NewServer(config.API.Server)
  271. assert.Equal(t, apiServer, &APIServer{})
  272. assert.Equal(t, "max_items can't be zero or negative number", err.Error())
  273. }
  274. func TestUnknownPath(t *testing.T) {
  275. router, _, err := NewAPITest(t)
  276. if err != nil {
  277. log.Fatalf("unable to run local API: %s", err)
  278. }
  279. w := httptest.NewRecorder()
  280. req, _ := http.NewRequest(http.MethodGet, "/test", nil)
  281. req.Header.Set("User-Agent", UserAgent)
  282. router.ServeHTTP(w, req)
  283. assert.Equal(t, 404, w.Code)
  284. }
  285. /*
  286. ListenURI string `yaml:"listen_uri,omitempty"` //127.0.0.1:8080
  287. TLS *TLSCfg `yaml:"tls"`
  288. DbConfig *DatabaseCfg `yaml:"-"`
  289. LogDir string `yaml:"-"`
  290. LogMedia string `yaml:"-"`
  291. OnlineClient *OnlineApiClientCfg `yaml:"online_client"`
  292. ProfilesPath string `yaml:"profiles_path,omitempty"`
  293. Profiles []*ProfileCfg `yaml:"-"`
  294. LogLevel *log.Level `yaml:"log_level"`
  295. UseForwardedForHeaders bool `yaml:"use_forwarded_for_headers,omitempty"`
  296. */
  297. func TestLoggingDebugToFileConfig(t *testing.T) {
  298. /*declare settings*/
  299. maxAge := "1h"
  300. flushConfig := csconfig.FlushDBCfg{
  301. MaxAge: &maxAge,
  302. }
  303. tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
  304. t.Cleanup(func() { os.RemoveAll(tempDir) })
  305. dbconfig := csconfig.DatabaseCfg{
  306. Type: "sqlite",
  307. DbPath: filepath.Join(tempDir, "ent"),
  308. Flush: &flushConfig,
  309. }
  310. cfg := csconfig.LocalApiServerCfg{
  311. ListenURI: "127.0.0.1:8080",
  312. LogMedia: "file",
  313. LogDir: tempDir,
  314. DbConfig: &dbconfig,
  315. }
  316. lvl := log.DebugLevel
  317. expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
  318. expectedLines := []string{"/test42"}
  319. cfg.LogLevel = &lvl
  320. // Configure logging
  321. if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil {
  322. t.Fatal(err)
  323. }
  324. api, err := NewServer(&cfg)
  325. if err != nil {
  326. t.Fatalf("failed to create api : %s", err)
  327. }
  328. if api == nil {
  329. t.Fatalf("failed to create api #2 is nbill")
  330. }
  331. w := httptest.NewRecorder()
  332. req, _ := http.NewRequest(http.MethodGet, "/test42", nil)
  333. req.Header.Set("User-Agent", UserAgent)
  334. api.router.ServeHTTP(w, req)
  335. assert.Equal(t, 404, w.Code)
  336. //wait for the request to happen
  337. time.Sleep(500 * time.Millisecond)
  338. //check file content
  339. data, err := os.ReadFile(expectedFile)
  340. if err != nil {
  341. t.Fatalf("failed to read file : %s", err)
  342. }
  343. for _, expectedStr := range expectedLines {
  344. if !strings.Contains(string(data), expectedStr) {
  345. t.Fatalf("expected %s in %s", expectedStr, string(data))
  346. }
  347. }
  348. }
  349. func TestLoggingErrorToFileConfig(t *testing.T) {
  350. /*declare settings*/
  351. maxAge := "1h"
  352. flushConfig := csconfig.FlushDBCfg{
  353. MaxAge: &maxAge,
  354. }
  355. tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
  356. t.Cleanup(func() { os.RemoveAll(tempDir) })
  357. dbconfig := csconfig.DatabaseCfg{
  358. Type: "sqlite",
  359. DbPath: filepath.Join(tempDir, "ent"),
  360. Flush: &flushConfig,
  361. }
  362. cfg := csconfig.LocalApiServerCfg{
  363. ListenURI: "127.0.0.1:8080",
  364. LogMedia: "file",
  365. LogDir: tempDir,
  366. DbConfig: &dbconfig,
  367. }
  368. lvl := log.ErrorLevel
  369. expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
  370. cfg.LogLevel = &lvl
  371. // Configure logging
  372. if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil {
  373. t.Fatal(err)
  374. }
  375. api, err := NewServer(&cfg)
  376. if err != nil {
  377. t.Fatalf("failed to create api : %s", err)
  378. }
  379. if api == nil {
  380. t.Fatalf("failed to create api #2 is nbill")
  381. }
  382. w := httptest.NewRecorder()
  383. req, _ := http.NewRequest(http.MethodGet, "/test42", nil)
  384. req.Header.Set("User-Agent", UserAgent)
  385. api.router.ServeHTTP(w, req)
  386. assert.Equal(t, 404, w.Code)
  387. //wait for the request to happen
  388. time.Sleep(500 * time.Millisecond)
  389. //check file content
  390. x, err := os.ReadFile(expectedFile)
  391. if err == nil && len(x) > 0 {
  392. t.Fatalf("file should be empty, got '%s'", x)
  393. }
  394. os.Remove("./crowdsec.log")
  395. os.Remove(expectedFile)
  396. }