260332c726
* Add use_forwarded_for_headers configuration option for LAPI * update documentation
216 lines
5.6 KiB
Go
216 lines
5.6 KiB
Go
package apiserver
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
|
|
middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
|
|
"github.com/crowdsecurity/crowdsec/pkg/cwversion"
|
|
"github.com/crowdsecurity/crowdsec/pkg/models"
|
|
"github.com/go-openapi/strfmt"
|
|
|
|
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
|
"github.com/crowdsecurity/crowdsec/pkg/database"
|
|
"github.com/gin-gonic/gin"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
var testMachineID = "test"
|
|
var testPassword = strfmt.Password("test")
|
|
var MachineTest = models.WatcherAuthRequest{
|
|
MachineID: &testMachineID,
|
|
Password: &testPassword,
|
|
}
|
|
|
|
var UserAgent = fmt.Sprintf("crowdsec-test/%s", cwversion.Version)
|
|
|
|
func LoadTestConfig() csconfig.GlobalConfig {
|
|
config := csconfig.GlobalConfig{}
|
|
maxAge := "1h"
|
|
flushConfig := csconfig.FlushDBCfg{
|
|
MaxAge: &maxAge,
|
|
}
|
|
dbconfig := csconfig.DatabaseCfg{
|
|
Type: "sqlite",
|
|
DbPath: "./ent",
|
|
Flush: &flushConfig,
|
|
}
|
|
apiServerConfig := csconfig.LocalApiServerCfg{
|
|
ListenURI: "http://127.0.0.1:8080",
|
|
DbConfig: &dbconfig,
|
|
ProfilesPath: "./tests/profiles.yaml",
|
|
}
|
|
apiConfig := csconfig.APICfg{
|
|
Server: &apiServerConfig,
|
|
}
|
|
config.API = &apiConfig
|
|
if err := config.API.Server.LoadProfiles(); err != nil {
|
|
log.Fatalf("failed to load profiles: %s", err)
|
|
}
|
|
return config
|
|
}
|
|
|
|
func LoadTestConfigForwardedFor() csconfig.GlobalConfig {
|
|
config := csconfig.GlobalConfig{}
|
|
maxAge := "1h"
|
|
flushConfig := csconfig.FlushDBCfg{
|
|
MaxAge: &maxAge,
|
|
}
|
|
dbconfig := csconfig.DatabaseCfg{
|
|
Type: "sqlite",
|
|
DbPath: "./ent",
|
|
Flush: &flushConfig,
|
|
}
|
|
apiServerConfig := csconfig.LocalApiServerCfg{
|
|
ListenURI: "http://127.0.0.1:8080",
|
|
DbConfig: &dbconfig,
|
|
ProfilesPath: "./tests/profiles.yaml",
|
|
UseForwardedForHeaders: true,
|
|
}
|
|
apiConfig := csconfig.APICfg{
|
|
Server: &apiServerConfig,
|
|
}
|
|
config.API = &apiConfig
|
|
if err := config.API.Server.LoadProfiles(); err != nil {
|
|
log.Fatalf("failed to load profiles: %s", err)
|
|
}
|
|
return config
|
|
}
|
|
|
|
func NewAPITest() (*gin.Engine, error) {
|
|
config := LoadTestConfig()
|
|
|
|
os.Remove("./ent")
|
|
apiServer, err := NewServer(config.API.Server)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to run local API: %s", err)
|
|
}
|
|
log.Printf("Creating new API server")
|
|
gin.SetMode(gin.TestMode)
|
|
router, err := apiServer.Router()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to run local API: %s", err)
|
|
}
|
|
return router, nil
|
|
}
|
|
|
|
func NewAPITestForwardedFor() (*gin.Engine, error) {
|
|
config := LoadTestConfigForwardedFor()
|
|
|
|
os.Remove("./ent")
|
|
apiServer, err := NewServer(config.API.Server)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to run local API: %s", err)
|
|
}
|
|
log.Printf("Creating new API server")
|
|
gin.SetMode(gin.TestMode)
|
|
router, err := apiServer.Router()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to run local API: %s", err)
|
|
}
|
|
return router, nil
|
|
}
|
|
|
|
func ValidateMachine(machineID string) error {
|
|
config := LoadTestConfig()
|
|
dbClient, err := database.NewClient(config.API.Server.DbConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to create new database client: %s", err)
|
|
}
|
|
if err := dbClient.ValidateMachine(machineID); err != nil {
|
|
return fmt.Errorf("unable to validate machine: %s", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func GetMachineIP(machineID string) (string, error) {
|
|
config := LoadTestConfig()
|
|
dbClient, err := database.NewClient(config.API.Server.DbConfig)
|
|
if err != nil {
|
|
return "", fmt.Errorf("unable to create new database client: %s", err)
|
|
}
|
|
machines, err := dbClient.ListMachines()
|
|
if err != nil {
|
|
return "", fmt.Errorf("Unable to list machines: %s", err)
|
|
}
|
|
for _, machine := range machines {
|
|
if machine.MachineId == machineID {
|
|
return machine.IpAddress, nil
|
|
}
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func CreateTestMachine(router *gin.Engine) (string, error) {
|
|
b, err := json.Marshal(MachineTest)
|
|
if err != nil {
|
|
return "", fmt.Errorf("unable to marshal MachineTest")
|
|
}
|
|
body := string(b)
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("POST", "/v1/watchers", strings.NewReader(body))
|
|
req.Header.Set("User-Agent", UserAgent)
|
|
router.ServeHTTP(w, req)
|
|
return body, nil
|
|
}
|
|
|
|
func CreateTestBouncer() (string, error) {
|
|
config := LoadTestConfig()
|
|
|
|
dbClient, err := database.NewClient(config.API.Server.DbConfig)
|
|
if err != nil {
|
|
log.Fatalf("unable to create new database client: %s", err)
|
|
}
|
|
apiKey, err := middlewares.GenerateAPIKey(keyLength)
|
|
if err != nil {
|
|
return "", fmt.Errorf("unable to generate api key: %s", err)
|
|
}
|
|
err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey))
|
|
if err != nil {
|
|
return "", fmt.Errorf("unable to create blocker: %s", err)
|
|
}
|
|
|
|
return apiKey, nil
|
|
}
|
|
|
|
func TestWithWrongDBConfig(t *testing.T) {
|
|
config := LoadTestConfig()
|
|
config.API.Server.DbConfig.Type = "test"
|
|
apiServer, err := NewServer(config.API.Server)
|
|
|
|
assert.Equal(t, apiServer, &APIServer{})
|
|
assert.Equal(t, "unable to init database client: unknown database type", err.Error())
|
|
}
|
|
|
|
func TestWithWrongFlushConfig(t *testing.T) {
|
|
config := LoadTestConfig()
|
|
maxItems := -1
|
|
config.API.Server.DbConfig.Flush.MaxItems = &maxItems
|
|
apiServer, err := NewServer(config.API.Server)
|
|
|
|
assert.Equal(t, apiServer, &APIServer{})
|
|
assert.Equal(t, "max_items can't be zero or negative number", err.Error())
|
|
}
|
|
|
|
func TestUnknownPath(t *testing.T) {
|
|
router, err := NewAPITest()
|
|
if err != nil {
|
|
log.Fatalf("unable to run local API: %s", err)
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/test", nil)
|
|
req.Header.Set("User-Agent", UserAgent)
|
|
router.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, 404, w.Code)
|
|
|
|
}
|