diff --git a/.golangci.yml b/.golangci.yml index 13d8b4534..eba554d87 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -9,6 +9,13 @@ run: - pkg/yamlpatch/merge_test.go linters-settings: + gci: + sections: + - standard + - default + - prefix(github.com/crowdsecurity) + - prefix(github.com/crowdsecurity/crowdsec) + gocyclo: min-complexity: 30 diff --git a/pkg/apiclient/alerts_service_test.go b/pkg/apiclient/alerts_service_test.go index fcc9bd06a..31a947556 100644 --- a/pkg/apiclient/alerts_service_test.go +++ b/pkg/apiclient/alerts_service_test.go @@ -5,13 +5,14 @@ import ( "fmt" "net/http" "net/url" - "reflect" "testing" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/version" "github.com/crowdsecurity/crowdsec/pkg/models" @@ -25,12 +26,11 @@ func TestAlertsListAsMachine(t *testing.T) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) + log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - log.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) client, err := NewClient(&Config{ MachineID: "test_login", @@ -39,19 +39,16 @@ func TestAlertsListAsMachine(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - log.Fatalf("new api client: %s", err) - } + require.NoError(t, err) defer teardown() mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { - if r.URL.RawQuery == "ip=1.2.3.4" { testMethod(t, r, "GET") w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `null`) + return } @@ -107,36 +104,26 @@ func TestAlertsListAsMachine(t *testing.T) { ]`) }) - tcapacity := int32(5) - tduration := "59m49.264032632s" - torigin := "crowdsec" tscenario := "crowdsecurity/ssh-bf" tscope := "Ip" - ttype := "ban" tvalue := "1.1.1.172" ttimestamp := "2020-11-28 10:20:46 +0000 UTC" - teventscount := int32(6) - tleakspeed := "10s" tmessage := "Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761" - tscenariohash := "4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f" - tscenarioversion := "0.1" - tstartat := "2020-11-28 10:20:46.842701127 +0100 +0100" - tstopat := "2020-11-28 10:20:46.845621385 +0100 +0100" expected := models.GetAlertsResponse{ &models.Alert{ - Capacity: &tcapacity, + Capacity: ptr.Of(int32(5)), CreatedAt: "2020-11-28T10:20:47+01:00", Decisions: []*models.Decision{ { - Duration: &tduration, + Duration: ptr.Of("59m49.264032632s"), ID: 1, - Origin: &torigin, + Origin: ptr.Of("crowdsec"), Scenario: &tscenario, Scope: &tscope, - Simulated: new(bool), //false, - Type: &ttype, + Simulated: ptr.Of(false), + Type: ptr.Of("ban"), Value: &tvalue, }, }, @@ -167,16 +154,16 @@ func TestAlertsListAsMachine(t *testing.T) { Timestamp: &ttimestamp, }, }, - EventsCount: &teventscount, + EventsCount: ptr.Of(int32(6)), ID: 1, - Leakspeed: &tleakspeed, + Leakspeed: ptr.Of("10s"), MachineID: "test", Message: &tmessage, Remediation: false, Scenario: &tscenario, - ScenarioHash: &tscenariohash, - ScenarioVersion: &tscenarioversion, - Simulated: new(bool), //(false), + ScenarioHash: ptr.Of("4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f"), + ScenarioVersion: ptr.Of("0.1"), + Simulated: ptr.Of(false), Source: &models.Source{ AsName: "Cloudflare Inc", AsNumber: "", @@ -188,8 +175,8 @@ func TestAlertsListAsMachine(t *testing.T) { Scope: &tscope, Value: &tvalue, }, - StartAt: &tstartat, - StopAt: &tstopat, + StartAt: ptr.Of("2020-11-28 10:20:46.842701127 +0100 +0100"), + StopAt: ptr.Of("2020-11-28 10:20:46.845621385 +0100 +0100"), }, } @@ -198,30 +185,16 @@ func TestAlertsListAsMachine(t *testing.T) { //log.Debugf("expected : -> %s", spew.Sdump(expected)) //first one returns data alerts, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{}) - if err != nil { - log.Errorf("test Unable to list alerts : %+v", err) - } + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, expected, *alerts) - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if !reflect.DeepEqual(*alerts, expected) { - t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) - } //this one doesn't - filter := AlertsListOpts{IPEquals: new(string)} - *filter.IPEquals = "1.2.3.4" + filter := AlertsListOpts{IPEquals: ptr.Of("1.2.3.4")} alerts, resp, err = client.Alerts.List(context.Background(), filter) - if err != nil { - log.Errorf("test Unable to list alerts : %+v", err) - } - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Empty(t, *alerts) } @@ -236,9 +209,7 @@ func TestAlertsGetAsMachine(t *testing.T) { log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - log.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) client, err := NewClient(&Config{ MachineID: "test_login", @@ -247,12 +218,10 @@ func TestAlertsGetAsMachine(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - log.Fatalf("new api client: %s", err) - } + require.NoError(t, err) defer teardown() + mux.HandleFunc("/alerts/2", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") w.WriteHeader(http.StatusNotFound) @@ -312,34 +281,24 @@ func TestAlertsGetAsMachine(t *testing.T) { }`) }) - tcapacity := int32(5) - tduration := "59m49.264032632s" - torigin := "crowdsec" tscenario := "crowdsecurity/ssh-bf" tscope := "Ip" ttype := "ban" tvalue := "1.1.1.172" ttimestamp := "2020-11-28 10:20:46 +0000 UTC" - teventscount := int32(6) - tleakspeed := "10s" - tmessage := "Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761" - tscenariohash := "4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f" - tscenarioversion := "0.1" - tstartat := "2020-11-28 10:20:46.842701127 +0100 +0100" - tstopat := "2020-11-28 10:20:46.845621385 +0100 +0100" expected := &models.Alert{ - Capacity: &tcapacity, + Capacity: ptr.Of(int32(5)), CreatedAt: "2020-11-28T10:20:47+01:00", Decisions: []*models.Decision{ { - Duration: &tduration, + Duration: ptr.Of("59m49.264032632s"), ID: 1, - Origin: &torigin, + Origin: ptr.Of("crowdsec"), Scenario: &tscenario, Scope: &tscope, - Simulated: new(bool), //false, + Simulated: ptr.Of(false), Type: &ttype, Value: &tvalue, }, @@ -371,16 +330,16 @@ func TestAlertsGetAsMachine(t *testing.T) { Timestamp: &ttimestamp, }, }, - EventsCount: &teventscount, + EventsCount: ptr.Of(int32(6)), ID: 1, - Leakspeed: &tleakspeed, + Leakspeed: ptr.Of("10s"), MachineID: "test", - Message: &tmessage, + Message: ptr.Of("Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761"), Remediation: false, Scenario: &tscenario, - ScenarioHash: &tscenariohash, - ScenarioVersion: &tscenarioversion, - Simulated: new(bool), //(false), + ScenarioHash: ptr.Of("4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f"), + ScenarioVersion: ptr.Of("0.1"), + Simulated: ptr.Of(false), Source: &models.Source{ AsName: "Cloudflare Inc", AsNumber: "", @@ -392,24 +351,18 @@ func TestAlertsGetAsMachine(t *testing.T) { Scope: &tscope, Value: &tvalue, }, - StartAt: &tstartat, - StopAt: &tstopat, + StartAt: ptr.Of("2020-11-28 10:20:46.842701127 +0100 +0100"), + StopAt: ptr.Of("2020-11-28 10:20:46.845621385 +0100 +0100"), } alerts, resp, err := client.Alerts.GetByID(context.Background(), 1) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if !reflect.DeepEqual(*alerts, *expected) { - t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *alerts) //fail _, _, err = client.Alerts.GetByID(context.Background(), 2) - assert.Contains(t, fmt.Sprintf("%s", err), "API error: object not found") + cstest.RequireErrorMessage(t, err, "API error: object not found") } func TestAlertsCreateAsMachine(t *testing.T) { @@ -420,17 +373,17 @@ func TestAlertsCreateAsMachine(t *testing.T) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) + mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") w.WriteHeader(http.StatusOK) w.Write([]byte(`["3"]`)) }) + log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - log.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) client, err := NewClient(&Config{ MachineID: "test_login", @@ -439,10 +392,7 @@ func TestAlertsCreateAsMachine(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - log.Fatalf("new api client: %s", err) - } + require.NoError(t, err) defer teardown() @@ -452,13 +402,8 @@ func TestAlertsCreateAsMachine(t *testing.T) { expected := &models.AddAlertsResponse{"3"} - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if !reflect.DeepEqual(*alerts, *expected) { - t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *alerts) } func TestAlertsDeleteAsMachine(t *testing.T) { @@ -469,18 +414,18 @@ func TestAlertsDeleteAsMachine(t *testing.T) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) + mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "DELETE") assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) w.WriteHeader(http.StatusOK) w.Write([]byte(`{"message":"0 deleted alerts"}`)) }) + log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - log.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) client, err := NewClient(&Config{ MachineID: "test_login", @@ -489,25 +434,16 @@ func TestAlertsDeleteAsMachine(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - log.Fatalf("new api client: %s", err) - } + require.NoError(t, err) defer teardown() - alert := AlertsDeleteOpts{IPEquals: new(string)} - *alert.IPEquals = "1.2.3.4" + alert := AlertsDeleteOpts{IPEquals: ptr.Of("1.2.3.4")} alerts, resp, err := client.Alerts.Delete(context.Background(), alert) require.NoError(t, err) expected := &models.DeleteAlertsResponse{NbDeleted: ""} - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if !reflect.DeepEqual(*alerts, *expected) { - t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *alerts) } diff --git a/pkg/apiclient/auth_service_test.go b/pkg/apiclient/auth_service_test.go index b56d52868..f5de827a1 100644 --- a/pkg/apiclient/auth_service_test.go +++ b/pkg/apiclient/auth_service_test.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/crowdsecurity/go-cs-lib/version" @@ -24,13 +25,11 @@ type BasicMockPayload struct { } func getLoginsForMockErrorCases() map[string]int { - loginsForMockErrorCases := map[string]int{ + return map[string]int{ "login_400": http.StatusBadRequest, "login_409": http.StatusConflict, "login_500": http.StatusInternalServerError, } - - return loginsForMockErrorCases } func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { @@ -49,7 +48,7 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { w.WriteHeader(http.StatusBadRequest) } - responseBody := "" + var responseBody string responseCode, hasFoundErrorMock := loginsForMockErrorCases[payload.MachineID] if !hasFoundErrorMock { @@ -58,6 +57,7 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { } else { responseBody = fmt.Sprintf("Error %d", responseCode) } + log.Printf("MockServerReceived > %s // Login : [%s] => Mux response [%d]", newStr, payload.MachineID, responseCode) w.WriteHeader(responseCode) @@ -76,14 +76,13 @@ func TestWatcherRegister(t *testing.T) { mux, urlx, teardown := setup() defer teardown() + //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} initBasicMuxMock(t, mux, "/watchers") log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) // Valid Registration : should retrieve the client and no err clientconfig := Config{ @@ -95,9 +94,7 @@ func TestWatcherRegister(t *testing.T) { } client, err := RegisterClient(&clientconfig, &http.Client{}) - if client == nil || err != nil { - t.Fatalf("while registering client : %s", err) - } + require.NoError(t, err) log.Printf("->%T", client) @@ -107,11 +104,8 @@ func TestWatcherRegister(t *testing.T) { clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest) client, err = RegisterClient(&clientconfig, &http.Client{}) - if client != nil || err == nil { - t.Fatalf("The RegisterClient function should have returned an error for the response code %d", errorCodeToTest) - } else { - log.Printf("The RegisterClient function handled the error code %d as expected \n\r", errorCodeToTest) - } + require.Nil(t, client, "nil expected for the response code %d", errorCodeToTest) + require.Error(t, err, "error expected for the response code %d", errorCodeToTest) } } @@ -126,9 +120,7 @@ func TestWatcherAuth(t *testing.T) { log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) //ok auth clientConfig := &Config{ @@ -139,34 +131,27 @@ func TestWatcherAuth(t *testing.T) { VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, } - client, err := NewClient(clientConfig) - if err != nil { - t.Fatalf("new api client: %s", err) - } + client, err := NewClient(clientConfig) + require.NoError(t, err) _, _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ MachineID: &clientConfig.MachineID, Password: &clientConfig.Password, Scenarios: clientConfig.Scenarios, }) - if err != nil { - t.Fatalf("unexpect auth err 0: %s", err) - } + require.NoError(t, err) // Testing error handling on AuthenticateWatcher (400, 409): should retrieve an error // Not testing 500 because it loops and try to re-autehnticate. But you can test it manually by adding it in array errorCodesToTest := [2]int{http.StatusBadRequest, http.StatusConflict} for _, errorCodeToTest := range errorCodesToTest { clientConfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest) + client, err := NewClient(clientConfig) + require.NoError(t, err) - if err != nil { - t.Fatalf("new api client: %s", err) - } - - var resp *Response - _, resp, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + _, resp, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ MachineID: &clientConfig.MachineID, Password: &clientConfig.Password, }) @@ -175,9 +160,7 @@ func TestWatcherAuth(t *testing.T) { resp.Response.Body.Close() bodyBytes, err := io.ReadAll(resp.Response.Body) - if err != nil { - t.Fatalf("error while reading body: %s", err.Error()) - } + require.NoError(t, err) log.Printf(string(bodyBytes)) t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest) @@ -199,10 +182,12 @@ func TestWatcherUnregister(t *testing.T) { assert.Equal(t, int64(0), r.ContentLength) w.WriteHeader(http.StatusOK) }) + mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") buf := new(bytes.Buffer) _, _ = buf.ReadFrom(r.Body) + newStr := buf.String() if newStr == `{"machine_id":"test_login","password":"test_password","scenarios":["crowdsecurity/test"]} ` { @@ -217,9 +202,7 @@ func TestWatcherUnregister(t *testing.T) { log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) mycfg := &Config{ MachineID: "test_login", @@ -229,16 +212,12 @@ func TestWatcherUnregister(t *testing.T) { VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, } - client, err := NewClient(mycfg) - if err != nil { - t.Fatalf("new api client: %s", err) - } + client, err := NewClient(mycfg) + require.NoError(t, err) _, err = client.Auth.UnregisterWatcher(context.Background()) - if err != nil { - t.Fatalf("while registering client : %s", err) - } + require.NoError(t, err) log.Printf("->%T", client) } @@ -255,6 +234,7 @@ func TestWatcherEnroll(t *testing.T) { _, _ = buf.ReadFrom(r.Body) newStr := buf.String() log.Debugf("body -> %s", newStr) + if newStr == `{"attachment_key":"goodkey","name":"","tags":[],"overwrite":false} ` { log.Print("good key") @@ -266,17 +246,17 @@ func TestWatcherEnroll(t *testing.T) { fmt.Fprintf(w, `{"message":"the attachment key provided is not valid"}`) } }) + mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`) }) + log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) mycfg := &Config{ MachineID: "test_login", @@ -286,16 +266,12 @@ func TestWatcherEnroll(t *testing.T) { VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, } - client, err := NewClient(mycfg) - if err != nil { - t.Fatalf("new api client: %s", err) - } + client, err := NewClient(mycfg) + require.NoError(t, err) _, err = client.Auth.EnrollWatcher(context.Background(), "goodkey", "", []string{}, false) - if err != nil { - t.Fatalf("unexpect enroll err: %s", err) - } + require.NoError(t, err) _, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false) assert.Contains(t, err.Error(), "the attachment key provided is not valid", "got %s", err.Error()) diff --git a/pkg/apiclient/auth_test.go b/pkg/apiclient/auth_test.go index 7e7377a43..f686de622 100644 --- a/pkg/apiclient/auth_test.go +++ b/pkg/apiclient/auth_test.go @@ -9,6 +9,9 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" ) func TestApiAuth(t *testing.T) { @@ -17,6 +20,7 @@ func TestApiAuth(t *testing.T) { mux, urlx, teardown := setup() mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") + if r.Header.Get("X-Api-Key") == "ixu" { assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) w.WriteHeader(http.StatusOK) @@ -26,11 +30,11 @@ func TestApiAuth(t *testing.T) { w.Write([]byte(`{"message":"access forbidden"}`)) } }) + log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) defer teardown() @@ -40,18 +44,12 @@ func TestApiAuth(t *testing.T) { } newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } - - alert := DecisionsListOpts{IPEquals: new(string)} - *alert.IPEquals = "1.2.3.4" - _, resp, err := newcli.Decisions.List(context.Background(), alert) require.NoError(t, err) - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } + alert := DecisionsListOpts{IPEquals: ptr.Of("1.2.3.4")} + _, resp, err := newcli.Decisions.List(context.Background(), alert) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) //ko bad token auth = &APIKeyTransport{ @@ -59,25 +57,21 @@ func TestApiAuth(t *testing.T) { } newcli, err = NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) _, resp, err = newcli.Decisions.List(context.Background(), alert) log.Infof("--> %s", err) - if resp.Response.StatusCode != http.StatusForbidden { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } + assert.Equal(t, http.StatusForbidden, resp.Response.StatusCode) + + cstest.RequireErrorMessage(t, err, "API error: access forbidden") - assert.Contains(t, err.Error(), "API error: access forbidden") //ko empty token auth = &APIKeyTransport{} + newcli, err = NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) _, _, err = newcli.Decisions.List(context.Background(), alert) require.Error(t, err) diff --git a/pkg/apiclient/client_http_test.go b/pkg/apiclient/client_http_test.go index fa25ee171..a7582eaf4 100644 --- a/pkg/apiclient/client_http_test.go +++ b/pkg/apiclient/client_http_test.go @@ -8,19 +8,19 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/version" ) func TestNewRequestInvalid(t *testing.T) { mux, urlx, teardown := setup() defer teardown() + //missing slash in uri apiURL, err := url.Parse(urlx) - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) client, err := NewClient(&Config{ MachineID: "test_login", @@ -29,9 +29,8 @@ func TestNewRequestInvalid(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) + /*mock login*/ mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) @@ -44,17 +43,16 @@ func TestNewRequestInvalid(t *testing.T) { }) _, _, err = client.Alerts.List(context.Background(), AlertsListOpts{}) - assert.Contains(t, err.Error(), `building request: BaseURL must have a trailing slash, but `) + cstest.RequireErrorContains(t, err, "building request: BaseURL must have a trailing slash, but ") } func TestNewRequestTimeout(t *testing.T) { mux, urlx, teardown := setup() defer teardown() - //missing slash in uri + + // missing slash in uri apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) client, err := NewClient(&Config{ MachineID: "test_login", @@ -63,9 +61,8 @@ func TestNewRequestTimeout(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) + /*mock login*/ mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { time.Sleep(2 * time.Second) @@ -75,5 +72,5 @@ func TestNewRequestTimeout(t *testing.T) { defer cancel() _, _, err = client.Alerts.List(ctx, AlertsListOpts{}) - assert.Contains(t, err.Error(), `performing request: context deadline exceeded`) + cstest.RequireErrorMessage(t, err, "performing request: context deadline exceeded") } diff --git a/pkg/apiclient/client_test.go b/pkg/apiclient/client_test.go index a75b3dd41..dc6eae169 100644 --- a/pkg/apiclient/client_test.go +++ b/pkg/apiclient/client_test.go @@ -11,7 +11,9 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/version" ) @@ -20,13 +22,13 @@ import ( - each test will then bind handler for the method(s) they want to try */ -func setup() (mux *http.ServeMux, serverURL string, teardown func()) { +func setup() (*http.ServeMux, string, func()) { return setupWithPrefix("v1") } -func setupWithPrefix(urlPrefix string) (mux *http.ServeMux, serverURL string, teardown func()) { +func setupWithPrefix(urlPrefix string) (*http.ServeMux, string, func()) { // mux is the HTTP request multiplexer used with the test server. - mux = http.NewServeMux() + mux := http.NewServeMux() baseURLPath := "/" + urlPrefix apiHandler := http.NewServeMux() @@ -40,19 +42,16 @@ func setupWithPrefix(urlPrefix string) (mux *http.ServeMux, serverURL string, te func testMethod(t *testing.T, r *http.Request, want string) { t.Helper() - - if got := r.Method; got != want { - t.Errorf("Request method: %v, want %v", got, want) - } + assert.Equal(t, want, r.Method) } func TestNewClientOk(t *testing.T) { mux, urlx, teardown := setup() defer teardown() + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -60,9 +59,8 @@ func TestNewClientOk(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) + /*mock login*/ mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -75,22 +73,17 @@ func TestNewClientOk(t *testing.T) { }) _, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{}) - if err != nil { - t.Fatalf("test Unable to list alerts : %+v", err) - } - - if resp.Response.StatusCode != http.StatusOK { - t.Fatalf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusCreated) - } + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) } func TestNewClientKo(t *testing.T) { mux, urlx, teardown := setup() defer teardown() + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -98,9 +91,8 @@ func TestNewClientKo(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) + /*mock login*/ mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) @@ -113,36 +105,36 @@ func TestNewClientKo(t *testing.T) { }) _, _, err = client.Alerts.List(context.Background(), AlertsListOpts{}) - assert.Contains(t, err.Error(), `API error: bad login/password`) + cstest.RequireErrorContains(t, err, `API error: bad login/password`) + log.Printf("err-> %s", err) } func TestNewDefaultClient(t *testing.T) { mux, urlx, teardown := setup() defer teardown() + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewDefaultClient(apiURL, "/v1", "", nil) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`{"code": 401, "message" : "brr"}`)) }) + _, _, err = client.Alerts.List(context.Background(), AlertsListOpts{}) - assert.Contains(t, err.Error(), `performing request: API error: brr`) + cstest.RequireErrorMessage(t, err, "performing request: API error: brr") + log.Printf("err-> %s", err) } func TestNewClientRegisterKO(t *testing.T) { apiURL, err := url.Parse("http://127.0.0.1:4242/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + _, err = RegisterClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -150,17 +142,18 @@ func TestNewClientRegisterKO(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) + if runtime.GOOS != "windows" { - assert.Contains(t, fmt.Sprintf("%s", err), "dial tcp 127.0.0.1:4242: connect: connection refused") + cstest.RequireErrorContains(t, err, "dial tcp 127.0.0.1:4242: connect: connection refused") } else { - assert.Contains(t, fmt.Sprintf("%s", err), " No connection could be made because the target machine actively refused it.") + cstest.RequireErrorContains(t, err, " No connection could be made because the target machine actively refused it.") } } func TestNewClientRegisterOK(t *testing.T) { log.SetLevel(log.TraceLevel) - mux, urlx, teardown := setup() + mux, urlx, teardown := setup() defer teardown() /*mock login*/ @@ -171,9 +164,8 @@ func TestNewClientRegisterOK(t *testing.T) { }) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := RegisterClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -181,17 +173,15 @@ func TestNewClientRegisterOK(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) - if err != nil { - t.Fatalf("while registering client : %s", err) - } + require.NoError(t, err) log.Printf("->%T", client) } func TestNewClientBadAnswer(t *testing.T) { log.SetLevel(log.TraceLevel) - mux, urlx, teardown := setup() + mux, urlx, teardown := setup() defer teardown() /*mock login*/ @@ -200,10 +190,10 @@ func TestNewClientBadAnswer(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`bad`)) }) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + _, err = RegisterClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -211,5 +201,5 @@ func TestNewClientBadAnswer(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) - assert.Contains(t, fmt.Sprintf("%s", err), `invalid body: invalid character 'b' looking for beginning of value`) + cstest.RequireErrorContains(t, err, "invalid body: invalid character 'b' looking for beginning of value") } diff --git a/pkg/apiclient/decisions_service_test.go b/pkg/apiclient/decisions_service_test.go index e9954d9a1..fb2fb7342 100644 --- a/pkg/apiclient/decisions_service_test.go +++ b/pkg/apiclient/decisions_service_test.go @@ -5,13 +5,13 @@ import ( "fmt" "net/http" "net/url" - "reflect" "testing" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/version" @@ -38,10 +38,9 @@ func TestDecisionsList(t *testing.T) { //no results } }) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) //ok answer auth := &APIKeyTransport{ @@ -49,55 +48,32 @@ func TestDecisionsList(t *testing.T) { } newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tduration := "3h59m55.756182786s" - torigin := "cscli" - tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'" - tscope := "Ip" - ttype := "ban" - tvalue := "1.2.3.4" expected := &models.GetDecisionsResponse{ &models.Decision{ - Duration: &tduration, + Duration: ptr.Of("3h59m55.756182786s"), ID: 4, - Origin: &torigin, - Scenario: &tscenario, - Scope: &tscope, - Type: &ttype, - Value: &tvalue, + Origin: ptr.Of("cscli"), + Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"), + Scope: ptr.Of("Ip"), + Type: ptr.Of("ban"), + Value: ptr.Of("1.2.3.4"), }, } - //OK decisions - decisionsFilter := DecisionsListOpts{IPEquals: new(string)} - *decisionsFilter.IPEquals = "1.2.3.4" + // OK decisions + decisionsFilter := DecisionsListOpts{IPEquals: ptr.Of("1.2.3.4")} decisions, resp, err := newcli.Decisions.List(context.Background(), decisionsFilter) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if err != nil { - t.Fatalf("new api client: %s", err) - } - - if !reflect.DeepEqual(*decisions, *expected) { - t.Fatalf("returned %+v, want %+v", resp, expected) - } + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *decisions) //Empty return - decisionsFilter = DecisionsListOpts{IPEquals: new(string)} - *decisionsFilter.IPEquals = "1.2.3.5" + decisionsFilter = DecisionsListOpts{IPEquals: ptr.Of("1.2.3.5")} decisions, resp, err = newcli.Decisions.List(context.Background(), decisionsFilter) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Empty(t, *decisions) } @@ -120,6 +96,7 @@ func TestDecisionsStream(t *testing.T) { } } }) + mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodDelete) @@ -129,9 +106,7 @@ func TestDecisionsStream(t *testing.T) { }) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) //ok answer auth := &APIKeyTransport{ @@ -139,63 +114,38 @@ func TestDecisionsStream(t *testing.T) { } newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tduration := "3h59m55.756182786s" - torigin := "cscli" - tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'" - tscope := "Ip" - ttype := "ban" - tvalue := "1.2.3.4" expected := &models.DecisionsStreamResponse{ New: models.GetDecisionsResponse{ &models.Decision{ - Duration: &tduration, + Duration: ptr.Of("3h59m55.756182786s"), ID: 4, - Origin: &torigin, - Scenario: &tscenario, - Scope: &tscope, - Type: &ttype, - Value: &tvalue, + Origin: ptr.Of("cscli"), + Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"), + Scope: ptr.Of("Ip"), + Type: ptr.Of("ban"), + Value: ptr.Of("1.2.3.4"), }, }, } decisions, resp, err := newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: true}) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if err != nil { - t.Fatalf("new api client: %s", err) - } - - if !reflect.DeepEqual(*decisions, *expected) { - t.Fatalf("returned %+v, want %+v", resp, expected) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *decisions) //and second call, we get empty lists decisions, resp, err = newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: false}) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Empty(t, decisions.New) assert.Empty(t, decisions.Deleted) //delete stream resp, err = newcli.Decisions.StopStream(context.Background()) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) } func TestDecisionsStreamV3Compatibility(t *testing.T) { @@ -219,9 +169,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { }) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) //ok answer auth := &APIKeyTransport{ @@ -229,38 +177,30 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { } newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tduration := "3h59m55.756182786s" torigin := "CAPI" - tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'" tscope := "ip" ttype := "ban" - tvalue := "1.2.3.4" - tvalue1 := "1.2.3.5" - tscenarioDeleted := "deleted" - tdurationDeleted := "1h" expected := &models.DecisionsStreamResponse{ New: models.GetDecisionsResponse{ &models.Decision{ - Duration: &tduration, + Duration: ptr.Of("3h59m55.756182786s"), Origin: &torigin, - Scenario: &tscenario, + Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"), Scope: &tscope, Type: &ttype, - Value: &tvalue, + Value: ptr.Of("1.2.3.4"), }, }, Deleted: models.GetDecisionsResponse{ &models.Decision{ - Duration: &tdurationDeleted, + Duration: ptr.Of("1h"), Origin: &torigin, - Scenario: &tscenarioDeleted, + Scenario: ptr.Of("deleted"), Scope: &tscope, Type: &ttype, - Value: &tvalue1, + Value: ptr.Of("1.2.3.5"), }, }, } @@ -268,18 +208,8 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { // GetStream is supposed to consume v3 payload and return v2 response decisions, resp, err := newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: true}) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if err != nil { - t.Fatalf("new api client: %s", err) - } - - if !reflect.DeepEqual(*decisions, *expected) { - t.Fatalf("returned %+v, want %+v", resp, expected) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *decisions) } func TestDecisionsStreamV3(t *testing.T) { @@ -300,9 +230,7 @@ func TestDecisionsStreamV3(t *testing.T) { }) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) //ok answer auth := &APIKeyTransport{ @@ -310,30 +238,19 @@ func TestDecisionsStreamV3(t *testing.T) { } newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tduration := "3h59m55.756182786s" - tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'" tscope := "ip" - tvalue := "1.2.3.4" - tvalue1 := "1.2.3.5" - tdurationBlocklist := "24h" - tnameBlocklist := "blocklist1" - tremediationBlocklist := "ban" - tscopeBlocklist := "ip" - turlBlocklist := "/v3/blocklist" expected := &modelscapi.GetDecisionsStreamResponse{ New: modelscapi.GetDecisionsStreamResponseNew{ &modelscapi.GetDecisionsStreamResponseNewItem{ Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{ { - Duration: &tduration, - Value: &tvalue, + Duration: ptr.Of("3h59m55.756182786s"), + Value: ptr.Of("1.2.3.4"), }, }, - Scenario: &tscenario, + Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"), Scope: &tscope, }, }, @@ -341,18 +258,18 @@ func TestDecisionsStreamV3(t *testing.T) { &modelscapi.GetDecisionsStreamResponseDeletedItem{ Scope: &tscope, Decisions: []string{ - tvalue1, + "1.2.3.5", }, }, }, Links: &modelscapi.GetDecisionsStreamResponseLinks{ Blocklists: []*modelscapi.BlocklistLink{ { - Duration: &tdurationBlocklist, - Name: &tnameBlocklist, - Remediation: &tremediationBlocklist, - Scope: &tscopeBlocklist, - URL: &turlBlocklist, + Duration: ptr.Of("24h"), + Name: ptr.Of("blocklist1"), + Remediation: ptr.Of("ban"), + Scope: ptr.Of("ip"), + URL: ptr.Of("/v3/blocklist"), }, }, }, @@ -361,18 +278,8 @@ func TestDecisionsStreamV3(t *testing.T) { // GetStream is supposed to consume v3 payload and return v2 response decisions, resp, err := newcli.Decisions.GetStreamV3(context.Background(), DecisionsStreamOpts{Startup: true}) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if err != nil { - t.Fatalf("new api client: %s", err) - } - - if !reflect.DeepEqual(*decisions, *expected) { - t.Fatalf("returned %+v, want %+v", resp, expected) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *decisions) } func TestDecisionsFromBlocklist(t *testing.T) { @@ -383,10 +290,13 @@ func TestDecisionsFromBlocklist(t *testing.T) { mux.HandleFunc("/blocklist", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, http.MethodGet) + if r.Header.Get("If-Modified-Since") == "Sun, 01 Jan 2023 01:01:01 GMT" { w.WriteHeader(http.StatusNotModified) + return } + if r.Method == http.MethodGet { w.WriteHeader(http.StatusOK) w.Write([]byte("1.2.3.4\r\n1.2.3.5")) @@ -394,9 +304,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { }) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) //ok answer auth := &APIKeyTransport{ @@ -404,12 +312,8 @@ func TestDecisionsFromBlocklist(t *testing.T) { } newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tvalue1 := "1.2.3.4" - tvalue2 := "1.2.3.5" tdurationBlocklist := "24h" tnameBlocklist := "blocklist1" tremediationBlocklist := "ban" @@ -419,7 +323,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { expected := []*models.Decision{ { Duration: &tdurationBlocklist, - Value: &tvalue1, + Value: ptr.Of("1.2.3.4"), Scenario: &tnameBlocklist, Scope: &tscopeBlocklist, Type: &tremediationBlocklist, @@ -427,7 +331,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { }, { Duration: &tdurationBlocklist, - Value: &tvalue2, + Value: ptr.Of("1.2.3.5"), Scenario: &tnameBlocklist, Scope: &tscopeBlocklist, Type: &tremediationBlocklist, @@ -450,13 +354,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { log.Infof("expected : %s, %s, %s, %s, %s", *expected[0].Value, *expected[0].Duration, *expected[0].Scenario, *expected[0].Scope, *expected[0].Type) log.Infof("decisions: %s, %s, %s, %s, %s", *decisions[1].Value, *decisions[1].Duration, *decisions[1].Scenario, *decisions[1].Scope, *decisions[1].Type) - if err != nil { - t.Fatalf("new api client: %s", err) - } - - if !reflect.DeepEqual(decisions, expected) { - t.Fatalf("returned %+v, want %+v", decisions, expected) - } + assert.Equal(t, expected, decisions) // test cache control _, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{ @@ -466,8 +364,10 @@ func TestDecisionsFromBlocklist(t *testing.T) { Name: &tnameBlocklist, Duration: &tdurationBlocklist, }, ptr.Of("Sun, 01 Jan 2023 01:01:01 GMT")) + require.NoError(t, err) assert.False(t, isModified) + _, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{ URL: &turlBlocklist, Scope: &tscopeBlocklist, @@ -475,6 +375,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { Name: &tnameBlocklist, Duration: &tdurationBlocklist, }, ptr.Of("Mon, 02 Jan 2023 01:01:01 GMT")) + require.NoError(t, err) assert.True(t, isModified) } @@ -485,6 +386,7 @@ func TestDeleteDecisions(t *testing.T) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) + mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "DELETE") assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) @@ -492,11 +394,12 @@ func TestDeleteDecisions(t *testing.T) { w.Write([]byte(`{"nbDeleted":"1"}`)) //w.Write([]byte(`{"message":"0 deleted alerts"}`)) }) + log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -504,18 +407,13 @@ func TestDeleteDecisions(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) filters := DecisionsDeleteOpts{IPEquals: new(string)} *filters.IPEquals = "1.2.3.4" - deleted, _, err := client.Decisions.Delete(context.Background(), filters) - if err != nil { - t.Fatalf("unexpected err : %s", err) - } + deleted, _, err := client.Decisions.Delete(context.Background(), filters) + require.NoError(t, err) assert.Equal(t, "1", deleted.NbDeleted) defer teardown() @@ -530,22 +428,23 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { ScenariosContaining string ScenariosNotContaining string } + tests := []struct { - name string - fields fields - want string - wantErr bool + name string + fields fields + expected string + expectedErr string }{ { - name: "no filter", - want: baseURLString + "?", + name: "no filter", + expected: baseURLString + "?", }, { name: "startup=true", fields: fields{ Startup: true, }, - want: baseURLString + "?startup=true", + expected: baseURLString + "?startup=true", }, { name: "set all params", @@ -555,7 +454,7 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { ScenariosContaining: "ssh", ScenariosNotContaining: "bf", }, - want: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true", + expected: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true", }, } @@ -568,25 +467,20 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { ScenariosContaining: tt.fields.ScenariosContaining, ScenariosNotContaining: tt.fields.ScenariosNotContaining, } + got, err := o.addQueryParamsToURL(baseURLString) - if (err != nil) != tt.wantErr { - t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() error = %v, wantErr %v", err, tt.wantErr) + cstest.RequireErrorContains(t, err, tt.expectedErr) + if tt.expectedErr != "" { return } gotURL, err := url.Parse(got) - if err != nil { - t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() got error while parsing URL: %s", err) - } + require.NoError(t, err) - expectedURL, err := url.Parse(tt.want) - if err != nil { - t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() got error while parsing URL: %s", err) - } + expectedURL, err := url.Parse(tt.expected) + require.NoError(t, err) - if *gotURL != *expectedURL { - t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() = %v, want %v", *gotURL, *expectedURL) - } + assert.Equal(t, *expectedURL, *gotURL) }) } } diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index 5824eb060..536505817 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -10,8 +10,8 @@ import ( "testing" "github.com/gin-gonic/gin" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" @@ -22,21 +22,14 @@ type LAPI struct { router *gin.Engine loginResp models.WatcherAuthResponse bouncerKey string - t *testing.T DBConfig *csconfig.DatabaseCfg } func SetupLAPITest(t *testing.T) LAPI { t.Helper() - router, loginResp, config, err := InitMachineTest(t) - if err != nil { - t.Fatal(err) - } + router, loginResp, config := InitMachineTest(t) - APIKey, err := CreateTestBouncer(config.API.Server.DbConfig) - if err != nil { - t.Fatal(err) - } + APIKey := CreateTestBouncer(t, config.API.Server.DbConfig) return LAPI{ router: router, @@ -46,24 +39,23 @@ func SetupLAPITest(t *testing.T) LAPI { } } -func (l *LAPI) InsertAlertFromFile(path string) *httptest.ResponseRecorder { - alertReader := GetAlertReaderFromFile(path) - return l.RecordResponse(http.MethodPost, "/v1/alerts", alertReader, "password") +func (l *LAPI) InsertAlertFromFile(t *testing.T, path string) *httptest.ResponseRecorder { + alertReader := GetAlertReaderFromFile(t, path) + return l.RecordResponse(t, http.MethodPost, "/v1/alerts", alertReader, "password") } -func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder { +func (l *LAPI) RecordResponse(t *testing.T, verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder { w := httptest.NewRecorder() req, err := http.NewRequest(verb, url, body) - if err != nil { - l.t.Fatal(err) - } + require.NoError(t, err) - if authType == "apikey" { + switch authType { + case "apikey": req.Header.Add("X-Api-Key", l.bouncerKey) - } else if authType == "password" { + case "password": AddAuthHeaders(req, l.loginResp) - } else { - l.t.Fatal("auth type not supported") + default: + t.Fatal("auth type not supported") } l.router.ServeHTTP(w, req) @@ -71,29 +63,16 @@ func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, aut return w } -func InitMachineTest(t *testing.T) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config, error) { - router, config, err := NewAPITest(t) - if err != nil { - return nil, models.WatcherAuthResponse{}, config, fmt.Errorf("unable to run local API: %s", err) - } +func InitMachineTest(t *testing.T) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) { + router, config := NewAPITest(t) + loginResp := LoginToTestAPI(t, router, config) - loginResp, err := LoginToTestAPI(router, config) - if err != nil { - return nil, models.WatcherAuthResponse{}, config, err - } - - return router, loginResp, config, nil + return router, loginResp, config } -func LoginToTestAPI(router *gin.Engine, config csconfig.Config) (models.WatcherAuthResponse, error) { - body, err := CreateTestMachine(router) - if err != nil { - return models.WatcherAuthResponse{}, err - } - err = ValidateMachine("test", config.API.Server.DbConfig) - if err != nil { - log.Fatalln(err) - } +func LoginToTestAPI(t *testing.T, router *gin.Engine, config csconfig.Config) models.WatcherAuthResponse { + body := CreateTestMachine(t, router) + ValidateMachine(t, "test", config.API.Server.DbConfig) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) @@ -101,12 +80,10 @@ func LoginToTestAPI(router *gin.Engine, config csconfig.Config) (models.WatcherA router.ServeHTTP(w, req) loginResp := models.WatcherAuthResponse{} - err = json.NewDecoder(w.Body).Decode(&loginResp) - if err != nil { - return models.WatcherAuthResponse{}, err - } + err := json.NewDecoder(w.Body).Decode(&loginResp) + require.NoError(t, err) - return loginResp, nil + return loginResp } func AddAuthHeaders(request *http.Request, authResponse models.WatcherAuthResponse) { @@ -116,17 +93,17 @@ func AddAuthHeaders(request *http.Request, authResponse models.WatcherAuthRespon func TestSimulatedAlert(t *testing.T) { lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_minibulk+simul.json") - alertContent := GetAlertReaderFromFile("./tests/alert_minibulk+simul.json") + lapi.InsertAlertFromFile(t, "./tests/alert_minibulk+simul.json") + alertContent := GetAlertReaderFromFile(t, "./tests/alert_minibulk+simul.json") //exclude decision in simulation mode - w := lapi.RecordResponse("GET", "/v1/alerts?simulated=false", alertContent, "password") + w := lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=false", alertContent, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.NotContains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) //include decision in simulation mode - w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", alertContent, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=true", alertContent, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) @@ -136,35 +113,29 @@ func TestCreateAlert(t *testing.T) { lapi := SetupLAPITest(t) // Create Alert with invalid format - w := lapi.RecordResponse(http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password") + w := lapi.RecordResponse(t, http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password") assert.Equal(t, 400, w.Code) - assert.Equal(t, "{\"message\":\"invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String()) + assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Create Alert with invalid input - alertContent := GetAlertReaderFromFile("./tests/invalidAlert_sample.json") + alertContent := GetAlertReaderFromFile(t, "./tests/invalidAlert_sample.json") - w = lapi.RecordResponse(http.MethodPost, "/v1/alerts", alertContent, "password") + w = lapi.RecordResponse(t, http.MethodPost, "/v1/alerts", alertContent, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, "{\"message\":\"validation failure list:\\n0.scenario in body is required\\n0.scenario_hash in body is required\\n0.scenario_version in body is required\\n0.simulated in body is required\\n0.source in body is required\"}", w.Body.String()) + assert.Equal(t, `{"message":"validation failure list:\n0.scenario in body is required\n0.scenario_hash in body is required\n0.scenario_version in body is required\n0.simulated in body is required\n0.source in body is required"}`, w.Body.String()) // Create Valid Alert - w = lapi.InsertAlertFromFile("./tests/alert_sample.json") + w = lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") assert.Equal(t, 201, w.Code) - assert.Equal(t, "[\"1\"]", w.Body.String()) + assert.Equal(t, `["1"]`, w.Body.String()) } func TestCreateAlertChannels(t *testing.T) { - apiServer, config, err := NewAPIServer(t) - if err != nil { - log.Fatalln(err) - } + apiServer, config := NewAPIServer(t) apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert) apiServer.InitController() - loginResp, err := LoginToTestAPI(apiServer.router, config) - if err != nil { - log.Fatalln(err) - } + loginResp := LoginToTestAPI(t, apiServer.router, config) lapi := LAPI{router: apiServer.router, loginResp: loginResp} var ( @@ -180,7 +151,7 @@ func TestCreateAlertChannels(t *testing.T) { wg.Done() }() - go lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json") + go lapi.InsertAlertFromFile(t, "./tests/alert_ssh-bf.json") wg.Wait() assert.Len(t, pd.Alert.Decisions, 1) apiServer.Close() @@ -188,18 +159,18 @@ func TestCreateAlertChannels(t *testing.T) { func TestAlertListFilters(t *testing.T) { lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json") - alertContent := GetAlertReaderFromFile("./tests/alert_ssh-bf.json") + lapi.InsertAlertFromFile(t, "./tests/alert_ssh-bf.json") + alertContent := GetAlertReaderFromFile(t, "./tests/alert_ssh-bf.json") //bad filter - w := lapi.RecordResponse("GET", "/v1/alerts?test=test", alertContent, "password") + w := lapi.RecordResponse(t, "GET", "/v1/alerts?test=test", alertContent, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String()) + assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) //get without filters - w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts", emptyBody, "password") assert.Equal(t, 200, w.Code) //check alert and decision assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") @@ -207,149 +178,149 @@ func TestAlertListFilters(t *testing.T) { //test decision_type filter (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ban", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?decision_type=ban", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test decision_type filter (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ratata", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?decision_type=ratata", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test scope (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?scope=Ip", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?scope=Ip", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test scope (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?scope=rarara", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?scope=rarara", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test scenario (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test scenario (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test ip (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test ip (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test ip (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?ip=gruueq", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String()) //test range (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test range - w = lapi.RecordResponse("GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test range (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?range=ratata", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=ratata", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String()) //test since (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?since=1h", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1h", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test since (ok but yields no results) - w = lapi.RecordResponse("GET", "/v1/alerts?since=1ns", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1ns", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test since (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?since=1zuzu", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1zuzu", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) //test until (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?until=1ns", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1ns", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test until (ok but no return) - w = lapi.RecordResponse("GET", "/v1/alerts?until=1m", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1m", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test until (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?until=1zuzu", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1zuzu", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) //test simulated (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=true", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test simulated (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?simulated=false", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test has active decision - w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=true", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=true", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test has active decision - w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=false", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test has active decision (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String()) } @@ -357,32 +328,32 @@ func TestAlertListFilters(t *testing.T) { func TestAlertBulkInsert(t *testing.T) { lapi := SetupLAPITest(t) //insert a bulk of 20 alerts to trigger bulk insert - lapi.InsertAlertFromFile("./tests/alert_bulk.json") - alertContent := GetAlertReaderFromFile("./tests/alert_bulk.json") + lapi.InsertAlertFromFile(t, "./tests/alert_bulk.json") + alertContent := GetAlertReaderFromFile(t, "./tests/alert_bulk.json") - w := lapi.RecordResponse("GET", "/v1/alerts", alertContent, "password") + w := lapi.RecordResponse(t, "GET", "/v1/alerts", alertContent, "password") assert.Equal(t, 200, w.Code) } func TestListAlert(t *testing.T) { lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") // List Alert with invalid filter - w := lapi.RecordResponse("GET", "/v1/alerts?test=test", emptyBody, "password") + w := lapi.RecordResponse(t, "GET", "/v1/alerts?test=test", emptyBody, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String()) + assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) // List Alert - w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "crowdsecurity/test") } func TestCreateAlertErrors(t *testing.T) { lapi := SetupLAPITest(t) - alertContent := GetAlertReaderFromFile("./tests/alert_sample.json") + alertContent := GetAlertReaderFromFile(t, "./tests/alert_sample.json") //test invalid bearer w := httptest.NewRecorder() @@ -403,7 +374,7 @@ func TestCreateAlertErrors(t *testing.T) { func TestDeleteAlert(t *testing.T) { lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") // Fail Delete Alert w := httptest.NewRecorder() @@ -426,7 +397,7 @@ func TestDeleteAlert(t *testing.T) { func TestDeleteAlertByID(t *testing.T) { lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") // Fail Delete Alert w := httptest.NewRecorder() @@ -454,25 +425,18 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24"} cfg.API.Server.ListenURI = "::8080" server, err := NewServer(cfg.API.Server) - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) + err = server.InitController() - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) + router, err := server.Router() - if err != nil { - log.Fatal(err) - } - loginResp, err := LoginToTestAPI(router, cfg) - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) + + loginResp := LoginToTestAPI(t, router, cfg) lapi := LAPI{ router: router, loginResp: loginResp, - t: t, } assertAlertDeleteFailedFromIP := func(ip string) { @@ -498,17 +462,17 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) } - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") assertAlertDeleteFailedFromIP("4.3.2.1") assertAlertDeletedFromIP("1.2.3.4") - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.0") - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.1") - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.255") - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") assertAlertDeletedFromIP("127.0.0.1") } diff --git a/pkg/apiserver/api_key_test.go b/pkg/apiserver/api_key_test.go index df61e0b26..883ff2129 100644 --- a/pkg/apiserver/api_key_test.go +++ b/pkg/apiserver/api_key_test.go @@ -6,20 +6,14 @@ import ( "strings" "testing" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) func TestAPIKey(t *testing.T) { - router, config, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, config := NewAPITest(t) + + APIKey := CreateTestBouncer(t, config.API.Server.DbConfig) - APIKey, err := CreateTestBouncer(config.API.Server.DbConfig) - if err != nil { - log.Fatal(err) - } // Login with empty token w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) @@ -27,7 +21,7 @@ func TestAPIKey(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 403, w.Code) - assert.Equal(t, "{\"message\":\"access forbidden\"}", w.Body.String()) + assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String()) // Login with invalid token w = httptest.NewRecorder() @@ -37,7 +31,7 @@ func TestAPIKey(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 403, w.Code) - assert.Equal(t, "{\"message\":\"access forbidden\"}", w.Body.String()) + assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String()) // Login with valid token w = httptest.NewRecorder() diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 16dba1e86..74c627cd0 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -36,6 +36,7 @@ import ( func getDBClient(t *testing.T) *database.Client { t.Helper() + dbPath, err := os.CreateTemp("", "*sqlite") require.NoError(t, err) dbClient, err := database.NewClient(&csconfig.DatabaseCfg{ @@ -72,8 +73,9 @@ func getAPIC(t *testing.T) *apic { } } -func absDiff(a int, b int) (c int) { - if c = a - b; c < 0 { +func absDiff(a int, b int) int { + c := a - b + if c < 0 { return -1 * c } @@ -185,6 +187,7 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) { func TestNewAPIC(t *testing.T) { var testConfig *csconfig.OnlineApiClientCfg + setConfig := func() { testConfig = &csconfig.OnlineApiClientCfg{ Credentials: &csconfig.ApiCredentialsCfg{ @@ -199,6 +202,7 @@ func TestNewAPIC(t *testing.T) { dbClient *database.Client consoleConfig *csconfig.ConsoleConfig } + tests := []struct { name string args args @@ -374,7 +378,6 @@ func TestAPICGetMetrics(t *testing.T) { assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers) assert.Equal(t, tc.expectedMetric.Machines, foundMetrics.Machines) - }) } } @@ -403,6 +406,7 @@ func TestCreateAlertsForDecision(t *testing.T) { type args struct { decisions []*models.Decision } + tests := []struct { name string args args @@ -489,6 +493,7 @@ func TestFillAlertsWithDecisions(t *testing.T) { alerts []*models.Alert decisions []*models.Decision } + tests := []struct { name string args args @@ -544,26 +549,18 @@ func TestAPICWhitelists(t *testing.T) { api := getAPIC(t) //one whitelist on IP, one on CIDR api.whitelists = &csconfig.CapiWhitelist{} - ipwl1 := "9.2.3.4" - ip := net.ParseIP(ipwl1) - api.whitelists.Ips = append(api.whitelists.Ips, ip) - ipwl1 = "7.2.3.4" - ip = net.ParseIP(ipwl1) - api.whitelists.Ips = append(api.whitelists.Ips, ip) - cidrwl1 := "13.2.3.0/24" - _, tnet, err := net.ParseCIDR(cidrwl1) - if err != nil { - t.Fatalf("unable to parse cidr : %s", err) - } + api.whitelists.Ips = append(api.whitelists.Ips, net.ParseIP("9.2.3.4"), net.ParseIP("7.2.3.4")) + + _, tnet, err := net.ParseCIDR("13.2.3.0/24") + require.NoError(t, err) api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet) - cidrwl1 = "11.2.3.0/24" - _, tnet, err = net.ParseCIDR(cidrwl1) - if err != nil { - t.Fatalf("unable to parse cidr : %s", err) - } + + _, tnet, err = net.ParseCIDR("11.2.3.0/24") + require.NoError(t, err) api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet) + api.dbClient.Ent.Decision.Create(). SetOrigin(types.CAPIOrigin). SetType("ban"). @@ -663,12 +660,15 @@ func TestAPICWhitelists(t *testing.T) { }, ), )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", httpmock.NewStringResponder( 200, "1.2.3.6", )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist2", httpmock.NewStringResponder( 200, "1.2.3.7", )) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) @@ -801,12 +801,15 @@ func TestAPICPullTop(t *testing.T) { }, ), )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", httpmock.NewStringResponder( 200, "1.2.3.6", )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist2", httpmock.NewStringResponder( 200, "1.2.3.7", )) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) @@ -828,7 +831,8 @@ func TestAPICPullTop(t *testing.T) { alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background()) validDecisions := api.dbClient.Ent.Decision.Query().Where( decision.UntilGT(time.Now())). - AllX(context.Background()) + AllX(context.Background(), + ) decisionScenarioFreq := make(map[string]int) alertScenario := make(map[string]int) @@ -858,6 +862,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { httpmock.Activate() defer httpmock.DeactivateAndReset() + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder( 200, jsonMarshalX( modelscapi.GetDecisionsStreamResponse{ @@ -887,10 +892,12 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { }, ), )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) { assert.Equal(t, "", req.Header.Get("If-Modified-Since")) return httpmock.NewStringResponse(200, "1.2.3.4"), nil }) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) @@ -916,6 +923,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { assert.NotEqual(t, "", req.Header.Get("If-Modified-Since")) return httpmock.NewStringResponse(304, ""), nil }) + err = api.PullTop(false) require.NoError(t, err) secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) @@ -928,6 +936,7 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) { httpmock.Activate() defer httpmock.DeactivateAndReset() + // create a decision about to expire. It should force fetch alertInstance := api.dbClient.Ent.Alert. Create(). @@ -975,10 +984,12 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) { }, ), )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) { assert.Equal(t, "", req.Header.Get("If-Modified-Since")) return httpmock.NewStringResponse(304, ""), nil }) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) @@ -1005,6 +1016,7 @@ func TestAPICPullBlocklistCall(t *testing.T) { assert.Equal(t, "", req.Header.Get("If-Modified-Since")) return httpmock.NewStringResponse(200, "1.2.3.4"), nil }) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) @@ -1073,6 +1085,7 @@ func TestAPICPush(t *testing.T) { Source: &models.Source{}, } } + return alerts }(), }, diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 62a8b83dd..b7f6be5fe 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -13,11 +13,12 @@ import ( "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/version" middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" @@ -63,13 +64,14 @@ func LoadTestConfig(t *testing.T) csconfig.Config { ShareCustomScenarios: new(bool), }, } + apiConfig := csconfig.APICfg{ Server: &apiServerConfig, } + config.API = &apiConfig - if err := config.API.Server.LoadProfiles(); err != nil { - log.Fatalf("failed to load profiles: %s", err) - } + err := config.API.Server.LoadProfiles() + require.NoError(t, err) return config } @@ -106,110 +108,89 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config { Server: &apiServerConfig, } config.API = &apiConfig - if err := config.API.Server.LoadProfiles(); err != nil { - log.Fatalf("failed to load profiles: %s", err) - } + err := config.API.Server.LoadProfiles() + require.NoError(t, err) return config } -func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config, error) { +func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) { config := LoadTestConfig(t) os.Remove("./ent") + apiServer, err := NewServer(config.API.Server) - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } + require.NoError(t, err) log.Printf("Creating new API server") gin.SetMode(gin.TestMode) - return apiServer, config, nil + return apiServer, config } -func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config, error) { - apiServer, config, err := NewAPIServer(t) - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } - err = apiServer.InitController() - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } +func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) { + apiServer, config := NewAPIServer(t) + + err := apiServer.InitController() + require.NoError(t, err) + router, err := apiServer.Router() - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } + require.NoError(t, err) - return router, config, nil + return router, config } -func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config, error) { +func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) { config := LoadTestConfigForwardedFor(t) os.Remove("./ent") + apiServer, err := NewServer(config.API.Server) - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } + require.NoError(t, err) + err = apiServer.InitController() - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } + require.NoError(t, err) log.Printf("Creating new API server") gin.SetMode(gin.TestMode) + router, err := apiServer.Router() - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } + require.NoError(t, err) - return router, config, nil + return router, config } -func ValidateMachine(machineID string, config *csconfig.DatabaseCfg) error { +func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) { dbClient, err := database.NewClient(config) - if err != nil { - return fmt.Errorf("unable to create new database client: %s", err) - } + require.NoError(t, err) - if err := dbClient.ValidateMachine(machineID); err != nil { - return fmt.Errorf("unable to validate machine: %s", err) - } - - return nil + err = dbClient.ValidateMachine(machineID) + require.NoError(t, err) } -func GetMachineIP(machineID string, config *csconfig.DatabaseCfg) (string, error) { +func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) string { dbClient, err := database.NewClient(config) - if err != nil { - return "", fmt.Errorf("unable to create new database client: %s", err) - } + require.NoError(t, err) + machines, err := dbClient.ListMachines() - if err != nil { - return "", fmt.Errorf("Unable to list machines: %s", err) - } + require.NoError(t, err) for _, machine := range machines { if machine.MachineId == machineID { - return machine.IpAddress, nil + return machine.IpAddress } } - return "", nil + return "" } -func GetAlertReaderFromFile(path string) *strings.Reader { +func GetAlertReaderFromFile(t *testing.T, path string) *strings.Reader { alertContentBytes, err := os.ReadFile(path) - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) alerts := make([]*models.Alert, 0) - if err = json.Unmarshal(alertContentBytes, &alerts); err != nil { - log.Fatal(err) - } + err = json.Unmarshal(alertContentBytes, &alerts) + require.NoError(t, err) for _, alert := range alerts { *alert.StartAt = time.Now().UTC().Format(time.RFC3339) @@ -217,74 +198,57 @@ func GetAlertReaderFromFile(path string) *strings.Reader { } alertContent, err := json.Marshal(alerts) - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) return strings.NewReader(string(alertContent)) } -func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision, int, error) { +func readDecisionsGetResp(t *testing.T, resp *httptest.ResponseRecorder) ([]*models.Decision, int) { var response []*models.Decision - if resp == nil { - return nil, 0, errors.New("response is nil") - } - err := json.Unmarshal(resp.Body.Bytes(), &response) - if err != nil { - return nil, resp.Code, err - } + require.NotNil(t, resp) - return response, resp.Code, nil + err := json.Unmarshal(resp.Body.Bytes(), &response) + require.NoError(t, err) + + return response, resp.Code } -func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string, int, error) { +func readDecisionsErrorResp(t *testing.T, resp *httptest.ResponseRecorder) (map[string]string, int) { var response map[string]string - if resp == nil { - return nil, 0, errors.New("response is nil") - } - err := json.Unmarshal(resp.Body.Bytes(), &response) - if err != nil { - return nil, resp.Code, err - } + require.NotNil(t, resp) - return response, resp.Code, nil + err := json.Unmarshal(resp.Body.Bytes(), &response) + require.NoError(t, err) + + return response, resp.Code } -func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int, error) { +func readDecisionsDeleteResp(t *testing.T, resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int) { var response models.DeleteDecisionResponse - if resp == nil { - return nil, 0, errors.New("response is nil") - } + require.NotNil(t, resp) err := json.Unmarshal(resp.Body.Bytes(), &response) - if err != nil { - return nil, resp.Code, err - } + require.NoError(t, err) - return &response, resp.Code, nil + return &response, resp.Code } -func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int, error) { +func readDecisionsStreamResp(t *testing.T, resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int) { response := make(map[string][]*models.Decision) - if resp == nil { - return nil, 0, errors.New("response is nil") - } + require.NotNil(t, resp) err := json.Unmarshal(resp.Body.Bytes(), &response) - if err != nil { - return nil, resp.Code, err - } + require.NoError(t, err) - return response, resp.Code, nil + return response, resp.Code } -func CreateTestMachine(router *gin.Engine) (string, error) { +func CreateTestMachine(t *testing.T, router *gin.Engine) string { b, err := json.Marshal(MachineTest) - if err != nil { - return "", fmt.Errorf("unable to marshal MachineTest") - } + require.NoError(t, err) + body := string(b) w := httptest.NewRecorder() @@ -292,26 +256,20 @@ func CreateTestMachine(router *gin.Engine) (string, error) { req.Header.Set("User-Agent", UserAgent) router.ServeHTTP(w, req) - return body, nil + return body } -func CreateTestBouncer(config *csconfig.DatabaseCfg) (string, error) { +func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string { dbClient, err := database.NewClient(config) - if err != nil { - log.Fatalf("unable to create new database client: %s", err) - } + require.NoError(t, err) apiKey, err := middlewares.GenerateAPIKey(keyLength) - if err != nil { - return "", fmt.Errorf("unable to generate api key: %s", err) - } + require.NoError(t, err) _, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) - if err != nil { - return "", fmt.Errorf("unable to create blocker: %s", err) - } + require.NoError(t, err) - return apiKey, nil + return apiKey } func TestWithWrongDBConfig(t *testing.T) { @@ -334,10 +292,7 @@ func TestWithWrongFlushConfig(t *testing.T) { } func TestUnknownPath(t *testing.T) { - router, _, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, _ := NewAPITest(t) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/test", nil) @@ -384,24 +339,17 @@ func TestLoggingDebugToFileConfig(t *testing.T) { LogDir: tempDir, DbConfig: &dbconfig, } - lvl := log.DebugLevel expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir) expectedLines := []string{"/test42"} - cfg.LogLevel = &lvl + cfg.LogLevel = ptr.Of(log.DebugLevel) // Configure logging - if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil { - t.Fatal(err) - } + err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false) + require.NoError(t, err) api, err := NewServer(&cfg) - if err != nil { - t.Fatalf("failed to create api : %s", err) - } - - if api == nil { - t.Fatalf("failed to create api #2 is nbill") - } + require.NoError(t, err) + require.NotNil(t, api) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/test42", nil) @@ -413,14 +361,10 @@ func TestLoggingDebugToFileConfig(t *testing.T) { //check file content data, err := os.ReadFile(expectedFile) - if err != nil { - t.Fatalf("failed to read file : %s", err) - } + require.NoError(t, err) for _, expectedStr := range expectedLines { - if !strings.Contains(string(data), expectedStr) { - t.Fatalf("expected %s in %s", expectedStr, string(data)) - } + assert.Contains(t, string(data), expectedStr) } } @@ -446,35 +390,29 @@ func TestLoggingErrorToFileConfig(t *testing.T) { LogDir: tempDir, DbConfig: &dbconfig, } - lvl := log.ErrorLevel expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir) - cfg.LogLevel = &lvl + cfg.LogLevel = ptr.Of(log.ErrorLevel) // Configure logging - if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil { - t.Fatal(err) - } - api, err := NewServer(&cfg) - if err != nil { - t.Fatalf("failed to create api : %s", err) - } + err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false) + require.NoError(t, err) - if api == nil { - t.Fatalf("failed to create api #2 is nbill") - } + api, err := NewServer(&cfg) + require.NoError(t, err) + require.NotNil(t, api) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) api.router.ServeHTTP(w, req) - assert.Equal(t, 404, w.Code) + assert.Equal(t, http.StatusNotFound, w.Code) //wait for the request to happen time.Sleep(500 * time.Millisecond) //check file content x, err := os.ReadFile(expectedFile) - if err == nil && len(x) > 0 { - t.Fatalf("file should be empty, got '%s'", x) + if err == nil { + require.Empty(t, x) } os.Remove("./crowdsec.log") diff --git a/pkg/apiserver/decisions_test.go b/pkg/apiserver/decisions_test.go index 465accbac..e4c9dda47 100644 --- a/pkg/apiserver/decisions_test.go +++ b/pkg/apiserver/decisions_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const ( @@ -16,23 +15,22 @@ func TestDeleteDecisionRange(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") // delete by ip wrong - w := lapi.RecordResponse("DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by range - w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String()) // delete by range : ensure it was already deleted - w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) } @@ -41,23 +39,23 @@ func TestDeleteDecisionFilter(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") // delete by ip wrong - w := lapi.RecordResponse("DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by ip good - w = lapi.RecordResponse("DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) // delete by scope/value - w = lapi.RecordResponse("DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) } @@ -66,17 +64,17 @@ func TestDeleteDecisionFilterByScenario(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") // delete by wrong scenario - w := lapi.RecordResponse("DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by scenario good - w = lapi.RecordResponse("DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String()) } @@ -85,14 +83,13 @@ func TestGetDecisionFilters(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") // Get Decision - w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY) + w := lapi.RecordResponse(t, "GET", "/v1/decisions", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err := readDecisionsGetResp(w) - require.NoError(t, err) + decisions, code := readDecisionsGetResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions, 2) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) @@ -104,10 +101,9 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : type filter - w = lapi.RecordResponse("GET", "/v1/decisions?type=ban", emptyBody, APIKEY) + w = lapi.RecordResponse(t, "GET", "/v1/decisions?type=ban", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err = readDecisionsGetResp(w) - require.NoError(t, err) + decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions, 2) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) @@ -122,10 +118,9 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : scope/value - w = lapi.RecordResponse("GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY) + w = lapi.RecordResponse(t, "GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err = readDecisionsGetResp(w) - require.NoError(t, err) + decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions, 1) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) @@ -137,10 +132,9 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : ip filter - w = lapi.RecordResponse("GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY) + w = lapi.RecordResponse(t, "GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err = readDecisionsGetResp(w) - require.NoError(t, err) + decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions, 1) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) @@ -151,10 +145,9 @@ func TestGetDecisionFilters(t *testing.T) { // assert.NotContains(t, w.Body.String(), `"id":2,"origin":"crowdsec","scenario":"crowdsecurity/ssh-bf","scope":"Ip","type":"ban","value":"91.121.79.178"`) // Get decision : by range - w = lapi.RecordResponse("GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY) + w = lapi.RecordResponse(t, "GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err = readDecisionsGetResp(w) - require.NoError(t, err) + decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions, 2) assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.179") @@ -165,13 +158,12 @@ func TestGetDecision(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") // Get Decision - w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY) + w := lapi.RecordResponse(t, "GET", "/v1/decisions", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err := readDecisionsGetResp(w) - require.NoError(t, err) + decisions, code := readDecisionsGetResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions, 3) /*decisions get doesn't perform deduplication*/ @@ -188,7 +180,7 @@ func TestGetDecision(t *testing.T) { assert.Equal(t, int64(3), decisions[2].ID) // Get Decision with invalid filter. It should ignore this filter - w = lapi.RecordResponse("GET", "/v1/decisions?test=test", emptyBody, APIKEY) + w = lapi.RecordResponse(t, "GET", "/v1/decisions?test=test", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) assert.Len(t, decisions, 3) } @@ -197,49 +189,43 @@ func TestDeleteDecisionByID(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") //Have one alerts - w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err := readDecisionsStreamResp(w) - require.NoError(t, err) + w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code := readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) // Delete alert with Invalid ID - w = lapi.RecordResponse("DELETE", "/v1/decisions/test", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/test", emptyBody, PASSWORD) assert.Equal(t, 400, w.Code) - errResp, _, err := readDecisionsErrorResp(w) - require.NoError(t, err) + errResp, _ := readDecisionsErrorResp(t, w) assert.Equal(t, "decision_id must be valid integer", errResp["message"]) // Delete alert with ID that not exist - w = lapi.RecordResponse("DELETE", "/v1/decisions/100", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/100", emptyBody, PASSWORD) assert.Equal(t, 500, w.Code) - errResp, _, err = readDecisionsErrorResp(w) - require.NoError(t, err) + errResp, _ = readDecisionsErrorResp(t, w) assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", errResp["message"]) //Have one alerts - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - require.NoError(t, err) + w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) // Delete alert with valid ID - w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - resp, _, err := readDecisionsDeleteResp(w) - require.NoError(t, err) + resp, _ := readDecisionsDeleteResp(t, w) assert.Equal(t, "1", resp.NbDeleted) //Have one alert (because we delete an alert that has dup targets) - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - require.NoError(t, err) + w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) @@ -249,20 +235,18 @@ func TestDeleteDecision(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") // Delete alert with Invalid filter - w := lapi.RecordResponse("DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD) assert.Equal(t, 500, w.Code) - errResp, _, err := readDecisionsErrorResp(w) - require.NoError(t, err) + errResp, _ := readDecisionsErrorResp(t, w) assert.Equal(t, "'test' doesn't exist: invalid filter", errResp["message"]) // Delete all alert - w = lapi.RecordResponse("DELETE", "/v1/decisions", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - resp, _, err := readDecisionsDeleteResp(w) - require.NoError(t, err) + resp, _ := readDecisionsDeleteResp(t, w) assert.Equal(t, "3", resp.NbDeleted) } @@ -271,12 +255,11 @@ func TestStreamStartDecisionDedup(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert : 3 decisions for 127.0.0.1, longest has id=3 - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") // Get Stream, we only get one decision (the longest one) - w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err := readDecisionsStreamResp(w) - require.NoError(t, err) + w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code := readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) @@ -285,13 +268,12 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // id=3 decision is deleted, this won't affect `deleted`, because there are decisions on the same ip - w = lapi.RecordResponse("DELETE", "/v1/decisions/3", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/3", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) // Get Stream, we only get one decision (the longest one, id=2) - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - require.NoError(t, err) + w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) @@ -300,13 +282,12 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // We delete another decision, yet don't receive it in stream, since there's another decision on same IP - w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/2", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) // And get the remaining decision (1) - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - require.NoError(t, err) + w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) @@ -315,13 +296,12 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // We delete the last decision, we receive the delete order - w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) //and now we only get a deleted decision - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - require.NoError(t, err) + w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions["deleted"], 1) assert.Equal(t, int64(1), decisions["deleted"][0].ID) diff --git a/pkg/apiserver/heartbeat_test.go b/pkg/apiserver/heartbeat_test.go index 0082f23ec..fbf01c7fb 100644 --- a/pkg/apiserver/heartbeat_test.go +++ b/pkg/apiserver/heartbeat_test.go @@ -10,9 +10,9 @@ import ( func TestHeartBeat(t *testing.T) { lapi := SetupLAPITest(t) - w := lapi.RecordResponse(http.MethodGet, "/v1/heartbeat", emptyBody, "password") + w := lapi.RecordResponse(t, http.MethodGet, "/v1/heartbeat", emptyBody, "password") assert.Equal(t, 200, w.Code) - w = lapi.RecordResponse("POST", "/v1/heartbeat", emptyBody, "password") + w = lapi.RecordResponse(t, "POST", "/v1/heartbeat", emptyBody, "password") assert.Equal(t, 405, w.Code) } diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go index 886962250..58f66cfc7 100644 --- a/pkg/apiserver/jwt_test.go +++ b/pkg/apiserver/jwt_test.go @@ -6,20 +6,13 @@ import ( "strings" "testing" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) func TestLogin(t *testing.T) { - router, config, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, config := NewAPITest(t) - body, err := CreateTestMachine(router) - if err != nil { - log.Fatalln(err) - } + body := CreateTestMachine(t, router) // Login with machine not validated yet w := httptest.NewRecorder() @@ -28,16 +21,16 @@ func TestLogin(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"machine test not validated\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"machine test not validated"}`, w.Body.String()) // Login with machine not exist w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test1\", \"password\": \"test1\"}")) + req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1", "password": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"ent: machine not found\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"ent: machine not found"}`, w.Body.String()) // Login with invalid body w = httptest.NewRecorder() @@ -46,31 +39,28 @@ func TestLogin(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"missing: invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"missing: invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Login with invalid format w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test1\"}")) + req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"validation failure list:\\npassword in body is required\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String()) //Validate machine - err = ValidateMachine("test", config.API.Server.DbConfig) - if err != nil { - log.Fatalln(err) - } + ValidateMachine(t, "test", config.API.Server.DbConfig) // Login with invalid password w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test\", \"password\": \"test1\"}")) + req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"incorrect Username or Password\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"incorrect Username or Password"}`, w.Body.String()) // Login with valid machine w = httptest.NewRecorder() @@ -79,16 +69,16 @@ func TestLogin(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) - assert.Contains(t, w.Body.String(), "\"token\"") - assert.Contains(t, w.Body.String(), "\"expire\"") + assert.Contains(t, w.Body.String(), `"token"`) + assert.Contains(t, w.Body.String(), `"expire"`) // Login with valid machine + scenarios w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test\", \"password\": \"test\", \"scenarios\": [\"crowdsecurity/test\", \"crowdsecurity/test2\"]}")) + req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test", "scenarios": ["crowdsecurity/test", "crowdsecurity/test2"]}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) - assert.Contains(t, w.Body.String(), "\"token\"") - assert.Contains(t, w.Body.String(), "\"expire\"") + assert.Contains(t, w.Body.String(), `"token"`) + assert.Contains(t, w.Body.String(), `"expire"`) } diff --git a/pkg/apiserver/machines_test.go b/pkg/apiserver/machines_test.go index 6ac016404..08efa91c6 100644 --- a/pkg/apiserver/machines_test.go +++ b/pkg/apiserver/machines_test.go @@ -7,15 +7,12 @@ import ( "strings" "testing" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCreateMachine(t *testing.T) { - router, _, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, _ := NewAPITest(t) // Create machine with invalid format w := httptest.NewRecorder() @@ -24,22 +21,21 @@ func TestCreateMachine(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 400, w.Code) - assert.Equal(t, "{\"message\":\"invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String()) + assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Create machine with invalid input w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader("{\"test\": \"test\"}")) + req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(`{"test": "test"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 500, w.Code) - assert.Equal(t, "{\"message\":\"validation failure list:\\nmachine_id in body is required\\npassword in body is required\"}", w.Body.String()) + assert.Equal(t, `{"message":"validation failure list:\nmachine_id in body is required\npassword in body is required"}`, w.Body.String()) // Create machine b, err := json.Marshal(MachineTest) - if err != nil { - log.Fatal("unable to marshal MachineTest") - } + require.NoError(t, err) + body := string(b) w = httptest.NewRecorder() @@ -52,16 +48,12 @@ func TestCreateMachine(t *testing.T) { } func TestCreateMachineWithForwardedFor(t *testing.T) { - router, config, err := NewAPITestForwardedFor(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, config := NewAPITestForwardedFor(t) router.TrustedPlatform = "X-Real-IP" // Create machine b, err := json.Marshal(MachineTest) - if err != nil { - log.Fatal("unable to marshal MachineTest") - } + require.NoError(t, err) + body := string(b) w := httptest.NewRecorder() @@ -73,25 +65,18 @@ func TestCreateMachineWithForwardedFor(t *testing.T) { assert.Equal(t, 201, w.Code) assert.Equal(t, "", w.Body.String()) - ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig) - if err != nil { - log.Fatalf("Could not get machine IP : %s", err) - } + ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig) assert.Equal(t, "1.1.1.1", ip) } func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { - router, config, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, config := NewAPITest(t) // Create machine b, err := json.Marshal(MachineTest) - if err != nil { - log.Fatal("unable to marshal MachineTest") - } + require.NoError(t, err) + body := string(b) w := httptest.NewRecorder() @@ -103,26 +88,20 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { assert.Equal(t, 201, w.Code) assert.Equal(t, "", w.Body.String()) - ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig) - if err != nil { - log.Fatalf("Could not get machine IP : %s", err) - } + ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig) + //For some reason, the IP is empty when running tests //if no forwarded-for headers are present assert.Equal(t, "", ip) } func TestCreateMachineWithoutForwardedFor(t *testing.T) { - router, config, err := NewAPITestForwardedFor(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, config := NewAPITestForwardedFor(t) // Create machine b, err := json.Marshal(MachineTest) - if err != nil { - log.Fatal("unable to marshal MachineTest") - } + require.NoError(t, err) + body := string(b) w := httptest.NewRecorder() @@ -133,25 +112,17 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) { assert.Equal(t, 201, w.Code) assert.Equal(t, "", w.Body.String()) - ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig) - if err != nil { - log.Fatalf("Could not get machine IP : %s", err) - } + ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig) + //For some reason, the IP is empty when running tests //if no forwarded-for headers are present assert.Equal(t, "", ip) } func TestCreateMachineAlreadyExist(t *testing.T) { - router, _, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, _ := NewAPITest(t) - body, err := CreateTestMachine(router) - if err != nil { - log.Fatalln(err) - } + body := CreateTestMachine(t, router) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) @@ -164,5 +135,5 @@ func TestCreateMachineAlreadyExist(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 403, w.Code) - assert.Equal(t, "{\"message\":\"user 'test': user already exist\"}", w.Body.String()) + assert.Equal(t, `{"message":"user 'test': user already exist"}`, w.Body.String()) }