From 89f704ef184a9b59c7e7aad158928ef44fdc556e Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Thu, 14 Dec 2023 14:54:11 +0100 Subject: [PATCH] light pkg/api{client,server} refact (#2659) * tests: don't run crowdsec if not necessary * make listen_uri report the random port number when 0 is requested * move apiserver.getTLSAuthType() -> csconfig.TLSCfg.GetAuthType() * move apiserver.isEnrolled() -> apiclient.ApiClient.IsEnrolled() * extract function apiserver.recoverFromPanic() * simplify and move APIServer.GetTLSConfig() -> TLSCfg.GetTLSConfig() * moved TLSCfg type to csconfig/tls.go * APIServer.InitController(): early return / happy path * extract function apiserver.newGinLogger() * lapi tests * update unit test * lint (testify) * lint (whitespace, variable names) * update docker tests --- docker/test/tests/test_agent.py | 10 +- docker/test/tests/test_agent_only.py | 2 +- docker/test/tests/test_local_api_url.py | 6 +- docker/test/tests/test_tls.py | 12 +- pkg/apiclient/alerts_service.go | 30 +- pkg/apiclient/alerts_service_test.go | 25 +- pkg/apiclient/auth.go | 35 +- pkg/apiclient/auth_service.go | 7 + pkg/apiclient/auth_service_test.go | 21 +- pkg/apiclient/auth_test.go | 5 +- pkg/apiclient/client.go | 33 +- pkg/apiclient/client_http.go | 12 +- pkg/apiclient/client_http_test.go | 2 + pkg/apiclient/client_test.go | 6 + pkg/apiclient/decisions_service.go | 35 +- pkg/apiclient/decisions_service_test.go | 32 +- pkg/apiclient/decisions_sync_service.go | 4 + pkg/apiclient/heartbeat.go | 1 - pkg/apiclient/metrics.go | 2 + pkg/apiclient/signal.go | 3 + pkg/apiserver/alerts_test.go | 26 +- pkg/apiserver/api_key_test.go | 1 - pkg/apiserver/apic.go | 176 +++++++-- pkg/apiserver/apic_metrics.go | 13 +- pkg/apiserver/apic_metrics_test.go | 1 + pkg/apiserver/apic_test.go | 45 ++- pkg/apiserver/apiserver.go | 451 +++++++++++----------- pkg/apiserver/apiserver_test.go | 55 ++- pkg/apiserver/controllers/controller.go | 5 +- pkg/apiserver/controllers/v1/alerts.go | 6 +- pkg/apiserver/controllers/v1/decisions.go | 5 +- pkg/apiserver/controllers/v1/errors.go | 3 +- pkg/apiserver/controllers/v1/machines.go | 5 +- pkg/apiserver/controllers/v1/metrics.go | 7 +- pkg/apiserver/controllers/v1/utils.go | 3 +- pkg/apiserver/decisions_test.go | 96 ++--- pkg/apiserver/jwt_test.go | 1 - pkg/apiserver/machines_test.go | 3 +- pkg/apiserver/middlewares/v1/api_key.go | 7 +- pkg/apiserver/middlewares/v1/jwt.go | 20 +- pkg/apiserver/utils.go | 27 -- pkg/csconfig/api.go | 17 +- pkg/csconfig/api_test.go | 4 +- pkg/csconfig/tls.go | 87 +++++ test/bats/01_crowdsec_lapi.bats | 51 +++ test/bats/01_cscli.bats | 6 +- 46 files changed, 927 insertions(+), 477 deletions(-) delete mode 100644 pkg/apiserver/utils.go create mode 100644 pkg/csconfig/tls.go create mode 100644 test/bats/01_crowdsec_lapi.bats diff --git a/docker/test/tests/test_agent.py b/docker/test/tests/test_agent.py index e1ede3f89..e55d11af8 100644 --- a/docker/test/tests/test_agent.py +++ b/docker/test/tests/test_agent.py @@ -13,7 +13,7 @@ def test_no_agent(crowdsec, flavor): 'DISABLE_AGENT': 'true', } with crowdsec(flavor=flavor, environment=env) as cs: - cs.wait_for_log("*CrowdSec Local API listening on 0.0.0.0:8080*") + cs.wait_for_log("*CrowdSec Local API listening on *:8080*") cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) res = cs.cont.exec_run('cscli lapi status') assert res.exit_code == 0 @@ -37,7 +37,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory): with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs: cs.wait_for_log([ "*Generate local agent credentials*", - "*CrowdSec Local API listening on 0.0.0.0:8080*", + "*CrowdSec Local API listening on *:8080*", ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) res = cs.cont.exec_run('cscli lapi status') @@ -50,7 +50,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory): with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs: cs.wait_for_log([ "*Generate local agent credentials*", - "*CrowdSec Local API listening on 0.0.0.0:8080*", + "*CrowdSec Local API listening on *:8080*", ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) res = cs.cont.exec_run('cscli lapi status') @@ -65,7 +65,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory): with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs: cs.wait_for_log([ "*Generate local agent credentials*", - "*CrowdSec Local API listening on 0.0.0.0:8080*", + "*CrowdSec Local API listening on *:8080*", ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) res = cs.cont.exec_run('cscli lapi status') @@ -78,7 +78,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory): with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs: cs.wait_for_log([ "*Local agent already registered*", - "*CrowdSec Local API listening on 0.0.0.0:8080*", + "*CrowdSec Local API listening on *:8080*", ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) res = cs.cont.exec_run('cscli lapi status') diff --git a/docker/test/tests/test_agent_only.py b/docker/test/tests/test_agent_only.py index d9db3ca30..038b726e3 100644 --- a/docker/test/tests/test_agent_only.py +++ b/docker/test/tests/test_agent_only.py @@ -29,7 +29,7 @@ def test_split_lapi_agent(crowdsec, flavor): cs_agent = crowdsec(name=agentname, environment=agent_env, flavor=flavor) with cs_lapi as lapi: - lapi.wait_for_log("*CrowdSec Local API listening on 0.0.0.0:8080*") + lapi.wait_for_log("*CrowdSec Local API listening on *:8080*") lapi.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) with cs_agent as agent: agent.wait_for_log("*Starting processing data*") diff --git a/docker/test/tests/test_local_api_url.py b/docker/test/tests/test_local_api_url.py index 262e8fbef..aa90c9fb7 100644 --- a/docker/test/tests/test_local_api_url.py +++ b/docker/test/tests/test_local_api_url.py @@ -11,7 +11,7 @@ def test_local_api_url_default(crowdsec, flavor): """Test LOCAL_API_URL (default)""" with crowdsec(flavor=flavor) as cs: cs.wait_for_log([ - "*CrowdSec Local API listening on 0.0.0.0:8080*", + "*CrowdSec Local API listening on *:8080*", "*Starting processing data*" ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) @@ -29,7 +29,7 @@ def test_local_api_url(crowdsec, flavor): } with crowdsec(flavor=flavor, environment=env) as cs: cs.wait_for_log([ - "*CrowdSec Local API listening on 0.0.0.0:8080*", + "*CrowdSec Local API listening on *:8080*", "*Starting processing data*" ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) @@ -54,7 +54,7 @@ def test_local_api_url_ipv6(crowdsec, flavor): with crowdsec(flavor=flavor, environment=env) as cs: cs.wait_for_log([ "*Starting processing data*", - "*CrowdSec Local API listening on 0.0.0.0:8080*", + "*CrowdSec Local API listening on [::1]:8080*", ]) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) res = cs.cont.exec_run('cscli lapi status') diff --git a/docker/test/tests/test_tls.py b/docker/test/tests/test_tls.py index eeda18f56..f12b2ff1b 100644 --- a/docker/test/tests/test_tls.py +++ b/docker/test/tests/test_tls.py @@ -23,7 +23,7 @@ def test_missing_key_file(crowdsec, flavor): with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs: # XXX: this message appears twice, is that normal? - cs.wait_for_log("*while serving local API: missing TLS key file*") + cs.wait_for_log("*while starting API server: missing TLS key file*") def test_missing_cert_file(crowdsec, flavor): @@ -35,7 +35,7 @@ def test_missing_cert_file(crowdsec, flavor): } with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs: - cs.wait_for_log("*while serving local API: missing TLS cert file*") + cs.wait_for_log("*while starting API server: missing TLS cert file*") def test_tls_missing_ca(crowdsec, flavor, certs_dir): @@ -174,7 +174,7 @@ def test_tls_split_lapi_agent(crowdsec, flavor, certs_dir): with cs_lapi as lapi: lapi.wait_for_log([ "*(tls) Client Auth Type set to VerifyClientCertIfGiven*", - "*CrowdSec Local API listening on 0.0.0.0:8080*" + "*CrowdSec Local API listening on *:8080*" ]) # TODO: wait_for_https lapi.wait_for_http(8080, '/health', want_status=None) @@ -225,7 +225,7 @@ def test_tls_mutual_split_lapi_agent(crowdsec, flavor, certs_dir): with cs_lapi as lapi: lapi.wait_for_log([ "*(tls) Client Auth Type set to VerifyClientCertIfGiven*", - "*CrowdSec Local API listening on 0.0.0.0:8080*" + "*CrowdSec Local API listening on *:8080*" ]) # TODO: wait_for_https lapi.wait_for_http(8080, '/health', want_status=None) @@ -276,7 +276,7 @@ def test_tls_client_ou(crowdsec, certs_dir): with cs_lapi as lapi: lapi.wait_for_log([ "*(tls) Client Auth Type set to VerifyClientCertIfGiven*", - "*CrowdSec Local API listening on 0.0.0.0:8080*" + "*CrowdSec Local API listening on *:8080*" ]) # TODO: wait_for_https lapi.wait_for_http(8080, '/health', want_status=None) @@ -306,7 +306,7 @@ def test_tls_client_ou(crowdsec, certs_dir): with cs_lapi as lapi: lapi.wait_for_log([ "*(tls) Client Auth Type set to VerifyClientCertIfGiven*", - "*CrowdSec Local API listening on 0.0.0.0:8080*" + "*CrowdSec Local API listening on *:8080*" ]) # TODO: wait_for_https lapi.wait_for_http(8080, '/health', want_status=None) diff --git a/pkg/apiclient/alerts_service.go b/pkg/apiclient/alerts_service.go index dd2ba2975..1d0a4ebd1 100644 --- a/pkg/apiclient/alerts_service.go +++ b/pkg/apiclient/alerts_service.go @@ -49,31 +49,37 @@ type AlertsDeleteOpts struct { } func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) (*models.AddAlertsResponse, *Response, error) { - - var added_ids models.AddAlertsResponse + var addedIds models.AddAlertsResponse u := fmt.Sprintf("%s/alerts", s.client.URLPrefix) req, err := s.client.NewRequest(http.MethodPost, u, &alerts) + if err != nil { return nil, nil, err } - resp, err := s.client.Do(ctx, req, &added_ids) + resp, err := s.client.Do(ctx, req, &addedIds) if err != nil { return nil, resp, err } - return &added_ids, resp, nil + + return &addedIds, resp, nil } // to demo query arguments func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.GetAlertsResponse, *Response, error) { - var alerts models.GetAlertsResponse - var URI string + var ( + alerts models.GetAlertsResponse + URI string + ) + u := fmt.Sprintf("%s/alerts", s.client.URLPrefix) params, err := qs.Values(opts) + if err != nil { return nil, nil, fmt.Errorf("building query: %w", err) } + if len(params) > 0 { URI = fmt.Sprintf("%s?%s", u, params.Encode()) } else { @@ -89,16 +95,19 @@ func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models. if err != nil { return nil, resp, fmt.Errorf("performing request: %w", err) } + return &alerts, resp, nil } // to demo query arguments func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*models.DeleteAlertsResponse, *Response, error) { var alerts models.DeleteAlertsResponse + params, err := qs.Values(opts) if err != nil { return nil, nil, err } + u := fmt.Sprintf("%s/alerts?%s", s.client.URLPrefix, params.Encode()) req, err := s.client.NewRequest(http.MethodDelete, u, nil) @@ -110,12 +119,14 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod if err != nil { return nil, resp, err } + return &alerts, resp, nil } -func (s *AlertsService) DeleteOne(ctx context.Context, alert_id string) (*models.DeleteAlertsResponse, *Response, error) { +func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.DeleteAlertsResponse, *Response, error) { var alerts models.DeleteAlertsResponse - u := fmt.Sprintf("%s/alerts/%s", s.client.URLPrefix, alert_id) + + u := fmt.Sprintf("%s/alerts/%s", s.client.URLPrefix, alertID) req, err := s.client.NewRequest(http.MethodDelete, u, nil) if err != nil { @@ -126,11 +137,13 @@ func (s *AlertsService) DeleteOne(ctx context.Context, alert_id string) (*models if err != nil { return nil, resp, err } + return &alerts, resp, nil } func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert, *Response, error) { var alert models.Alert + u := fmt.Sprintf("%s/alerts/%d", s.client.URLPrefix, alertID) req, err := s.client.NewRequest(http.MethodGet, u, nil) @@ -142,5 +155,6 @@ func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert if err != nil { return nil, nil, err } + return &alert, resp, nil } diff --git a/pkg/apiclient/alerts_service_test.go b/pkg/apiclient/alerts_service_test.go index 49313f007..fcc9bd06a 100644 --- a/pkg/apiclient/alerts_service_test.go +++ b/pkg/apiclient/alerts_service_test.go @@ -26,10 +26,12 @@ func TestAlertsListAsMachine(t *testing.T) { w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") if err != nil { log.Fatalf("parsing api url: %s", apiURL) } + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -199,6 +201,7 @@ func TestAlertsListAsMachine(t *testing.T) { if err != nil { log.Errorf("test Unable to list alerts : %+v", err) } + if resp.Response.StatusCode != http.StatusOK { t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) } @@ -209,14 +212,17 @@ func TestAlertsListAsMachine(t *testing.T) { //this one doesn't filter := AlertsListOpts{IPEquals: new(string)} *filter.IPEquals = "1.2.3.4" + alerts, resp, err = client.Alerts.List(context.Background(), filter) if err != nil { log.Errorf("test Unable to list alerts : %+v", err) } + if resp.Response.StatusCode != http.StatusOK { t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) } - assert.Equal(t, 0, len(*alerts)) + + assert.Empty(t, *alerts) } func TestAlertsGetAsMachine(t *testing.T) { @@ -228,10 +234,12 @@ func TestAlertsGetAsMachine(t *testing.T) { w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") if err != nil { log.Fatalf("parsing api url: %s", apiURL) } + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -390,6 +398,7 @@ func TestAlertsGetAsMachine(t *testing.T) { alerts, resp, err := client.Alerts.GetByID(context.Background(), 1) require.NoError(t, err) + if resp.Response.StatusCode != http.StatusOK { t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) } @@ -401,7 +410,6 @@ func TestAlertsGetAsMachine(t *testing.T) { //fail _, _, err = client.Alerts.GetByID(context.Background(), 2) assert.Contains(t, fmt.Sprintf("%s", err), "API error: object not found") - } func TestAlertsCreateAsMachine(t *testing.T) { @@ -418,10 +426,12 @@ func TestAlertsCreateAsMachine(t *testing.T) { w.Write([]byte(`["3"]`)) }) log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") if err != nil { log.Fatalf("parsing api url: %s", apiURL) } + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -435,13 +445,17 @@ func TestAlertsCreateAsMachine(t *testing.T) { } defer teardown() + alert := models.AddAlertsRequest{} alerts, resp, err := client.Alerts.Add(context.Background(), alert) require.NoError(t, err) + expected := &models.AddAlertsResponse{"3"} + if resp.Response.StatusCode != http.StatusOK { t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) } + if !reflect.DeepEqual(*alerts, *expected) { t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) } @@ -457,15 +471,17 @@ func TestAlertsDeleteAsMachine(t *testing.T) { }) mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "DELETE") - assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4") + assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) w.WriteHeader(http.StatusOK) w.Write([]byte(`{"message":"0 deleted alerts"}`)) }) log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") if err != nil { log.Fatalf("parsing api url: %s", apiURL) } + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -479,15 +495,18 @@ func TestAlertsDeleteAsMachine(t *testing.T) { } defer teardown() + alert := AlertsDeleteOpts{IPEquals: new(string)} *alert.IPEquals = "1.2.3.4" alerts, resp, err := client.Alerts.Delete(context.Background(), alert) require.NoError(t, err) expected := &models.DeleteAlertsResponse{NbDeleted: ""} + if resp.Response.StatusCode != http.StatusOK { t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) } + if !reflect.DeepEqual(*alerts, *expected) { t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) } diff --git a/pkg/apiclient/auth.go b/pkg/apiclient/auth.go index 96230b910..86cdc7736 100644 --- a/pkg/apiclient/auth.go +++ b/pkg/apiclient/auth.go @@ -41,10 +41,13 @@ func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) { // specification of http.RoundTripper. req = cloneRequest(req) req.Header.Add("X-Api-Key", t.APIKey) + if t.UserAgent != "" { req.Header.Add("User-Agent", t.UserAgent) } + log.Debugf("req-api: %s %s", req.Method, req.URL.String()) + if log.GetLevel() >= log.TraceLevel { dump, _ := httputil.DumpRequest(req, true) log.Tracef("auth-api request: %s", string(dump)) @@ -55,6 +58,7 @@ func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) { log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err) return resp, err } + if log.GetLevel() >= log.TraceLevel { dump, _ := httputil.DumpResponse(resp, true) log.Tracef("auth-api response: %s", string(dump)) @@ -73,6 +77,7 @@ func (t *APIKeyTransport) transport() http.RoundTripper { if t.Transport != nil { return t.Transport } + return http.DefaultTransport } @@ -90,15 +95,19 @@ func (r retryRoundTripper) ShouldRetry(statusCode int) bool { return true } } + return false } func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - var resp *http.Response - var err error + var ( + resp *http.Response + err error + ) backoff := 0 maxAttempts := r.maxAttempts + if fflag.DisableHttpRetryBackoff.IsEnabled() { maxAttempts = 1 } @@ -108,6 +117,7 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) if r.withBackOff { backoff += 10 + rand.Intn(20) } + log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts) select { case <-req.Context().Done(): @@ -115,22 +125,28 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) case <-time.After(time.Duration(backoff) * time.Second): } } + if r.onBeforeRequest != nil { r.onBeforeRequest(i) } + clonedReq := cloneRequest(req) resp, err = r.next.RoundTrip(clonedReq) + if err != nil { left := maxAttempts - i - 1 if left > 0 { log.Errorf("error while performing request: %s; %d retries left", err, left) } + continue } + if !r.ShouldRetry(resp.StatusCode) { return resp, nil } } + return resp, err } @@ -157,6 +173,7 @@ func (t *JWTTransport) refreshJwtToken() error { if err != nil { return fmt.Errorf("can't update scenario list: %s", err) } + log.Debugf("scenarios list updated for '%s'", *t.MachineID) } @@ -175,14 +192,18 @@ func (t *JWTTransport) refreshJwtToken() error { enc := json.NewEncoder(buf) enc.SetEscapeHTML(false) err = enc.Encode(auth) + if err != nil { return fmt.Errorf("could not encode jwt auth body: %w", err) } + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf) if err != nil { return fmt.Errorf("could not create request: %w", err) } + req.Header.Add("Content-Type", "application/json") + client := &http.Client{ Transport: &retryRoundTripper{ next: http.DefaultTransport, @@ -191,9 +212,11 @@ func (t *JWTTransport) refreshJwtToken() error { retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError}, }, } + if t.UserAgent != "" { req.Header.Add("User-Agent", t.UserAgent) } + if log.GetLevel() >= log.TraceLevel { dump, _ := httputil.DumpRequest(req, true) log.Tracef("auth-jwt request: %s", string(dump)) @@ -205,6 +228,7 @@ func (t *JWTTransport) refreshJwtToken() error { if err != nil { return fmt.Errorf("could not get jwt token: %w", err) } + log.Debugf("auth-jwt : http %d", resp.StatusCode) if log.GetLevel() >= log.TraceLevel { @@ -226,12 +250,15 @@ func (t *JWTTransport) refreshJwtToken() error { if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { return fmt.Errorf("unable to decode response: %w", err) } + if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil { return fmt.Errorf("unable to parse jwt expiration: %w", err) } + t.Token = response.Token log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String()) + return nil } @@ -267,6 +294,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { dump, _ := httputil.DumpResponse(resp, true) log.Tracef("resp-jwt: %s (err:%v)", string(dump), err) } + if err != nil { /*we had an error (network error for example, or 401 because token is refused), reset the token ?*/ t.Token = "" @@ -333,9 +361,12 @@ func cloneRequest(r *http.Request) *http.Request { if r.Body != nil { var b bytes.Buffer + b.ReadFrom(r.Body) + r.Body = io.NopCloser(&b) r2.Body = io.NopCloser(bytes.NewReader(b.Bytes())) } + return r2 } diff --git a/pkg/apiclient/auth_service.go b/pkg/apiclient/auth_service.go index 64284902e..e43503852 100644 --- a/pkg/apiclient/auth_service.go +++ b/pkg/apiclient/auth_service.go @@ -22,6 +22,7 @@ type enrollRequest struct { func (s *AuthService) UnregisterWatcher(ctx context.Context) (*Response, error) { u := fmt.Sprintf("%s/watchers", s.client.URLPrefix) + req, err := s.client.NewRequest(http.MethodDelete, u, nil) if err != nil { return nil, err @@ -31,6 +32,7 @@ func (s *AuthService) UnregisterWatcher(ctx context.Context) (*Response, error) if err != nil { return resp, err } + return resp, nil } @@ -46,6 +48,7 @@ func (s *AuthService) RegisterWatcher(ctx context.Context, registration models.W if err != nil { return resp, err } + return resp, nil } @@ -53,6 +56,7 @@ func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.Watch var authResp models.WatcherAuthResponse u := fmt.Sprintf("%s/watchers/login", s.client.URLPrefix) + req, err := s.client.NewRequest(http.MethodPost, u, &auth) if err != nil { return authResp, nil, err @@ -62,11 +66,13 @@ func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.Watch if err != nil { return authResp, resp, err } + return authResp, resp, nil } func (s *AuthService) EnrollWatcher(ctx context.Context, enrollKey string, name string, tags []string, overwrite bool) (*Response, error) { u := fmt.Sprintf("%s/watchers/enroll", s.client.URLPrefix) + req, err := s.client.NewRequest(http.MethodPost, u, &enrollRequest{EnrollKey: enrollKey, Name: name, Tags: tags, Overwrite: overwrite}) if err != nil { return nil, err @@ -76,5 +82,6 @@ func (s *AuthService) EnrollWatcher(ctx context.Context, enrollKey string, name if err != nil { return resp, err } + return resp, nil } diff --git a/pkg/apiclient/auth_service_test.go b/pkg/apiclient/auth_service_test.go index 7c89d24f1..b56d52868 100644 --- a/pkg/apiclient/auth_service_test.go +++ b/pkg/apiclient/auth_service_test.go @@ -35,6 +35,7 @@ func getLoginsForMockErrorCases() map[string]int { func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { loginsForMockErrorCases := getLoginsForMockErrorCases() + mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") buf := new(bytes.Buffer) @@ -71,7 +72,6 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { * 400, 409, 500 => Error */ func TestWatcherRegister(t *testing.T) { - log.SetLevel(log.DebugLevel) mux, urlx, teardown := setup() @@ -79,6 +79,7 @@ func TestWatcherRegister(t *testing.T) { //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} initBasicMuxMock(t, mux, "/watchers") log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") if err != nil { t.Fatalf("parsing api url: %s", apiURL) @@ -92,16 +93,19 @@ func TestWatcherRegister(t *testing.T) { URL: apiURL, VersionPrefix: "v1", } + client, err := RegisterClient(&clientconfig, &http.Client{}) if client == nil || err != nil { t.Fatalf("while registering client : %s", err) } + log.Printf("->%T", client) // Testing error handling on Registration (400, 409, 500): should retrieve an error errorCodesToTest := [3]int{http.StatusBadRequest, http.StatusConflict, http.StatusInternalServerError} for _, errorCodeToTest := range errorCodesToTest { clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest) + client, err = RegisterClient(&clientconfig, &http.Client{}) if client != nil || err == nil { t.Fatalf("The RegisterClient function should have returned an error for the response code %d", errorCodeToTest) @@ -112,7 +116,6 @@ func TestWatcherRegister(t *testing.T) { } func TestWatcherAuth(t *testing.T) { - log.SetLevel(log.DebugLevel) mux, urlx, teardown := setup() @@ -121,6 +124,7 @@ func TestWatcherAuth(t *testing.T) { initBasicMuxMock(t, mux, "/watchers/login") log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") if err != nil { t.Fatalf("parsing api url: %s", apiURL) @@ -169,6 +173,7 @@ func TestWatcherAuth(t *testing.T) { if err == nil { resp.Response.Body.Close() + bodyBytes, err := io.ReadAll(resp.Response.Body) if err != nil { t.Fatalf("error while reading body: %s", err.Error()) @@ -176,14 +181,13 @@ func TestWatcherAuth(t *testing.T) { log.Printf(string(bodyBytes)) t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest) - } else { - log.Printf("The AuthenticateWatcher function handled the error code %d as expected \n\r", errorCodeToTest) } + + log.Printf("The AuthenticateWatcher function handled the error code %d as expected \n\r", errorCodeToTest) } } func TestWatcherUnregister(t *testing.T) { - log.SetLevel(log.DebugLevel) mux, urlx, teardown := setup() @@ -192,7 +196,7 @@ func TestWatcherUnregister(t *testing.T) { mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "DELETE") - assert.Equal(t, r.ContentLength, int64(0)) + assert.Equal(t, int64(0), r.ContentLength) w.WriteHeader(http.StatusOK) }) mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { @@ -211,10 +215,12 @@ func TestWatcherUnregister(t *testing.T) { }) log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") if err != nil { t.Fatalf("parsing api url: %s", apiURL) } + mycfg := &Config{ MachineID: "test_login", Password: "test_password", @@ -228,10 +234,12 @@ func TestWatcherUnregister(t *testing.T) { if err != nil { t.Fatalf("new api client: %s", err) } + _, err = client.Auth.UnregisterWatcher(context.Background()) if err != nil { t.Fatalf("while registering client : %s", err) } + log.Printf("->%T", client) } @@ -264,6 +272,7 @@ func TestWatcherEnroll(t *testing.T) { fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`) }) log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") if err != nil { t.Fatalf("parsing api url: %s", apiURL) diff --git a/pkg/apiclient/auth_test.go b/pkg/apiclient/auth_test.go index f28a0ea05..7e7377a43 100644 --- a/pkg/apiclient/auth_test.go +++ b/pkg/apiclient/auth_test.go @@ -18,7 +18,7 @@ func TestApiAuth(t *testing.T) { mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") if r.Header.Get("X-Api-Key") == "ixu" { - assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4") + assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) w.WriteHeader(http.StatusOK) w.Write([]byte(`null`)) } else { @@ -66,9 +66,11 @@ func TestApiAuth(t *testing.T) { _, resp, err = newcli.Decisions.List(context.Background(), alert) log.Infof("--> %s", err) + if resp.Response.StatusCode != http.StatusForbidden { t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) } + assert.Contains(t, err.Error(), "API error: access forbidden") //ko empty token auth = &APIKeyTransport{} @@ -82,5 +84,4 @@ func TestApiAuth(t *testing.T) { log.Infof("--> %s", err) assert.Contains(t, err.Error(), "APIKey is empty") - } diff --git a/pkg/apiclient/client.go b/pkg/apiclient/client.go index d95f77490..75bc52881 100644 --- a/pkg/apiclient/client.go +++ b/pkg/apiclient/client.go @@ -10,6 +10,8 @@ import ( "net/http" "net/url" + "github.com/golang-jwt/jwt/v4" + "github.com/crowdsecurity/crowdsec/pkg/models" ) @@ -43,6 +45,21 @@ func (a *ApiClient) GetClient() *http.Client { return a.client } +func (a *ApiClient) IsEnrolled() bool { + jwtTransport := a.client.Transport.(*JWTTransport) + tokenStr := jwtTransport.Token + + token, _ := jwt.Parse(tokenStr, nil) + if token == nil { + return false + } + + claims := token.Claims.(jwt.MapClaims) + _, ok := claims["organization_id"] + + return ok +} + type service struct { client *ApiClient } @@ -59,12 +76,15 @@ func NewClient(config *Config) (*ApiClient, error) { } tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig.RootCAs = CaCertPool + if Cert != nil { tlsconfig.Certificates = []tls.Certificate{*Cert} } + if ht, ok := http.DefaultTransport.(*http.Transport); ok { ht.TLSClientConfig = &tlsconfig } + c := &ApiClient{client: t.Client(), BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL} c.common.client = c c.Decisions = (*DecisionsService)(&c.common) @@ -81,16 +101,20 @@ func NewClient(config *Config) (*ApiClient, error) { func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *http.Client) (*ApiClient, error) { if client == nil { client = &http.Client{} + if ht, ok := http.DefaultTransport.(*http.Transport); ok { tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig.RootCAs = CaCertPool + if Cert != nil { tlsconfig.Certificates = []tls.Certificate{*Cert} } + ht.TLSClientConfig = &tlsconfig client.Transport = ht } } + c := &ApiClient{client: client, BaseURL: URL, UserAgent: userAgent, URLPrefix: prefix} c.common.client = c c.Decisions = (*DecisionsService)(&c.common) @@ -108,11 +132,13 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) { if client == nil { client = &http.Client{} } + tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} if Cert != nil { tlsconfig.RootCAs = CaCertPool tlsconfig.Certificates = []tls.Certificate{*Cert} } + http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig c := &ApiClient{client: client, BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix} c.common.client = c @@ -126,10 +152,11 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) { if resp != nil && resp.Response != nil { return nil, fmt.Errorf("api register (%s) http %s: %w", c.BaseURL, resp.Response.Status, err) } + return nil, fmt.Errorf("api register (%s): %w", c.BaseURL, err) } - return c, nil + return c, nil } type Response struct { @@ -148,6 +175,7 @@ func (e *ErrorResponse) Error() string { if len(e.Errors) > 0 { err += fmt.Sprintf(" (%s)", e.Errors) } + return err } @@ -160,7 +188,9 @@ func CheckResponse(r *http.Response) error { if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 { return nil } + errorResponse := &ErrorResponse{} + data, err := io.ReadAll(r.Body) if err == nil && data != nil { err := json.Unmarshal(data, errorResponse) @@ -171,6 +201,7 @@ func CheckResponse(r *http.Response) error { errorResponse.Message = new(string) *errorResponse.Message = fmt.Sprintf("http code %d, no error message", r.StatusCode) } + return errorResponse } diff --git a/pkg/apiclient/client_http.go b/pkg/apiclient/client_http.go index 2c55128e1..5222ad770 100644 --- a/pkg/apiclient/client_http.go +++ b/pkg/apiclient/client_http.go @@ -19,6 +19,7 @@ func (c *ApiClient) NewRequest(method, url string, body interface{}) (*http.Requ if !strings.HasSuffix(c.BaseURL.Path, "/") { return nil, fmt.Errorf("BaseURL must have a trailing slash, but %q does not", c.BaseURL) } + u, err := c.BaseURL.Parse(url) if err != nil { return nil, err @@ -29,8 +30,8 @@ func (c *ApiClient) NewRequest(method, url string, body interface{}) (*http.Requ buf = &bytes.Buffer{} enc := json.NewEncoder(buf) enc.SetEscapeHTML(false) - err := enc.Encode(body) - if err != nil { + + if err = enc.Encode(body); err != nil { return nil, err } } @@ -51,6 +52,7 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (* if ctx == nil { return nil, errors.New("context must be non-nil") } + req = req.WithContext(ctx) // Check rate limit @@ -62,6 +64,7 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (* if log.GetLevel() >= log.DebugLevel { log.Debugf("[URL] %s %s", req.Method, req.URL) } + resp, err := c.client.Do(req) if resp != nil && resp.Body != nil { defer resp.Body.Close() @@ -82,8 +85,10 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (* e.URL = url.String() return newResponse(resp), e } + return newResponse(resp), err } + return newResponse(resp), err } @@ -112,9 +117,12 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (* if errors.Is(decErr, io.EOF) { decErr = nil // ignore EOF errors caused by empty response body } + return response, decErr } + io.Copy(w, resp.Body) } + return response, err } diff --git a/pkg/apiclient/client_http_test.go b/pkg/apiclient/client_http_test.go index 7b2075d9d..fa25ee171 100644 --- a/pkg/apiclient/client_http_test.go +++ b/pkg/apiclient/client_http_test.go @@ -21,6 +21,7 @@ func TestNewRequestInvalid(t *testing.T) { if err != nil { t.Fatalf("parsing api url: %s", apiURL) } + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -54,6 +55,7 @@ func TestNewRequestTimeout(t *testing.T) { if err != nil { t.Fatalf("parsing api url: %s", apiURL) } + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", diff --git a/pkg/apiclient/client_test.go b/pkg/apiclient/client_test.go index f2ac84fbe..a75b3dd41 100644 --- a/pkg/apiclient/client_test.go +++ b/pkg/apiclient/client_test.go @@ -40,6 +40,7 @@ func setupWithPrefix(urlPrefix string) (mux *http.ServeMux, serverURL string, te func testMethod(t *testing.T, r *http.Request, want string) { t.Helper() + if got := r.Method; got != want { t.Errorf("Request method: %v, want %v", got, want) } @@ -77,6 +78,7 @@ func TestNewClientOk(t *testing.T) { if err != nil { t.Fatalf("test Unable to list alerts : %+v", err) } + if resp.Response.StatusCode != http.StatusOK { t.Fatalf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusCreated) } @@ -126,6 +128,7 @@ func TestNewDefaultClient(t *testing.T) { if err != nil { t.Fatalf("new api client: %s", err) } + mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`{"code": 401, "message" : "brr"}`)) @@ -157,6 +160,7 @@ func TestNewClientRegisterKO(t *testing.T) { func TestNewClientRegisterOK(t *testing.T) { log.SetLevel(log.TraceLevel) mux, urlx, teardown := setup() + defer teardown() /*mock login*/ @@ -180,12 +184,14 @@ func TestNewClientRegisterOK(t *testing.T) { if err != nil { t.Fatalf("while registering client : %s", err) } + log.Printf("->%T", client) } func TestNewClientBadAnswer(t *testing.T) { log.SetLevel(log.TraceLevel) mux, urlx, teardown := setup() + defer teardown() /*mock login*/ diff --git a/pkg/apiclient/decisions_service.go b/pkg/apiclient/decisions_service.go index b9475971a..89e6eff92 100644 --- a/pkg/apiclient/decisions_service.go +++ b/pkg/apiclient/decisions_service.go @@ -42,6 +42,7 @@ func (o *DecisionsStreamOpts) addQueryParamsToURL(url string) (string, error) { if err != nil { return "", err } + return fmt.Sprintf("%s?%s", url, params.Encode()), nil } @@ -61,10 +62,12 @@ type DecisionsDeleteOpts struct { // to demo query arguments func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*models.GetDecisionsResponse, *Response, error) { var decisions models.GetDecisionsResponse + params, err := qs.Values(opts) if err != nil { return nil, nil, err } + u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode()) req, err := s.client.NewRequest(http.MethodGet, u, nil) @@ -111,14 +114,18 @@ func (s *DecisionsService) GetDecisionsFromGroups(decisionsGroups []*modelscapi. Origin: ptr.Of(types.CAPIOrigin), } } + decisions = append(decisions, partialDecisions...) } + return decisions } func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) { - var decisions modelscapi.GetDecisionsStreamResponse - var v2Decisions models.DecisionsStreamResponse + var ( + decisions modelscapi.GetDecisionsStreamResponse + v2Decisions models.DecisionsStreamResponse + ) scenarioDeleted := "deleted" durationDeleted := "1h" @@ -134,8 +141,10 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m } v2Decisions.New = s.GetDecisionsFromGroups(decisions.New) + for _, decisionsGroup := range decisions.Deleted { partialDecisions := make([]*models.Decision, len(decisionsGroup.Decisions)) + for idx, decision := range decisionsGroup.Decisions { decision := decision // fix exportloopref linter message partialDecisions[idx] = &models.Decision{ @@ -147,6 +156,7 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m Origin: ptr.Of(types.CAPIOrigin), } } + v2Decisions.Deleted = append(v2Decisions.Deleted, partialDecisions...) } @@ -161,6 +171,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl log.Debugf("Fetching blocklist %s", *blocklist.URL) client := http.Client{} + req, err := http.NewRequest(http.MethodGet, *blocklist.URL, nil) if err != nil { return nil, false, err @@ -169,6 +180,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl if lastPullTimestamp != nil { req.Header.Set("If-Modified-Since", *lastPullTimestamp) } + req = req.WithContext(ctx) log.Debugf("[URL] %s %s", req.Method, req.URL) // we dont use client_http Do method because we need the reader and is not provided. We would be forced to use Pipe and goroutine, etc @@ -188,6 +200,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl // If the error type is *url.Error, sanitize its URL before returning. log.Errorf("Error fetching blocklist %s: %s", *blocklist.URL, err) + return nil, false, err } @@ -197,13 +210,17 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl } else { log.Debugf("Blocklist %s has not been modified (decisions about to expire)", *blocklist.URL) } + return nil, false, nil } + if resp.StatusCode != http.StatusOK { log.Debugf("Received nok status code %d for blocklist %s", resp.StatusCode, *blocklist.URL) return nil, false, nil } + decisions := make([]*models.Decision, 0) + scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { decision := scanner.Text() @@ -227,6 +244,7 @@ func (s *DecisionsService) GetStream(ctx context.Context, opts DecisionsStreamOp if err != nil { return nil, nil, err } + if s.client.URLPrefix == "v3" { return s.FetchV3Decisions(ctx, u) } else { @@ -239,6 +257,7 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream if err != nil { return nil, nil, err } + var decisions modelscapi.GetDecisionsStreamResponse req, err := s.client.NewRequest(http.MethodGet, u, nil) @@ -255,8 +274,8 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream } func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) { - u := fmt.Sprintf("%s/decisions", s.client.URLPrefix) + req, err := s.client.NewRequest(http.MethodDelete, u, nil) if err != nil { return nil, err @@ -266,15 +285,18 @@ func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) { if err != nil { return resp, err } + return resp, nil } func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) (*models.DeleteDecisionResponse, *Response, error) { var deleteDecisionResponse models.DeleteDecisionResponse + params, err := qs.Values(opts) if err != nil { return nil, nil, err } + u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode()) req, err := s.client.NewRequest(http.MethodDelete, u, nil) @@ -286,12 +308,14 @@ func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) if err != nil { return nil, resp, err } + return &deleteDecisionResponse, resp, nil } -func (s *DecisionsService) DeleteOne(ctx context.Context, decision_id string) (*models.DeleteDecisionResponse, *Response, error) { +func (s *DecisionsService) DeleteOne(ctx context.Context, decisionID string) (*models.DeleteDecisionResponse, *Response, error) { var deleteDecisionResponse models.DeleteDecisionResponse - u := fmt.Sprintf("%s/decisions/%s", s.client.URLPrefix, decision_id) + + u := fmt.Sprintf("%s/decisions/%s", s.client.URLPrefix, decisionID) req, err := s.client.NewRequest(http.MethodDelete, u, nil) if err != nil { @@ -302,5 +326,6 @@ func (s *DecisionsService) DeleteOne(ctx context.Context, decision_id string) (* if err != nil { return nil, resp, err } + return &deleteDecisionResponse, resp, nil } diff --git a/pkg/apiclient/decisions_service_test.go b/pkg/apiclient/decisions_service_test.go index a31b97e2e..e9954d9a1 100644 --- a/pkg/apiclient/decisions_service_test.go +++ b/pkg/apiclient/decisions_service_test.go @@ -28,8 +28,8 @@ func TestDecisionsList(t *testing.T) { mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") if r.URL.RawQuery == "ip=1.2.3.4" { - assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4") - assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu") + assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) + assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) w.WriteHeader(http.StatusOK) w.Write([]byte(`[{"duration":"3h59m55.756182786s","id":4,"origin":"cscli","scenario":"manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'","scope":"Ip","type":"ban","value":"1.2.3.4"}]`)) } else { @@ -83,6 +83,7 @@ func TestDecisionsList(t *testing.T) { if err != nil { t.Fatalf("new api client: %s", err) } + if !reflect.DeepEqual(*decisions, *expected) { t.Fatalf("returned %+v, want %+v", resp, expected) } @@ -96,8 +97,8 @@ func TestDecisionsList(t *testing.T) { if resp.Response.StatusCode != http.StatusOK { t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) } - assert.Equal(t, len(*decisions), 0) + assert.Empty(t, *decisions) } func TestDecisionsStream(t *testing.T) { @@ -107,8 +108,7 @@ func TestDecisionsStream(t *testing.T) { defer teardown() mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { - - assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu") + assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodGet) if r.Method == http.MethodGet { if r.URL.RawQuery == "startup=true" { @@ -121,7 +121,7 @@ func TestDecisionsStream(t *testing.T) { } }) mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu") + assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodDelete) if r.Method == http.MethodDelete { w.WriteHeader(http.StatusOK) @@ -173,6 +173,7 @@ func TestDecisionsStream(t *testing.T) { if err != nil { t.Fatalf("new api client: %s", err) } + if !reflect.DeepEqual(*decisions, *expected) { t.Fatalf("returned %+v, want %+v", resp, expected) } @@ -184,8 +185,9 @@ func TestDecisionsStream(t *testing.T) { if resp.Response.StatusCode != http.StatusOK { t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) } - assert.Equal(t, 0, len(decisions.New)) - assert.Equal(t, 0, len(decisions.Deleted)) + + assert.Empty(t, decisions.New) + assert.Empty(t, decisions.Deleted) //delete stream resp, err = newcli.Decisions.StopStream(context.Background()) @@ -203,8 +205,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { defer teardown() mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { - - assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu") + assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodGet) if r.Method == http.MethodGet { if r.URL.RawQuery == "startup=true" { @@ -275,6 +276,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { if err != nil { t.Fatalf("new api client: %s", err) } + if !reflect.DeepEqual(*decisions, *expected) { t.Fatalf("returned %+v, want %+v", resp, expected) } @@ -287,8 +289,7 @@ func TestDecisionsStreamV3(t *testing.T) { defer teardown() mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { - - assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu") + assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodGet) if r.Method == http.MethodGet { w.WriteHeader(http.StatusOK) @@ -368,6 +369,7 @@ func TestDecisionsStreamV3(t *testing.T) { if err != nil { t.Fatalf("new api client: %s", err) } + if !reflect.DeepEqual(*decisions, *expected) { t.Fatalf("returned %+v, want %+v", resp, expected) } @@ -451,6 +453,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { if err != nil { t.Fatalf("new api client: %s", err) } + if !reflect.DeepEqual(decisions, expected) { t.Fatalf("returned %+v, want %+v", decisions, expected) } @@ -484,7 +487,7 @@ func TestDeleteDecisions(t *testing.T) { }) mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "DELETE") - assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4") + assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) w.WriteHeader(http.StatusOK) w.Write([]byte(`{"nbDeleted":"1"}`)) //w.Write([]byte(`{"message":"0 deleted alerts"}`)) @@ -512,6 +515,7 @@ func TestDeleteDecisions(t *testing.T) { if err != nil { t.Fatalf("unexpected err : %s", err) } + assert.Equal(t, "1", deleted.NbDeleted) defer teardown() @@ -519,6 +523,7 @@ func TestDeleteDecisions(t *testing.T) { func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { baseURLString := "http://localhost:8080/v1/decisions/stream" + type fields struct { Startup bool Scopes string @@ -553,6 +558,7 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { want: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true", }, } + for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/apiclient/decisions_sync_service.go b/pkg/apiclient/decisions_sync_service.go index 57999691f..1aee9b6ca 100644 --- a/pkg/apiclient/decisions_sync_service.go +++ b/pkg/apiclient/decisions_sync_service.go @@ -15,7 +15,9 @@ type DecisionDeleteService service // DecisionDeleteService purposely reuses AddSignalsRequestItemDecisions model func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *models.DecisionsDeleteRequest) (interface{}, *Response, error) { var response interface{} + u := fmt.Sprintf("%s/decisions/delete", d.client.URLPrefix) + req, err := d.client.NewRequest(http.MethodPost, u, &deletedDecisions) if err != nil { return nil, nil, fmt.Errorf("while building request: %w", err) @@ -25,10 +27,12 @@ func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *model if err != nil { return nil, resp, fmt.Errorf("while performing request: %w", err) } + if resp.Response.StatusCode != http.StatusOK { log.Warnf("Decisions delete response : http %s", resp.Response.Status) } else { log.Debugf("Decisions delete response : http %s", resp.Response.Status) } + return &response, resp, nil } diff --git a/pkg/apiclient/heartbeat.go b/pkg/apiclient/heartbeat.go index fb8ef075e..bf61b8d2e 100644 --- a/pkg/apiclient/heartbeat.go +++ b/pkg/apiclient/heartbeat.go @@ -15,7 +15,6 @@ import ( type HeartBeatService service func (h *HeartBeatService) Ping(ctx context.Context) (bool, *Response, error) { - u := fmt.Sprintf("%s/heartbeat", h.client.URLPrefix) req, err := h.client.NewRequest(http.MethodGet, u, nil) diff --git a/pkg/apiclient/metrics.go b/pkg/apiclient/metrics.go index ea447280a..a82273007 100644 --- a/pkg/apiclient/metrics.go +++ b/pkg/apiclient/metrics.go @@ -14,6 +14,7 @@ func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (inte var response interface{} u := fmt.Sprintf("%s/metrics/", s.client.URLPrefix) + req, err := s.client.NewRequest(http.MethodPost, u, &metrics) if err != nil { return nil, nil, err @@ -23,5 +24,6 @@ func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (inte if err != nil { return nil, resp, err } + return &response, resp, nil } diff --git a/pkg/apiclient/signal.go b/pkg/apiclient/signal.go index 2dceb8157..94c02f080 100644 --- a/pkg/apiclient/signal.go +++ b/pkg/apiclient/signal.go @@ -16,6 +16,7 @@ func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsReque var response interface{} u := fmt.Sprintf("%s/signals", s.client.URLPrefix) + req, err := s.client.NewRequest(http.MethodPost, u, &signals) if err != nil { return nil, nil, fmt.Errorf("while building request: %w", err) @@ -25,10 +26,12 @@ func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsReque if err != nil { return nil, resp, fmt.Errorf("while performing request: %w", err) } + if resp.Response.StatusCode != http.StatusOK { log.Warnf("Signal push response : http %s", resp.Response.Status) } else { log.Debugf("Signal push response : http %s", resp.Response.Status) } + return &response, resp, nil } diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index 5fd23d116..5824eb060 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -9,13 +9,13 @@ import ( "sync" "testing" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/gin-gonic/gin" - - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" ) type LAPI struct { @@ -57,6 +57,7 @@ func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, aut if err != nil { l.t.Fatal(err) } + if authType == "apikey" { req.Header.Add("X-Api-Key", l.bouncerKey) } else if authType == "password" { @@ -64,7 +65,9 @@ func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, aut } else { l.t.Fatal("auth type not supported") } + l.router.ServeHTTP(w, req) + return w } @@ -78,6 +81,7 @@ func InitMachineTest(t *testing.T) (*gin.Engine, models.WatcherAuthResponse, csc if err != nil { return nil, models.WatcherAuthResponse{}, config, err } + return router, loginResp, config, nil } @@ -150,7 +154,6 @@ func TestCreateAlert(t *testing.T) { } func TestCreateAlertChannels(t *testing.T) { - apiServer, config, err := NewAPIServer(t) if err != nil { log.Fatalln(err) @@ -164,18 +167,22 @@ func TestCreateAlertChannels(t *testing.T) { } lapi := LAPI{router: apiServer.router, loginResp: loginResp} - var pd csplugin.ProfileAlert - var wg sync.WaitGroup + var ( + pd csplugin.ProfileAlert + wg sync.WaitGroup + ) wg.Add(1) + go func() { pd = <-apiServer.controller.PluginChannel + wg.Done() }() go lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json") wg.Wait() - assert.Equal(t, len(pd.Alert.Decisions), 1) + assert.Len(t, pd.Alert.Decisions, 1) apiServer.Close() } @@ -345,7 +352,6 @@ func TestAlertListFilters(t *testing.T) { w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String()) - } func TestAlertBulkInsert(t *testing.T) { @@ -393,7 +399,6 @@ func TestCreateAlertErrors(t *testing.T) { req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", lapi.loginResp.Token+"s")) lapi.router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - } func TestDeleteAlert(t *testing.T) { @@ -506,5 +511,4 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { lapi.InsertAlertFromFile("./tests/alert_sample.json") assertAlertDeletedFromIP("127.0.0.1") - } diff --git a/pkg/apiserver/api_key_test.go b/pkg/apiserver/api_key_test.go index a77ab3f83..df61e0b26 100644 --- a/pkg/apiserver/api_key_test.go +++ b/pkg/apiserver/api_key_test.go @@ -48,5 +48,4 @@ func TestAPIKey(t *testing.T) { assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 7e4347c2a..dcf12929a 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -75,12 +75,14 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration { if ret <= 0 { return 1 } + return ret } func (a *apic) FetchScenariosListFromDB() ([]string, error) { scenarios := make([]string, 0) machines, err := a.dbClient.ListMachines() + if err != nil { return nil, fmt.Errorf("while listing machines: %w", err) } @@ -88,18 +90,22 @@ func (a *apic) FetchScenariosListFromDB() ([]string, error) { for _, v := range machines { machineScenarios := strings.Split(v.Scenarios, ",") log.Debugf("%d scenarios for machine %d", len(machineScenarios), v.ID) + for _, sv := range machineScenarios { if !slices.Contains(scenarios, sv) && sv != "" { scenarios = append(scenarios, sv) } } } + log.Debugf("Returning list of scenarios : %+v", scenarios) + return scenarios, nil } func decisionsToApiDecisions(decisions []*models.Decision) models.AddSignalsRequestItemDecisions { apiDecisions := models.AddSignalsRequestItemDecisions{} + for _, decision := range decisions { x := &models.AddSignalsRequestItemDecisionsItem{ Duration: ptr.Of(*decision.Duration), @@ -114,11 +120,14 @@ func decisionsToApiDecisions(decisions []*models.Decision) models.AddSignalsRequ UUID: decision.UUID, } *x.ID = decision.ID + if decision.Simulated != nil { x.Simulated = *decision.Simulated } + apiDecisions = append(apiDecisions, x) } + return apiDecisions } @@ -149,6 +158,7 @@ func alertToSignal(alert *models.Alert, scenarioTrust string, shareContext bool) } if shareContext { signal.Context = make([]*models.AddSignalsRequestItemContextItems0, 0) + for _, meta := range alert.Meta { contextItem := models.AddSignalsRequestItemContextItems0{ Key: meta.Key, @@ -157,13 +167,14 @@ func alertToSignal(alert *models.Alert, scenarioTrust string, shareContext bool) signal.Context = append(signal.Context, &contextItem) } } + return signal } func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) { var err error - ret := &apic{ + ret := &apic{ AlertsAddChan: make(chan []*models.Alert), dbClient: dbClient, mu: sync.Mutex{}, @@ -186,9 +197,11 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con password := strfmt.Password(config.Credentials.Password) apiURL, err := url.Parse(config.Credentials.URL) + if err != nil { return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.URL, err) } + papiURL, err := url.Parse(config.Credentials.PapiURL) if err != nil { return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.PapiURL, err) @@ -198,6 +211,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con if err != nil { return nil, fmt.Errorf("while fetching scenarios from db: %w", err) } + ret.apiClient, err = apiclient.NewClient(&apiclient.Config{ MachineID: config.Credentials.Login, Password: password, @@ -228,7 +242,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con return ret, fmt.Errorf("authenticate watcher (%s): %w", config.Credentials.Login, err) } - if err := ret.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil { + if err = ret.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil { return ret, fmt.Errorf("unable to parse jwt expiration: %w", err) } @@ -242,6 +256,7 @@ func (a *apic) Push() error { defer trace.CatchPanic("lapi/pushToAPIC") var cache models.AddSignalsRequest + ticker := time.NewTicker(a.pushIntervalFirst) log.Infof("Start push to CrowdSec Central API (interval: %s once, then %s)", a.pushIntervalFirst.Round(time.Second), a.pushInterval) @@ -252,28 +267,35 @@ func (a *apic) Push() error { a.pullTomb.Kill(nil) a.metricsTomb.Kill(nil) log.Infof("push tomb is dying, sending cache (%d elements) before exiting", len(cache)) + if len(cache) == 0 { return nil } + go a.Send(&cache) + return nil case <-ticker.C: ticker.Reset(a.pushInterval) + if len(cache) > 0 { a.mu.Lock() cacheCopy := cache cache = make(models.AddSignalsRequest, 0) a.mu.Unlock() log.Infof("Signal push: %d signals to push", len(cacheCopy)) + go a.Send(&cacheCopy) } case alerts := <-a.AlertsAddChan: var signals []*models.AddSignalsRequestItem + for _, alert := range alerts { if ok := shouldShareAlert(alert, a.consoleConfig); ok { signals = append(signals, alertToSignal(alert, getScenarioTrustOfAlert(alert), *a.consoleConfig.ShareContext)) } } + a.mu.Lock() cache = append(cache, signals...) a.mu.Unlock() @@ -288,11 +310,13 @@ func getScenarioTrustOfAlert(alert *models.Alert) string { } else if alert.ScenarioVersion == nil || *alert.ScenarioVersion == "" || *alert.ScenarioVersion == "?" { scenarioTrust = "tainted" } + if len(alert.Decisions) > 0 { if *alert.Decisions[0].Origin == types.CscliOrigin { scenarioTrust = "manual" } } + return scenarioTrust } @@ -301,6 +325,7 @@ func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig log.Debugf("simulation enabled for alert (id:%d), will not be sent to CAPI", alert.ID) return false } + switch scenarioTrust := getScenarioTrustOfAlert(alert); scenarioTrust { case "manual": if !*consoleConfig.ShareManualDecisions { @@ -318,6 +343,7 @@ func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig return false } } + return true } @@ -333,34 +359,44 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { I don't know enough about gin to tell how much of an issue it can be. */ - var cache []*models.AddSignalsRequestItem = *cacheOrig - var send models.AddSignalsRequest + var ( + cache []*models.AddSignalsRequestItem = *cacheOrig + send models.AddSignalsRequest + ) bulkSize := 50 pageStart := 0 pageEnd := bulkSize for { - if pageEnd >= len(cache) { send = cache[pageStart:] ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, _, err := a.apiClient.Signal.Add(ctx, &send) + if err != nil { log.Errorf("sending signal to central API: %s", err) return } + break } + send = cache[pageStart:pageEnd] ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, _, err := a.apiClient.Signal.Add(ctx, &send) + if err != nil { //we log it here as well, because the return value of func might be discarded log.Errorf("sending signal to central API: %s", err) } + pageStart += bulkSize pageEnd += bulkSize } @@ -372,18 +408,22 @@ func (a *apic) CAPIPullIsOld() (bool, error) { alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID))) alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert count, err := alerts.Count(a.dbClient.CTX) + if err != nil { return false, fmt.Errorf("while looking for CAPI alert: %w", err) } + if count > 0 { log.Printf("last CAPI pull is newer than 1h30, skip.") return false, nil } + return true, nil } -func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delete_counters map[string]map[string]int) (int, error) { +func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, deleteCounters map[string]map[string]int) (int, error) { nbDeleted := 0 + for _, decision := range deletedDecisions { filter := map[string][]string{ "value": {*decision.Value}, @@ -398,20 +438,25 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet if err != nil { return 0, fmt.Errorf("deleting decisions error: %w", err) } + dbCliDel, err := strconv.Atoi(dbCliRet) if err != nil { return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err) } - updateCounterForDecision(delete_counters, decision.Origin, decision.Scenario, dbCliDel) + + updateCounterForDecision(deleteCounters, decision.Origin, decision.Scenario, dbCliDel) nbDeleted += dbCliDel } + return nbDeleted, nil } -func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, delete_counters map[string]map[string]int) (int, error) { +func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) { var nbDeleted int + for _, decisions := range deletedDecisions { scope := decisions.Scope + for _, decision := range decisions.Decisions { filter := map[string][]string{ "value": {decision}, @@ -425,26 +470,32 @@ func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisi if err != nil { return 0, fmt.Errorf("deleting decisions error: %w", err) } + dbCliDel, err := strconv.Atoi(dbCliRet) if err != nil { return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err) } - updateCounterForDecision(delete_counters, ptr.Of(types.CAPIOrigin), nil, dbCliDel) + + updateCounterForDecision(deleteCounters, ptr.Of(types.CAPIOrigin), nil, dbCliDel) nbDeleted += dbCliDel } } + return nbDeleted, nil } func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert { newAlerts := make([]*models.Alert, 0) + for _, decision := range decisions { found := false + for _, sub := range newAlerts { if sub.Source.Scope == nil { log.Warningf("nil scope in %+v", sub) continue } + if *decision.Origin == types.CAPIOrigin { if *sub.Source.Scope == types.CAPIOrigin { found = true @@ -464,11 +515,13 @@ func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert { log.Warningf("unknown origin %s : %+v", *decision.Origin, decision) } } + if !found { log.Debugf("Create entry for origin:%s scenario:%s", *decision.Origin, *decision.Scenario) newAlerts = append(newAlerts, createAlertForDecision(decision)) } } + return newAlerts } @@ -489,6 +542,7 @@ func createAlertForDecision(decision *models.Decision) *models.Alert { // XXX: this or nil? scenario = "" scope = "" + log.Warningf("unknown origin %s", *decision.Origin) } @@ -512,10 +566,10 @@ func createAlertForDecision(decision *models.Decision) *models.Alert { } // This function takes in list of parent alerts and decisions and then pairs them up. -func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decision, add_counters map[string]map[string]int) []*models.Alert { +func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decision, addCounters map[string]map[string]int) []*models.Alert { for _, decision := range decisions { //count and create separate alerts for each list - updateCounterForDecision(add_counters, decision.Origin, decision.Scenario, 1) + updateCounterForDecision(addCounters, decision.Origin, decision.Scenario, 1) /*CAPI might send lower case scopes, unify it.*/ switch strings.ToLower(*decision.Scope) { @@ -524,6 +578,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio case "range": *decision.Scope = types.Range } + found := false //add the individual decisions to the right list for idx, alert := range alerts { @@ -531,6 +586,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio if *alert.Source.Scope == types.CAPIOrigin { alerts[idx].Decisions = append(alerts[idx].Decisions, decision) found = true + break } } else if *decision.Origin == types.ListOrigin { @@ -543,10 +599,12 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio log.Warningf("unknown origin %s", *decision.Origin) } } + if !found { log.Warningf("Orphaned decision for %s - %s", *decision.Origin, *decision.Scenario) } } + return alerts } @@ -581,18 +639,20 @@ func (a *apic) PullTop(forcePull bool) error { if err != nil { return fmt.Errorf("get stream: %w", err) } + a.startup = false /*to count additions/deletions across lists*/ log.Debugf("Received %d new decisions", len(data.New)) log.Debugf("Received %d deleted decisions", len(data.Deleted)) + if data.Links != nil { log.Debugf("Received %d blocklists links", len(data.Links.Blocklists)) } - add_counters, delete_counters := makeAddAndDeleteCounters() + addCounters, deleteCounters := makeAddAndDeleteCounters() // process deleted decisions - if nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, delete_counters); err != nil { + if nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters); err != nil { return err } else { log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted) @@ -610,28 +670,30 @@ func (a *apic) PullTop(forcePull bool) error { alert := createAlertForDecision(decisions[0]) alertsFromCapi := []*models.Alert{alert} - alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, add_counters) + alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters) - err = a.SaveAlerts(alertsFromCapi, add_counters, delete_counters) + err = a.SaveAlerts(alertsFromCapi, addCounters, deleteCounters) if err != nil { return fmt.Errorf("while saving alerts: %w", err) } // update blocklists - if err := a.UpdateBlocklists(data.Links, add_counters, forcePull); err != nil { + if err := a.UpdateBlocklists(data.Links, addCounters, forcePull); err != nil { return fmt.Errorf("while updating blocklists: %w", err) } + return nil } // we receive a link to a blocklist, we pull the content of the blocklist and we create one alert func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error { - add_counters, _ := makeAddAndDeleteCounters() + addCounters, _ := makeAddAndDeleteCounters() if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{ Blocklists: []*modelscapi.BlocklistLink{blocklist}, - }, add_counters, forcePull); err != nil { + }, addCounters, forcePull); err != nil { return fmt.Errorf("while pulling blocklist: %w", err) } + return nil } @@ -641,17 +703,20 @@ func (a *apic) whitelistedBy(decision *models.Decision) string { if decision.Value == nil { return "" } + ipval := net.ParseIP(*decision.Value) for _, cidr := range a.whitelists.Cidrs { if cidr.Contains(ipval) { return cidr.String() } } + for _, ip := range a.whitelists.Ips { if ip != nil && ip.Equal(ipval) { return ip.String() } } + return "" } @@ -661,12 +726,14 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis } //deal with CAPI whitelists for fire. We want to avoid having a second list, so we shrink in place outIdx := 0 + for _, decision := range decisions { whitelister := a.whitelistedBy(decision) if whitelister != "" { log.Infof("%s from %s is whitelisted by %s", *decision.Value, *decision.Scenario, whitelister) continue } + decisions[outIdx] = decision outIdx++ } @@ -674,17 +741,20 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis return decisions[:outIdx] } -func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, add_counters map[string]map[string]int, delete_counters map[string]map[string]int) error { +func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error { for _, alert := range alertsFromCapi { - setAlertScenario(alert, add_counters, delete_counters) + setAlertScenario(alert, addCounters, deleteCounters) log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions)) + if a.dbClient.Type == "sqlite" && (a.dbClient.WalMode == nil || !*a.dbClient.WalMode) { log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist") } + alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alert) if err != nil { return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err) } + log.Printf("%s : added %d entries, deleted %d entries (alert:%d)", *alert.Source.Scope, inserted, deleted, alertID) } @@ -697,71 +767,91 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo alertQuery.Where(alert.SourceScopeEQ(fmt.Sprintf("%s:%s", types.ListOrigin, *blocklist.Name))) alertQuery.Order(ent.Desc(alert.FieldCreatedAt)) alertInstance, err := alertQuery.First(context.Background()) + if err != nil { if ent.IsNotFound(err) { log.Debugf("no alert found for %s, force refresh", *blocklist.Name) return true, nil } + return false, fmt.Errorf("while getting alert: %w", err) } + decisionQuery := a.dbClient.Ent.Decision.Query() decisionQuery.Where(decision.HasOwnerWith(alert.IDEQ(alertInstance.ID))) firstDecision, err := decisionQuery.First(context.Background()) + if err != nil { if ent.IsNotFound(err) { log.Debugf("no decision found for %s, force refresh", *blocklist.Name) return true, nil } + return false, fmt.Errorf("while getting decision: %w", err) } + if firstDecision == nil || firstDecision.Until == nil || firstDecision.Until.Sub(time.Now().UTC()) < (a.pullInterval+15*time.Minute) { log.Debugf("at least one decision found for %s, expire soon, force refresh", *blocklist.Name) return true, nil } + return false, nil } -func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, add_counters map[string]map[string]int, forcePull bool) error { +func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error { if blocklist.Scope == nil { log.Warningf("blocklist has no scope") return nil } + if blocklist.Duration == nil { log.Warningf("blocklist has no duration") return nil } + if !forcePull { _forcePull, err := a.ShouldForcePullBlocklist(blocklist) if err != nil { return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err) } + forcePull = _forcePull } + blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name) - var lastPullTimestamp *string - var err error + + var ( + lastPullTimestamp *string + err error + ) + if !forcePull { lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName) if err != nil { return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) } } + decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp) if err != nil { return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err) } + if !hasChanged { if lastPullTimestamp == nil { log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name) } else { log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp) } + return nil } + err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) if err != nil { return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) } + if len(decisions) == 0 { log.Infof("blocklist %s has no decisions", *blocklist.Name) return nil @@ -770,19 +860,21 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap decisions = a.ApplyApicWhitelists(decisions) alert := createAlertForDecision(decisions[0]) alertsFromCapi := []*models.Alert{alert} - alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, add_counters) + alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters) - err = a.SaveAlerts(alertsFromCapi, add_counters, nil) + err = a.SaveAlerts(alertsFromCapi, addCounters, nil) if err != nil { return fmt.Errorf("while saving alert from blocklist %s: %w", *blocklist.Name, err) } + return nil } -func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, add_counters map[string]map[string]int, forcePull bool) error { +func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error { if links == nil { return nil } + if links.Blocklists == nil { return nil } @@ -792,21 +884,23 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink if err != nil { return fmt.Errorf("while creating default client: %w", err) } + for _, blocklist := range links.Blocklists { - if err := a.updateBlocklist(defaultClient, blocklist, add_counters, forcePull); err != nil { + if err := a.updateBlocklist(defaultClient, blocklist, addCounters, forcePull); err != nil { return err } } + return nil } -func setAlertScenario(alert *models.Alert, add_counters map[string]map[string]int, delete_counters map[string]map[string]int) { +func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) { if *alert.Source.Scope == types.CAPIOrigin { *alert.Source.Scope = types.CommunityBlocklistPullSourceScope - alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", add_counters[types.CAPIOrigin]["all"], delete_counters[types.CAPIOrigin]["all"])) + alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.CAPIOrigin]["all"], deleteCounters[types.CAPIOrigin]["all"])) } else if *alert.Source.Scope == types.ListOrigin { *alert.Source.Scope = fmt.Sprintf("%s:%s", types.ListOrigin, *alert.Scenario) - alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", add_counters[types.ListOrigin][*alert.Scenario], delete_counters[types.ListOrigin][*alert.Scenario])) + alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.ListOrigin][*alert.Scenario], deleteCounters[types.ListOrigin][*alert.Scenario])) } } @@ -814,20 +908,26 @@ func (a *apic) Pull() error { defer trace.CatchPanic("lapi/pullFromAPIC") toldOnce := false + for { scenario, err := a.FetchScenariosListFromDB() if err != nil { log.Errorf("unable to fetch scenarios from db: %s", err) } + if len(scenario) > 0 { break } + if !toldOnce { log.Warning("scenario list is empty, will not pull yet") + toldOnce = true } + time.Sleep(1 * time.Second) } + if err := a.PullTop(false); err != nil { log.Errorf("capi pull top: %s", err) } @@ -839,6 +939,7 @@ func (a *apic) Pull() error { select { case <-ticker.C: ticker.Reset(a.pullInterval) + if err := a.PullTop(false); err != nil { log.Errorf("capi pull top: %s", err) continue @@ -846,6 +947,7 @@ func (a *apic) Pull() error { case <-a.pullTomb.Dying(): // if one apic routine is dying, do we kill the others? a.metricsTomb.Kill(nil) a.pushTomb.Kill(nil) + return nil } } @@ -858,15 +960,15 @@ func (a *apic) Shutdown() { } func makeAddAndDeleteCounters() (map[string]map[string]int, map[string]map[string]int) { - add_counters := make(map[string]map[string]int) - add_counters[types.CAPIOrigin] = make(map[string]int) - add_counters[types.ListOrigin] = make(map[string]int) + addCounters := make(map[string]map[string]int) + addCounters[types.CAPIOrigin] = make(map[string]int) + addCounters[types.ListOrigin] = make(map[string]int) - delete_counters := make(map[string]map[string]int) - delete_counters[types.CAPIOrigin] = make(map[string]int) - delete_counters[types.ListOrigin] = make(map[string]int) + deleteCounters := make(map[string]map[string]int) + deleteCounters[types.CAPIOrigin] = make(map[string]int) + deleteCounters[types.ListOrigin] = make(map[string]int) - return add_counters, delete_counters + return addCounters, deleteCounters } func updateCounterForDecision(counter map[string]map[string]int, origin *string, scenario *string, totalDecisions int) { diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index 9059928fe..128ce5a96 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -2,10 +2,10 @@ package apiserver import ( "context" - "slices" "time" log "github.com/sirupsen/logrus" + "slices" "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/trace" @@ -66,6 +66,7 @@ func (a *apic) fetchMachineIDs() ([]string, error) { } // sorted slices are required for the slices.Equal comparison slices.Sort(ret) + return ret, nil } @@ -91,6 +92,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { if count < len(metInts)-1 { count++ } + return metInts[count] } @@ -100,8 +102,10 @@ func (a *apic) SendMetrics(stop chan (bool)) { ids, err := a.fetchMachineIDs() if err != nil { log.Debugf("unable to get machines (%s), will retry", err) + return } + machineIDs = ids } @@ -117,16 +121,20 @@ func (a *apic) SendMetrics(stop chan (bool)) { case <-stop: checkTicker.Stop() metTicker.Stop() + return case <-checkTicker.C: oldIDs := machineIDs + reloadMachineIDs() + if !slices.Equal(oldIDs, machineIDs) { log.Infof("capi metrics: machines changed, immediate send") metTicker.Reset(1 * time.Millisecond) } case <-metTicker.C: metTicker.Stop() + metrics, err := a.GetMetrics() if err != nil { log.Errorf("unable to get metrics (%s)", err) @@ -134,17 +142,20 @@ func (a *apic) SendMetrics(stop chan (bool)) { // metrics are nil if they could not be retrieved if metrics != nil { log.Info("capi metrics: sending") + _, _, err = a.apiClient.Metrics.Add(context.Background(), metrics) if err != nil { log.Errorf("capi metrics: failed: %s", err) } } + metTicker.Reset(nextMetInt()) case <-a.metricsTomb.Dying(): // if one apic routine is dying, do we kill the others? checkTicker.Stop() metTicker.Stop() a.pullTomb.Kill(nil) a.pushTomb.Kill(nil) + return } } diff --git a/pkg/apiserver/apic_metrics_test.go b/pkg/apiserver/apic_metrics_test.go index a12568e50..2bc0dd269 100644 --- a/pkg/apiserver/apic_metrics_test.go +++ b/pkg/apiserver/apic_metrics_test.go @@ -61,6 +61,7 @@ func TestAPICSendMetrics(t *testing.T) { httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/metrics/", httpmock.NewBytesResponder(200, []byte{})) httpmock.Activate() + defer httpmock.Deactivate() for _, tc := range tests { diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 736a690c9..16dba1e86 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -44,12 +44,14 @@ func getDBClient(t *testing.T) *database.Client { DbPath: dbPath.Name(), }) require.NoError(t, err) + return dbClient } func getAPIC(t *testing.T) *apic { t.Helper() dbClient := getDBClient(t) + return &apic{ AlertsAddChan: make(chan []*models.Alert), //DecisionDeleteChan: make(chan []*models.Decision), @@ -74,6 +76,7 @@ func absDiff(a int, b int) (c int) { if c = a - b; c < 0 { return -1 * c } + return c } @@ -94,6 +97,7 @@ func jsonMarshalX(v interface{}) []byte { if err != nil { panic(err) } + return data } @@ -176,7 +180,6 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) { assert.ElementsMatch(t, tc.expectedScenarios, scenarios) }) - } } @@ -220,6 +223,7 @@ func TestNewAPIC(t *testing.T) { expectedErr: "first path segment in URL cannot contain colon", }, } + for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -274,7 +278,7 @@ func TestAPICHandleDeletedDecisions(t *testing.T) { Scope: ptr.Of("IP"), }}, deleteCounters) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 2, nbDeleted) assert.Equal(t, 2, deleteCounters[types.CAPIOrigin]["all"]) } @@ -338,6 +342,7 @@ func TestAPICGetMetrics(t *testing.T) { }, }, } + for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -394,6 +399,7 @@ func TestCreateAlertsForDecision(t *testing.T) { Origin: ptr.Of(types.CAPIOrigin), Scenario: ptr.Of("crowdsecurity/ssh-bf"), } + type args struct { decisions []*models.Decision } @@ -443,6 +449,7 @@ func TestCreateAlertsForDecision(t *testing.T) { }, }, } + for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -477,6 +484,7 @@ func TestFillAlertsWithDecisions(t *testing.T) { Scenario: ptr.Of("crowdsecurity/ssh-bf"), Scope: ptr.Of("ip"), } + type args struct { alerts []*models.Alert decisions []*models.Decision @@ -520,6 +528,7 @@ func TestFillAlertsWithDecisions(t *testing.T) { }, }, } + for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -546,12 +555,14 @@ func TestAPICWhitelists(t *testing.T) { if err != nil { t.Fatalf("unable to parse cidr : %s", err) } + api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet) cidrwl1 = "11.2.3.0/24" _, tnet, err = net.ParseCIDR(cidrwl1) if err != nil { t.Fatalf("unable to parse cidr : %s", err) } + api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet) api.dbClient.Ent.Decision.Create(). SetOrigin(types.CAPIOrigin). @@ -564,6 +575,7 @@ func TestAPICWhitelists(t *testing.T) { assertTotalDecisionCount(t, api.dbClient, 1) assertTotalValidDecisionCount(t, api.dbClient, 1) httpmock.Activate() + defer httpmock.DeactivateAndReset() httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder( 200, jsonMarshalX( @@ -681,33 +693,39 @@ func TestAPICWhitelists(t *testing.T) { AllX(context.Background()) decisionScenarioFreq := make(map[string]int) - decisionIp := make(map[string]int) + decisionIP := make(map[string]int) alertScenario := make(map[string]int) for _, alert := range alerts { alertScenario[alert.SourceScope]++ } - assert.Equal(t, 3, len(alertScenario)) + + assert.Len(t, alertScenario, 3) assert.Equal(t, 1, alertScenario[types.CommunityBlocklistPullSourceScope]) assert.Equal(t, 1, alertScenario["lists:blocklist1"]) assert.Equal(t, 1, alertScenario["lists:blocklist2"]) for _, decisions := range validDecisions { decisionScenarioFreq[decisions.Scenario]++ - decisionIp[decisions.Value]++ + decisionIP[decisions.Value]++ } - assert.Equal(t, 1, decisionIp["2.2.3.4"], 1) - assert.Equal(t, 1, decisionIp["6.2.3.4"], 1) - if _, ok := decisionIp["13.2.3.4"]; ok { + + assert.Equal(t, 1, decisionIP["2.2.3.4"], 1) + assert.Equal(t, 1, decisionIP["6.2.3.4"], 1) + + if _, ok := decisionIP["13.2.3.4"]; ok { t.Errorf("13.2.3.4 is whitelisted") } - if _, ok := decisionIp["13.2.3.5"]; ok { + + if _, ok := decisionIP["13.2.3.5"]; ok { t.Errorf("13.2.3.5 is whitelisted") } - if _, ok := decisionIp["9.2.3.4"]; ok { + + if _, ok := decisionIP["9.2.3.4"]; ok { t.Errorf("9.2.3.4 is whitelisted") } + assert.Equal(t, 1, decisionScenarioFreq["blocklist1"], 1) assert.Equal(t, 1, decisionScenarioFreq["blocklist2"], 1) assert.Equal(t, 2, decisionScenarioFreq["crowdsecurity/test1"], 2) @@ -726,6 +744,7 @@ func TestAPICPullTop(t *testing.T) { assertTotalDecisionCount(t, api.dbClient, 1) assertTotalValidDecisionCount(t, api.dbClient, 1) httpmock.Activate() + defer httpmock.DeactivateAndReset() httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder( 200, jsonMarshalX( @@ -817,7 +836,8 @@ func TestAPICPullTop(t *testing.T) { for _, alert := range alerts { alertScenario[alert.SourceScope]++ } - assert.Equal(t, 3, len(alertScenario)) + + assert.Len(t, alertScenario, 3) assert.Equal(t, 1, alertScenario[types.CommunityBlocklistPullSourceScope]) assert.Equal(t, 1, alertScenario["lists:blocklist1"]) assert.Equal(t, 1, alertScenario["lists:blocklist2"]) @@ -835,6 +855,7 @@ func TestAPICPullTop(t *testing.T) { func TestAPICPullTopBLCacheFirstCall(t *testing.T) { // no decision in db, no last modified parameter. api := getAPIC(t) + httpmock.Activate() defer httpmock.DeactivateAndReset() httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder( @@ -904,6 +925,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { func TestAPICPullTopBLCacheForceCall(t *testing.T) { api := getAPIC(t) + httpmock.Activate() defer httpmock.DeactivateAndReset() // create a decision about to expire. It should force fetch @@ -975,6 +997,7 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) { func TestAPICPullBlocklistCall(t *testing.T) { api := getAPIC(t) + httpmock.Activate() defer httpmock.DeactivateAndReset() diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 9c455a38a..11d0c3eaa 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -2,8 +2,6 @@ package apiserver import ( "context" - "crypto/tls" - "crypto/x509" "fmt" "io" "net" @@ -15,7 +13,6 @@ import ( "github.com/gin-gonic/gin" "github.com/go-co-op/gocron" - "github.com/golang-jwt/jwt/v4" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/natefinch/lumberjack.v2" @@ -23,7 +20,6 @@ import ( "github.com/crowdsecurity/go-cs-lib/trace" - "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers" v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/csconfig" @@ -32,9 +28,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var ( - keyLength = 32 -) +const keyLength = 32 type APIServer struct { URL string @@ -52,57 +46,117 @@ type APIServer struct { isEnrolled bool } -// RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one. +func recoverFromPanic(c *gin.Context) { + err := recover() + if err == nil { + return + } + + // Check for a broken connection, as it is not really a + // condition that warrants a panic stack trace. + brokenPipe := false + + if ne, ok := err.(*net.OpError); ok { + if se, ok := ne.Err.(*os.SyscallError); ok { + if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") { + brokenPipe = true + } + } + } + + // because of https://github.com/golang/net/blob/39120d07d75e76f0079fe5d27480bcb965a21e4c/http2/server.go + // and because it seems gin doesn't handle those neither, we need to "hand define" some errors to properly catch them + if strErr, ok := err.(error); ok { + //stolen from http2/server.go in x/net + var ( + errClientDisconnected = errors.New("client disconnected") + errClosedBody = errors.New("body closed by handler") + errHandlerComplete = errors.New("http2: request body closed due to handler exiting") + errStreamClosed = errors.New("http2: stream closed") + ) + + if errors.Is(strErr, errClientDisconnected) || + errors.Is(strErr, errClosedBody) || + errors.Is(strErr, errHandlerComplete) || + errors.Is(strErr, errStreamClosed) { + brokenPipe = true + } + } + + if brokenPipe { + log.Warningf("client %s disconnected : %s", c.ClientIP(), err) + c.Abort() + } else { + filename := trace.WriteStackTrace(err) + log.Warningf("client %s error : %s", c.ClientIP(), err) + log.Warningf("stacktrace written to %s, please join to your issue", filename) + c.AbortWithStatus(http.StatusInternalServerError) + } +} + +// CustomRecoveryWithWriter returns a middleware for a writer that recovers from any panics and writes a 500 if there was one. func CustomRecoveryWithWriter() gin.HandlerFunc { return func(c *gin.Context) { - defer func() { - if err := recover(); err != nil { - // Check for a broken connection, as it is not really a - // condition that warrants a panic stack trace. - var brokenPipe bool - if ne, ok := err.(*net.OpError); ok { - if se, ok := ne.Err.(*os.SyscallError); ok { - if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") { - brokenPipe = true - } - } - } - - // because of https://github.com/golang/net/blob/39120d07d75e76f0079fe5d27480bcb965a21e4c/http2/server.go - // and because it seems gin doesn't handle those neither, we need to "hand define" some errors to properly catch them - if strErr, ok := err.(error); ok { - //stolen from http2/server.go in x/net - var ( - errClientDisconnected = errors.New("client disconnected") - errClosedBody = errors.New("body closed by handler") - errHandlerComplete = errors.New("http2: request body closed due to handler exiting") - errStreamClosed = errors.New("http2: stream closed") - ) - if errors.Is(strErr, errClientDisconnected) || - errors.Is(strErr, errClosedBody) || - errors.Is(strErr, errHandlerComplete) || - errors.Is(strErr, errStreamClosed) { - brokenPipe = true - } - } - - if brokenPipe { - log.Warningf("client %s disconnected : %s", c.ClientIP(), err) - c.Abort() - } else { - filename := trace.WriteStackTrace(err) - log.Warningf("client %s error : %s", c.ClientIP(), err) - log.Warningf("stacktrace written to %s, please join to your issue", filename) - c.AbortWithStatus(http.StatusInternalServerError) - } - } - }() + defer recoverFromPanic(c) c.Next() } } +// XXX: could be a method of LocalApiServerCfg +func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, error) { + clog := log.New() + + if err := types.ConfigureLogger(clog); err != nil { + return nil, "", fmt.Errorf("while configuring gin logger: %w", err) + } + + if config.LogLevel != nil { + clog.SetLevel(*config.LogLevel) + } + + if config.LogMedia != "file" { + return clog, "", nil + } + + // Log rotation + + logFile := filepath.Join(config.LogDir, "crowdsec_api.log") + log.Debugf("starting router, logging to %s", logFile) + + logger := &lumberjack.Logger{ + Filename: logFile, + MaxSize: 500, //megabytes + MaxBackups: 3, + MaxAge: 28, //days + Compress: true, //disabled by default + } + + if config.LogMaxSize != 0 { + logger.MaxSize = config.LogMaxSize + } + + if config.LogMaxFiles != 0 { + logger.MaxBackups = config.LogMaxFiles + } + + if config.LogMaxAge != 0 { + logger.MaxAge = config.LogMaxAge + } + + if config.CompressLogs != nil { + logger.Compress = *config.CompressLogs + } + + clog.SetOutput(logger) + + return clog, logFile, nil +} + +// NewServer creates a LAPI server. +// It sets up a gin router, a database client, and a controller. func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { var flushScheduler *gocron.Scheduler + dbClient, err := database.NewClient(config.DbConfig) if err != nil { return nil, fmt.Errorf("unable to init database client: %w", err) @@ -115,63 +169,26 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { } } - logFile := "" - if config.LogMedia == "file" { - logFile = filepath.Join(config.LogDir, "crowdsec_api.log") - } - if log.GetLevel() < log.DebugLevel { gin.SetMode(gin.ReleaseMode) } - log.Debugf("starting router, logging to %s", logFile) + router := gin.New() + router.ForwardedByClientIP = false + if config.TrustedProxies != nil && config.UseForwardedForHeaders { - if err := router.SetTrustedProxies(*config.TrustedProxies); err != nil { + if err = router.SetTrustedProxies(*config.TrustedProxies); err != nil { return nil, fmt.Errorf("while setting trusted_proxies: %w", err) } + router.ForwardedByClientIP = true - } else { - router.ForwardedByClientIP = false } - /*The logger that will be used by handlers*/ - clog := log.New() - - if err := types.ConfigureLogger(clog); err != nil { - return nil, fmt.Errorf("while configuring gin logger: %w", err) - } - if config.LogLevel != nil { - clog.SetLevel(*config.LogLevel) - } - - /*Configure logs*/ - if logFile != "" { - _maxsize := 500 - if config.LogMaxSize != 0 { - _maxsize = config.LogMaxSize - } - _maxfiles := 3 - if config.LogMaxFiles != 0 { - _maxfiles = config.LogMaxFiles - } - _maxage := 28 - if config.LogMaxAge != 0 { - _maxage = config.LogMaxAge - } - _compress := true - if config.CompressLogs != nil { - _compress = *config.CompressLogs - } - - LogOutput := &lumberjack.Logger{ - Filename: logFile, - MaxSize: _maxsize, //megabytes - MaxBackups: _maxfiles, - MaxAge: _maxage, //days - Compress: _compress, //disabled by default - } - clog.SetOutput(LogOutput) + // The logger that will be used by handlers + clog, logFile, err := newGinLogger(config) + if err != nil { + return nil, err } gin.DefaultErrorWriter = clog.WriterLevel(log.ErrorLevel) @@ -206,41 +223,50 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { DisableRemoteLapiRegistration: config.DisableRemoteLapiRegistration, } - var apiClient *apic - var papiClient *Papi - var isMachineEnrolled = false + var ( + apiClient *apic + papiClient *Papi + isMachineEnrolled = false + ) + + controller.AlertsAddChan = nil + controller.DecisionDeleteChan = nil if config.OnlineClient != nil && config.OnlineClient.Credentials != nil { log.Printf("Loading CAPI manager") + apiClient, err = NewAPIC(config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) if err != nil { return nil, err } + log.Infof("CAPI manager configured successfully") - isMachineEnrolled = isEnrolled(apiClient.apiClient) + controller.AlertsAddChan = apiClient.AlertsAddChan - if isMachineEnrolled { + + if apiClient.apiClient.IsEnrolled() { + isMachineEnrolled = true + log.Infof("Machine is enrolled in the console, Loading PAPI Client") + papiClient, err = NewPAPI(apiClient, dbClient, config.ConsoleConfig, *config.PapiLogLevel) if err != nil { return nil, err } + controller.DecisionDeleteChan = papiClient.Channels.DeleteDecisionChannel } else { log.Errorf("Machine is not enrolled in the console, can't synchronize with the console") } - } else { - apiClient = nil - controller.AlertsAddChan = nil - controller.DecisionDeleteChan = nil } - if trustedIPs, err := config.GetTrustedIPs(); err == nil { - controller.TrustedIPs = trustedIPs - } else { + trustedIPs, err := config.GetTrustedIPs() + if err != nil { return nil, err } + controller.TrustedIPs = trustedIPs + return &APIServer{ URL: config.ListenURI, TLS: config.TLS, @@ -255,80 +281,20 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { consoleConfig: config.ConsoleConfig, isEnrolled: isMachineEnrolled, }, nil - -} - -func isEnrolled(client *apiclient.ApiClient) bool { - apiHTTPClient := client.GetClient() - jwtTransport := apiHTTPClient.Transport.(*apiclient.JWTTransport) - tokenStr := jwtTransport.Token - - token, _ := jwt.Parse(tokenStr, nil) - if token == nil { - return false - } - claims := token.Claims.(jwt.MapClaims) - _, ok := claims["organization_id"] - - return ok } func (s *APIServer) Router() (*gin.Engine, error) { return s.router, nil } -func (s *APIServer) GetTLSConfig() (*tls.Config, error) { - var caCert []byte - var err error - var caCertPool *x509.CertPool - var clientAuthType tls.ClientAuthType - - if s.TLS == nil { - return &tls.Config{}, nil - } - - if s.TLS.ClientVerification == "" { - //sounds like a sane default : verify client cert if given, but don't make it mandatory - clientAuthType = tls.VerifyClientCertIfGiven - } else { - clientAuthType, err = getTLSAuthType(s.TLS.ClientVerification) - if err != nil { - return nil, err - } - } - - if s.TLS.CACertPath != "" { - if clientAuthType > tls.RequestClientCert { - log.Infof("(tls) Client Auth Type set to %s", clientAuthType.String()) - caCert, err = os.ReadFile(s.TLS.CACertPath) - if err != nil { - return nil, fmt.Errorf("while opening cert file: %w", err) - } - caCertPool, err = x509.SystemCertPool() - if err != nil { - log.Warnf("Error loading system CA certificates: %s", err) - } - if caCertPool == nil { - caCertPool = x509.NewCertPool() - } - caCertPool.AppendCertsFromPEM(caCert) - } - } - - return &tls.Config{ - ServerName: s.TLS.ServerName, //should it be removed ? - ClientAuth: clientAuthType, - ClientCAs: caCertPool, - MinVersion: tls.VersionTLS12, // TLS versions below 1.2 are considered insecure - see https://www.rfc-editor.org/rfc/rfc7525.txt for details - }, nil -} - func (s *APIServer) Run(apiReady chan bool) error { defer trace.CatchPanic("lapi/runServer") - tlsCfg, err := s.GetTLSConfig() + + tlsCfg, err := s.TLS.GetTLSConfig() if err != nil { return fmt.Errorf("while creating TLS config: %w", err) } + s.httpServer = &http.Server{ Addr: s.URL, Handler: s.router, @@ -386,41 +352,74 @@ func (s *APIServer) Run(apiReady chan bool) error { }) } - s.httpServerTomb.Go(func() error { - go func() { - apiReady <- true - log.Infof("CrowdSec Local API listening on %s", s.URL) - if s.TLS != nil && (s.TLS.CertFilePath != "" || s.TLS.KeyFilePath != "") { - if s.TLS.KeyFilePath == "" { - log.Fatalf("while serving local API: %v", errors.New("missing TLS key file")) - } else if s.TLS.CertFilePath == "" { - log.Fatalf("while serving local API: %v", errors.New("missing TLS cert file")) - } - - if err := s.httpServer.ListenAndServeTLS(s.TLS.CertFilePath, s.TLS.KeyFilePath); err != nil { - log.Fatalf("while serving local API: %v", err) - } - } else { - if err := s.httpServer.ListenAndServe(); err != http.ErrServerClosed { - log.Fatalf("while serving local API: %v", err) - } - } - }() - <-s.httpServerTomb.Dying() - return nil - }) + s.httpServerTomb.Go(func() error { s.listenAndServeURL(apiReady); return nil }) return nil } +// listenAndServeURL starts the http server and blocks until it's closed +// it also updates the URL field with the actual address the server is listening on +// it's meant to be run in a separate goroutine +func (s *APIServer) listenAndServeURL(apiReady chan bool) { + serverError := make(chan error, 1) + + go func() { + listener, err := net.Listen("tcp", s.URL) + if err != nil { + serverError <- fmt.Errorf("listening on %s: %w", s.URL, err) + return + } + + s.URL = listener.Addr().String() + log.Infof("CrowdSec Local API listening on %s", s.URL) + apiReady <- true + + if s.TLS != nil && (s.TLS.CertFilePath != "" || s.TLS.KeyFilePath != "") { + if s.TLS.KeyFilePath == "" { + serverError <- errors.New("missing TLS key file") + return + } else if s.TLS.CertFilePath == "" { + serverError <- errors.New("missing TLS cert file") + return + } + + err = s.httpServer.ServeTLS(listener, s.TLS.CertFilePath, s.TLS.KeyFilePath) + } else { + err = s.httpServer.Serve(listener) + } + + if err != nil && err != http.ErrServerClosed { + serverError <- fmt.Errorf("while serving local API: %w", err) + return + } + }() + + select { + case err := <-serverError: + log.Fatalf("while starting API server: %s", err) + case <-s.httpServerTomb.Dying(): + log.Infof("Shutting down API server") + // do we need a graceful shutdown here? + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := s.httpServer.Shutdown(ctx); err != nil { + log.Errorf("while shutting down http server: %s", err) + } + } +} + func (s *APIServer) Close() { if s.apic != nil { s.apic.Shutdown() // stop apic first since it use dbClient } + if s.papi != nil { s.papi.Shutdown() // papi also uses the dbClient } + s.dbClient.Ent.Close() + if s.flushScheduler != nil { s.flushScheduler.Stop() } @@ -428,6 +427,7 @@ func (s *APIServer) Close() { func (s *APIServer) Shutdown() error { s.Close() + if s.httpServer != nil { if err := s.httpServer.Shutdown(context.TODO()); err != nil { return err @@ -438,13 +438,17 @@ func (s *APIServer) Shutdown() error { if pipe, ok := gin.DefaultErrorWriter.(*io.PipeWriter); ok { pipe.Close() } + if pipe, ok := gin.DefaultWriter.(*io.PipeWriter); ok { pipe.Close() } + s.httpServerTomb.Kill(nil) + if err := s.httpServerTomb.Wait(); err != nil { return fmt.Errorf("while waiting on httpServerTomb: %w", err) } + return nil } @@ -453,36 +457,41 @@ func (s *APIServer) AttachPluginBroker(broker *csplugin.PluginBroker) { } func (s *APIServer) InitController() error { - err := s.controller.Init() if err != nil { return fmt.Errorf("controller init: %w", err) } - if s.TLS != nil { - var cacheExpiration time.Duration - if s.TLS.CacheExpiration != nil { - cacheExpiration = *s.TLS.CacheExpiration - } else { - cacheExpiration = time.Hour - } - s.controller.HandlerV1.Middlewares.JWT.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedAgentsOU, s.TLS.CRLPath, - cacheExpiration, - log.WithFields(log.Fields{ - "component": "tls-auth", - "type": "agent", - })) - if err != nil { - return fmt.Errorf("while creating TLS auth for agents: %w", err) - } - s.controller.HandlerV1.Middlewares.APIKey.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedBouncersOU, s.TLS.CRLPath, - cacheExpiration, - log.WithFields(log.Fields{ - "component": "tls-auth", - "type": "bouncer", - })) - if err != nil { - return fmt.Errorf("while creating TLS auth for bouncers: %w", err) - } + + if s.TLS == nil { + return nil } - return err + + // TLS is configured: create the TLSAuth middleware for agents and bouncers + + cacheExpiration := time.Hour + if s.TLS.CacheExpiration != nil { + cacheExpiration = *s.TLS.CacheExpiration + } + + s.controller.HandlerV1.Middlewares.JWT.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedAgentsOU, s.TLS.CRLPath, + cacheExpiration, + log.WithFields(log.Fields{ + "component": "tls-auth", + "type": "agent", + })) + if err != nil { + return fmt.Errorf("while creating TLS auth for agents: %w", err) + } + + s.controller.HandlerV1.Middlewares.APIKey.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedBouncersOU, s.TLS.CRLPath, + cacheExpiration, + log.WithFields(log.Fields{ + "component": "tls-auth", + "type": "bouncer", + })) + if err != nil { + return fmt.Errorf("while creating TLS auth for bouncers: %w", err) + } + + return nil } diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 6150c351b..62a8b83dd 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -11,21 +11,20 @@ import ( "testing" "time" + "github.com/gin-gonic/gin" + "github.com/go-openapi/strfmt" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/version" middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" - "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/go-openapi/strfmt" - "github.com/pkg/errors" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" - "github.com/gin-gonic/gin" - - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" ) var testMachineID = "test" @@ -46,6 +45,7 @@ func LoadTestConfig(t *testing.T) csconfig.Config { } tempDir, _ := os.MkdirTemp("", "crowdsec_tests") + t.Cleanup(func() { os.RemoveAll(tempDir) }) dbconfig := csconfig.DatabaseCfg{ @@ -70,6 +70,7 @@ func LoadTestConfig(t *testing.T) csconfig.Config { if err := config.API.Server.LoadProfiles(); err != nil { log.Fatalf("failed to load profiles: %s", err) } + return config } @@ -81,6 +82,7 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config { } tempDir, _ := os.MkdirTemp("", "crowdsec_tests") + t.Cleanup(func() { os.RemoveAll(tempDir) }) dbconfig := csconfig.DatabaseCfg{ @@ -107,18 +109,22 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config { if err := config.API.Server.LoadProfiles(); err != nil { log.Fatalf("failed to load profiles: %s", err) } + return config } func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config, error) { config := LoadTestConfig(t) + os.Remove("./ent") apiServer, err := NewServer(config.API.Server) if err != nil { return nil, config, fmt.Errorf("unable to run local API: %s", err) } + log.Printf("Creating new API server") gin.SetMode(gin.TestMode) + return apiServer, config, nil } @@ -135,6 +141,7 @@ func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config, error) { if err != nil { return nil, config, fmt.Errorf("unable to run local API: %s", err) } + return router, config, nil } @@ -150,12 +157,14 @@ func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config, error) if err != nil { return nil, config, fmt.Errorf("unable to run local API: %s", err) } + log.Printf("Creating new API server") gin.SetMode(gin.TestMode) router, err := apiServer.Router() if err != nil { return nil, config, fmt.Errorf("unable to run local API: %s", err) } + return router, config, nil } @@ -164,9 +173,11 @@ func ValidateMachine(machineID string, config *csconfig.DatabaseCfg) error { if err != nil { return fmt.Errorf("unable to create new database client: %s", err) } + if err := dbClient.ValidateMachine(machineID); err != nil { return fmt.Errorf("unable to validate machine: %s", err) } + return nil } @@ -179,23 +190,24 @@ func GetMachineIP(machineID string, config *csconfig.DatabaseCfg) (string, error if err != nil { return "", fmt.Errorf("Unable to list machines: %s", err) } + for _, machine := range machines { if machine.MachineId == machineID { return machine.IpAddress, nil } } + return "", nil } func GetAlertReaderFromFile(path string) *strings.Reader { - alertContentBytes, err := os.ReadFile(path) if err != nil { log.Fatal(err) } alerts := make([]*models.Alert, 0) - if err := json.Unmarshal(alertContentBytes, &alerts); err != nil { + if err = json.Unmarshal(alertContentBytes, &alerts); err != nil { log.Fatal(err) } @@ -208,12 +220,13 @@ func GetAlertReaderFromFile(path string) *strings.Reader { if err != nil { log.Fatal(err) } - return strings.NewReader(string(alertContent)) + return strings.NewReader(string(alertContent)) } func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision, int, error) { var response []*models.Decision + if resp == nil { return nil, 0, errors.New("response is nil") } @@ -221,11 +234,13 @@ func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision, if err != nil { return nil, resp.Code, err } + return response, resp.Code, nil } func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string, int, error) { var response map[string]string + if resp == nil { return nil, 0, errors.New("response is nil") } @@ -233,11 +248,13 @@ func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string, if err != nil { return nil, resp.Code, err } + return response, resp.Code, nil } func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int, error) { var response models.DeleteDecisionResponse + if resp == nil { return nil, 0, errors.New("response is nil") } @@ -245,11 +262,13 @@ func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDec if err != nil { return nil, resp.Code, err } + return &response, resp.Code, nil } func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int, error) { response := make(map[string][]*models.Decision) + if resp == nil { return nil, 0, errors.New("response is nil") } @@ -257,6 +276,7 @@ func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*mod if err != nil { return nil, resp.Code, err } + return response, resp.Code, nil } @@ -271,6 +291,7 @@ func CreateTestMachine(router *gin.Engine) (string, error) { req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Set("User-Agent", UserAgent) router.ServeHTTP(w, req) + return body, nil } @@ -279,10 +300,12 @@ func CreateTestBouncer(config *csconfig.DatabaseCfg) (string, error) { if err != nil { log.Fatalf("unable to create new database client: %s", err) } + apiKey, err := middlewares.GenerateAPIKey(keyLength) if err != nil { return "", fmt.Errorf("unable to generate api key: %s", err) } + _, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) if err != nil { return "", fmt.Errorf("unable to create blocker: %s", err) @@ -322,7 +345,6 @@ func TestUnknownPath(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 404, w.Code) - } /* @@ -348,6 +370,7 @@ func TestLoggingDebugToFileConfig(t *testing.T) { } tempDir, _ := os.MkdirTemp("", "crowdsec_tests") + t.Cleanup(func() { os.RemoveAll(tempDir) }) dbconfig := csconfig.DatabaseCfg{ @@ -370,10 +393,12 @@ func TestLoggingDebugToFileConfig(t *testing.T) { if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil { t.Fatal(err) } + api, err := NewServer(&cfg) if err != nil { t.Fatalf("failed to create api : %s", err) } + if api == nil { t.Fatalf("failed to create api #2 is nbill") } @@ -397,11 +422,9 @@ func TestLoggingDebugToFileConfig(t *testing.T) { t.Fatalf("expected %s in %s", expectedStr, string(data)) } } - } func TestLoggingErrorToFileConfig(t *testing.T) { - /*declare settings*/ maxAge := "1h" flushConfig := csconfig.FlushDBCfg{ @@ -409,6 +432,7 @@ func TestLoggingErrorToFileConfig(t *testing.T) { } tempDir, _ := os.MkdirTemp("", "crowdsec_tests") + t.Cleanup(func() { os.RemoveAll(tempDir) }) dbconfig := csconfig.DatabaseCfg{ @@ -434,6 +458,7 @@ func TestLoggingErrorToFileConfig(t *testing.T) { if err != nil { t.Fatalf("failed to create api : %s", err) } + if api == nil { t.Fatalf("failed to create api #2 is nbill") } diff --git a/pkg/apiserver/controllers/controller.go b/pkg/apiserver/controllers/controller.go index e0a1656e7..5794b40d3 100644 --- a/pkg/apiserver/controllers/controller.go +++ b/pkg/apiserver/controllers/controller.go @@ -6,13 +6,14 @@ import ( "net/http" "github.com/alexliesenfeld/health" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers/v1" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/gin-gonic/gin" - log "github.com/sirupsen/logrus" ) type Controller struct { diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index 66d19288d..10841ce45 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -10,15 +10,15 @@ import ( "time" jwt "github.com/appleboy/gin-jwt/v2" + "github.com/gin-gonic/gin" + "github.com/go-openapi/strfmt" "github.com/google/uuid" + log "github.com/sirupsen/logrus" "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/gin-gonic/gin" - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" ) func FormatOneAlert(alert *ent.Alert) *models.Alert { diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 8ea379873..534870484 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -7,11 +7,12 @@ import ( "strconv" "time" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/fflag" "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/gin-gonic/gin" - log "github.com/sirupsen/logrus" ) // Format decisions for the bouncers diff --git a/pkg/apiserver/controllers/v1/errors.go b/pkg/apiserver/controllers/v1/errors.go index 5edf0d6bf..b85b811f8 100644 --- a/pkg/apiserver/controllers/v1/errors.go +++ b/pkg/apiserver/controllers/v1/errors.go @@ -3,9 +3,10 @@ package v1 import ( "net/http" - "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/gin-gonic/gin" "github.com/pkg/errors" + + "github.com/crowdsecurity/crowdsec/pkg/database" ) func (c *Controller) HandleDBErrors(gctx *gin.Context, err error) { diff --git a/pkg/apiserver/controllers/v1/machines.go b/pkg/apiserver/controllers/v1/machines.go index b4f28d94f..55f79d0c9 100644 --- a/pkg/apiserver/controllers/v1/machines.go +++ b/pkg/apiserver/controllers/v1/machines.go @@ -3,10 +3,11 @@ package v1 import ( "net/http" - "github.com/crowdsecurity/crowdsec/pkg/models" - "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" + + "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" ) func (c *Controller) CreateMachine(gctx *gin.Context) { diff --git a/pkg/apiserver/controllers/v1/metrics.go b/pkg/apiserver/controllers/v1/metrics.go index 0f3bdb6d1..676cc31ea 100644 --- a/pkg/apiserver/controllers/v1/metrics.go +++ b/pkg/apiserver/controllers/v1/metrics.go @@ -35,8 +35,11 @@ var LapiBouncerHits = prometheus.NewCounterVec( []string{"bouncer", "route", "method"}, ) -/* keep track of the number of calls (per bouncer) that lead to nil/non-nil responses. -while it's not exact, it's a good way to know - when you have a rutpure bouncer - what is the rate of ok/ko answers you got from lapi*/ +/* + keep track of the number of calls (per bouncer) that lead to nil/non-nil responses. + +while it's not exact, it's a good way to know - when you have a rutpure bouncer - what is the rate of ok/ko answers you got from lapi +*/ var LapiNilDecisions = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "cs_lapi_decisions_ko_total", diff --git a/pkg/apiserver/controllers/v1/utils.go b/pkg/apiserver/controllers/v1/utils.go index 8edce5898..aaa17ca51 100644 --- a/pkg/apiserver/controllers/v1/utils.go +++ b/pkg/apiserver/controllers/v1/utils.go @@ -4,8 +4,9 @@ import ( "fmt" "net/http" - "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/gin-gonic/gin" + + "github.com/crowdsecurity/crowdsec/pkg/database/ent" ) var ( diff --git a/pkg/apiserver/decisions_test.go b/pkg/apiserver/decisions_test.go index 5f92b1f08..465accbac 100644 --- a/pkg/apiserver/decisions_test.go +++ b/pkg/apiserver/decisions_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -91,9 +92,9 @@ func TestGetDecisionFilters(t *testing.T) { w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code, err := readDecisionsGetResp(w) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 200, code) - assert.Equal(t, 2, len(decisions)) + assert.Len(t, decisions, 2) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) assert.Equal(t, "91.121.79.179", *decisions[0].Value) assert.Equal(t, int64(1), decisions[0].ID) @@ -106,9 +107,9 @@ func TestGetDecisionFilters(t *testing.T) { w = lapi.RecordResponse("GET", "/v1/decisions?type=ban", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code, err = readDecisionsGetResp(w) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 200, code) - assert.Equal(t, 2, len(decisions)) + assert.Len(t, decisions, 2) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) assert.Equal(t, "91.121.79.179", *decisions[0].Value) assert.Equal(t, int64(1), decisions[0].ID) @@ -124,9 +125,9 @@ func TestGetDecisionFilters(t *testing.T) { w = lapi.RecordResponse("GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code, err = readDecisionsGetResp(w) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 200, code) - assert.Equal(t, 1, len(decisions)) + assert.Len(t, decisions, 1) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) assert.Equal(t, "91.121.79.179", *decisions[0].Value) assert.Equal(t, int64(1), decisions[0].ID) @@ -139,9 +140,9 @@ func TestGetDecisionFilters(t *testing.T) { w = lapi.RecordResponse("GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code, err = readDecisionsGetResp(w) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 200, code) - assert.Equal(t, 1, len(decisions)) + assert.Len(t, decisions, 1) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) assert.Equal(t, "91.121.79.179", *decisions[0].Value) assert.Equal(t, int64(1), decisions[0].ID) @@ -153,12 +154,11 @@ func TestGetDecisionFilters(t *testing.T) { w = lapi.RecordResponse("GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code, err = readDecisionsGetResp(w) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 200, code) - assert.Equal(t, 2, len(decisions)) + assert.Len(t, decisions, 2) assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.179") assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.178") - } func TestGetDecision(t *testing.T) { @@ -171,9 +171,9 @@ func TestGetDecision(t *testing.T) { w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code, err := readDecisionsGetResp(w) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 200, code) - assert.Equal(t, 3, len(decisions)) + assert.Len(t, decisions, 3) /*decisions get doesn't perform deduplication*/ assert.Equal(t, "crowdsecurity/test", *decisions[0].Scenario) assert.Equal(t, "127.0.0.1", *decisions[0].Value) @@ -190,7 +190,7 @@ func TestGetDecision(t *testing.T) { // Get Decision with invalid filter. It should ignore this filter w = lapi.RecordResponse("GET", "/v1/decisions?test=test", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - assert.Equal(t, 3, len(decisions)) + assert.Len(t, decisions, 3) } func TestDeleteDecisionByID(t *testing.T) { @@ -202,47 +202,47 @@ func TestDeleteDecisionByID(t *testing.T) { //Have one alerts w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code, err := readDecisionsStreamResp(w) - assert.Equal(t, err, nil) + require.NoError(t, err) assert.Equal(t, 200, code) - assert.Equal(t, 0, len(decisions["deleted"])) - assert.Equal(t, 1, len(decisions["new"])) + assert.Empty(t, decisions["deleted"]) + assert.Len(t, decisions["new"], 1) // Delete alert with Invalid ID w = lapi.RecordResponse("DELETE", "/v1/decisions/test", emptyBody, PASSWORD) assert.Equal(t, 400, w.Code) - err_resp, _, err := readDecisionsErrorResp(w) - assert.NoError(t, err) - assert.Equal(t, "decision_id must be valid integer", err_resp["message"]) + errResp, _, err := readDecisionsErrorResp(w) + require.NoError(t, err) + assert.Equal(t, "decision_id must be valid integer", errResp["message"]) // Delete alert with ID that not exist w = lapi.RecordResponse("DELETE", "/v1/decisions/100", emptyBody, PASSWORD) assert.Equal(t, 500, w.Code) - err_resp, _, err = readDecisionsErrorResp(w) - assert.NoError(t, err) - assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", err_resp["message"]) + errResp, _, err = readDecisionsErrorResp(w) + require.NoError(t, err) + assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", errResp["message"]) //Have one alerts w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code, err = readDecisionsStreamResp(w) - assert.Equal(t, err, nil) + require.NoError(t, err) assert.Equal(t, 200, code) - assert.Equal(t, 0, len(decisions["deleted"])) - assert.Equal(t, 1, len(decisions["new"])) + assert.Empty(t, decisions["deleted"]) + assert.Len(t, decisions["new"], 1) // Delete alert with valid ID w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) resp, _, err := readDecisionsDeleteResp(w) - assert.NoError(t, err) - assert.Equal(t, resp.NbDeleted, "1") + require.NoError(t, err) + assert.Equal(t, "1", resp.NbDeleted) //Have one alert (because we delete an alert that has dup targets) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code, err = readDecisionsStreamResp(w) - assert.Equal(t, err, nil) + require.NoError(t, err) assert.Equal(t, 200, code) - assert.Equal(t, 0, len(decisions["deleted"])) - assert.Equal(t, 1, len(decisions["new"])) + assert.Empty(t, decisions["deleted"]) + assert.Len(t, decisions["new"], 1) } func TestDeleteDecision(t *testing.T) { @@ -254,16 +254,16 @@ func TestDeleteDecision(t *testing.T) { // Delete alert with Invalid filter w := lapi.RecordResponse("DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD) assert.Equal(t, 500, w.Code) - err_resp, _, err := readDecisionsErrorResp(w) - assert.NoError(t, err) - assert.Equal(t, err_resp["message"], "'test' doesn't exist: invalid filter") + errResp, _, err := readDecisionsErrorResp(w) + require.NoError(t, err) + assert.Equal(t, "'test' doesn't exist: invalid filter", errResp["message"]) // Delete all alert w = lapi.RecordResponse("DELETE", "/v1/decisions", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) resp, _, err := readDecisionsDeleteResp(w) - assert.NoError(t, err) - assert.Equal(t, resp.NbDeleted, "3") + require.NoError(t, err) + assert.Equal(t, "3", resp.NbDeleted) } func TestStreamStartDecisionDedup(t *testing.T) { @@ -276,10 +276,10 @@ func TestStreamStartDecisionDedup(t *testing.T) { // Get Stream, we only get one decision (the longest one) w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code, err := readDecisionsStreamResp(w) - assert.Equal(t, nil, err) + require.NoError(t, err) assert.Equal(t, 200, code) - assert.Equal(t, 0, len(decisions["deleted"])) - assert.Equal(t, 1, len(decisions["new"])) + assert.Empty(t, decisions["deleted"]) + assert.Len(t, decisions["new"], 1) assert.Equal(t, int64(3), decisions["new"][0].ID) assert.Equal(t, "test", *decisions["new"][0].Origin) assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) @@ -291,10 +291,10 @@ func TestStreamStartDecisionDedup(t *testing.T) { // Get Stream, we only get one decision (the longest one, id=2) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code, err = readDecisionsStreamResp(w) - assert.Equal(t, nil, err) + require.NoError(t, err) assert.Equal(t, 200, code) - assert.Equal(t, 0, len(decisions["deleted"])) - assert.Equal(t, 1, len(decisions["new"])) + assert.Empty(t, decisions["deleted"]) + assert.Len(t, decisions["new"], 1) assert.Equal(t, int64(2), decisions["new"][0].ID) assert.Equal(t, "test", *decisions["new"][0].Origin) assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) @@ -306,10 +306,10 @@ func TestStreamStartDecisionDedup(t *testing.T) { // And get the remaining decision (1) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code, err = readDecisionsStreamResp(w) - assert.Equal(t, nil, err) + require.NoError(t, err) assert.Equal(t, 200, code) - assert.Equal(t, 0, len(decisions["deleted"])) - assert.Equal(t, 1, len(decisions["new"])) + assert.Empty(t, decisions["deleted"]) + assert.Len(t, decisions["new"], 1) assert.Equal(t, int64(1), decisions["new"][0].ID) assert.Equal(t, "test", *decisions["new"][0].Origin) assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) @@ -321,13 +321,13 @@ func TestStreamStartDecisionDedup(t *testing.T) { //and now we only get a deleted decision w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code, err = readDecisionsStreamResp(w) - assert.Equal(t, nil, err) + require.NoError(t, err) assert.Equal(t, 200, code) - assert.Equal(t, 1, len(decisions["deleted"])) + assert.Len(t, decisions["deleted"], 1) assert.Equal(t, int64(1), decisions["deleted"][0].ID) assert.Equal(t, "test", *decisions["deleted"][0].Origin) assert.Equal(t, "127.0.0.1", *decisions["deleted"][0].Value) - assert.Equal(t, 0, len(decisions["new"])) + assert.Empty(t, decisions["new"]) } type DecisionCheck struct { diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go index e5c3529cc..886962250 100644 --- a/pkg/apiserver/jwt_test.go +++ b/pkg/apiserver/jwt_test.go @@ -91,5 +91,4 @@ func TestLogin(t *testing.T) { assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "\"token\"") assert.Contains(t, w.Body.String(), "\"expire\"") - } diff --git a/pkg/apiserver/machines_test.go b/pkg/apiserver/machines_test.go index 25fd0eaf4..6ac016404 100644 --- a/pkg/apiserver/machines_test.go +++ b/pkg/apiserver/machines_test.go @@ -49,7 +49,6 @@ func TestCreateMachine(t *testing.T) { assert.Equal(t, 201, w.Code) assert.Equal(t, "", w.Body.String()) - } func TestCreateMachineWithForwardedFor(t *testing.T) { @@ -78,6 +77,7 @@ func TestCreateMachineWithForwardedFor(t *testing.T) { if err != nil { log.Fatalf("Could not get machine IP : %s", err) } + assert.Equal(t, "1.1.1.1", ip) } @@ -165,5 +165,4 @@ func TestCreateMachineAlreadyExist(t *testing.T) { assert.Equal(t, 403, w.Code) assert.Equal(t, "{\"message\":\"user 'test': user already exist\"}", w.Body.String()) - } diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 1481a0145..7e4df875c 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -8,18 +8,19 @@ import ( "net/http" "strings" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/gin-gonic/gin" - log "github.com/sirupsen/logrus" ) const ( APIKeyHeader = "X-Api-Key" bouncerContextKey = "bouncer_info" // max allowed by bcrypt 72 = 54 bytes in base64 - dummyAPIKeySize = 54 + dummyAPIKeySize = 54 ) type APIKey struct { diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index 22c171c63..8797761a4 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -10,15 +10,16 @@ import ( "time" jwt "github.com/appleboy/gin-jwt/v2" + "github.com/gin-gonic/gin" + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/bcrypt" + "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" - "github.com/gin-gonic/gin" - "github.com/go-openapi/strfmt" - log "github.com/sirupsen/logrus" - "golang.org/x/crypto/bcrypt" ) var identityKey = "id" @@ -46,16 +47,12 @@ func IdentityHandler(c *gin.Context) interface{} { } } - - type authInput struct { - machineID string - clientMachine *ent.Machine + machineID string + clientMachine *ent.Machine scenariosInput []string } - - func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { ret := authInput{} @@ -123,8 +120,6 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { return &ret, nil } - - func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { var loginInput models.WatcherAuthRequest var err error @@ -169,7 +164,6 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { return &ret, nil } - func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { var err error var auth *authInput diff --git a/pkg/apiserver/utils.go b/pkg/apiserver/utils.go deleted file mode 100644 index 409d79b01..000000000 --- a/pkg/apiserver/utils.go +++ /dev/null @@ -1,27 +0,0 @@ -package apiserver - -import ( - "crypto/tls" - "fmt" - - log "github.com/sirupsen/logrus" -) - -func getTLSAuthType(authType string) (tls.ClientAuthType, error) { - switch authType { - case "NoClientCert": - return tls.NoClientCert, nil - case "RequestClientCert": - log.Warn("RequestClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead") - return tls.RequestClientCert, nil - case "RequireAnyClientCert": - log.Warn("RequireAnyClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead") - return tls.RequireAnyClientCert, nil - case "VerifyClientCertIfGiven": - return tls.VerifyClientCertIfGiven, nil - case "RequireAndVerifyClientCert": - return tls.RequireAndVerifyClientCert, nil - default: - return 0, fmt.Errorf("unknown TLS client_verification value: %s", authType) - } -} diff --git a/pkg/csconfig/api.go b/pkg/csconfig/api.go index c1577782f..07b8d154c 100644 --- a/pkg/csconfig/api.go +++ b/pkg/csconfig/api.go @@ -212,18 +212,6 @@ type LocalApiServerCfg struct { CapiWhitelists *CapiWhitelist `yaml:"-"` } -type TLSCfg struct { - CertFilePath string `yaml:"cert_file"` - KeyFilePath string `yaml:"key_file"` - ClientVerification string `yaml:"client_verification,omitempty"` - ServerName string `yaml:"server_name"` - CACertPath string `yaml:"ca_cert_path"` - AllowedAgentsOU []string `yaml:"agents_allowed_ou"` - AllowedBouncersOU []string `yaml:"bouncers_allowed_ou"` - CRLPath string `yaml:"crl_path"` - CacheExpiration *time.Duration `yaml:"cache_expiration,omitempty"` -} - func (c *Config) LoadAPIServer() error { if c.DisableAPI { log.Warning("crowdsec local API is disabled from flag") @@ -243,13 +231,16 @@ func (c *Config) LoadAPIServer() error { if !*c.API.Server.Enable { log.Warning("crowdsec local API is disabled because 'enable' is set to false") c.DisableAPI = true - return nil } if c.DisableAPI { return nil } + if c.API.Server.ListenURI == "" { + return fmt.Errorf("no listen_uri specified") + } + //inherit log level from common, then api->server var logLevel log.Level if c.API.Server.LogLevel != nil { diff --git a/pkg/csconfig/api_test.go b/pkg/csconfig/api_test.go index 10128b76b..b39d6eccf 100644 --- a/pkg/csconfig/api_test.go +++ b/pkg/csconfig/api_test.go @@ -219,7 +219,9 @@ func TestLoadAPIServer(t *testing.T) { input: &Config{ Self: []byte(configData), API: &APICfg{ - Server: &LocalApiServerCfg{}, + Server: &LocalApiServerCfg{ + ListenURI: "http://crowdsec.api", + }, }, Common: &CommonCfg{ LogDir: "./testdata/", diff --git a/pkg/csconfig/tls.go b/pkg/csconfig/tls.go new file mode 100644 index 000000000..897112a75 --- /dev/null +++ b/pkg/csconfig/tls.go @@ -0,0 +1,87 @@ +package csconfig + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "time" + + log "github.com/sirupsen/logrus" +) + +type TLSCfg struct { + CertFilePath string `yaml:"cert_file"` + KeyFilePath string `yaml:"key_file"` + ClientVerification string `yaml:"client_verification,omitempty"` + ServerName string `yaml:"server_name"` + CACertPath string `yaml:"ca_cert_path"` + AllowedAgentsOU []string `yaml:"agents_allowed_ou"` + AllowedBouncersOU []string `yaml:"bouncers_allowed_ou"` + CRLPath string `yaml:"crl_path"` + CacheExpiration *time.Duration `yaml:"cache_expiration,omitempty"` +} + +func (t *TLSCfg) GetAuthType() (tls.ClientAuthType, error) { + if t.ClientVerification == "" { + // sounds like a sane default: verify client cert if given, but don't make it mandatory + return tls.VerifyClientCertIfGiven, nil + } + + switch t.ClientVerification { + case "NoClientCert": + return tls.NoClientCert, nil + case "RequestClientCert": + log.Warn("RequestClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead") + return tls.RequestClientCert, nil + case "RequireAnyClientCert": + log.Warn("RequireAnyClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead") + return tls.RequireAnyClientCert, nil + case "VerifyClientCertIfGiven": + return tls.VerifyClientCertIfGiven, nil + case "RequireAndVerifyClientCert": + return tls.RequireAndVerifyClientCert, nil + default: + return 0, fmt.Errorf("unknown TLS client_verification value: %s", t.ClientVerification) + } +} + +func (t *TLSCfg) GetTLSConfig() (*tls.Config, error) { + if t == nil { + return &tls.Config{}, nil + } + + clientAuthType, err := t.GetAuthType() + if err != nil { + return nil, err + } + + caCertPool, err := x509.SystemCertPool() + if err != nil { + log.Warnf("Error loading system CA certificates: %s", err) + } + + if caCertPool == nil { + caCertPool = x509.NewCertPool() + } + + // the > condition below is a weird way to say "if a client certificate is required" + // see https://pkg.go.dev/crypto/tls#ClientAuthType + if clientAuthType > tls.RequestClientCert && t.CACertPath != "" { + log.Infof("(tls) Client Auth Type set to %s", clientAuthType.String()) + + caCert, err := os.ReadFile(t.CACertPath) + if err != nil { + return nil, fmt.Errorf("while opening cert file: %w", err) + } + + caCertPool.AppendCertsFromPEM(caCert) + } + + return &tls.Config{ + ServerName: t.ServerName, //should it be removed ? + ClientAuth: clientAuthType, + ClientCAs: caCertPool, + MinVersion: tls.VersionTLS12, // TLS versions below 1.2 are considered insecure - see https://www.rfc-editor.org/rfc/rfc7525.txt for details + }, nil +} diff --git a/test/bats/01_crowdsec_lapi.bats b/test/bats/01_crowdsec_lapi.bats new file mode 100644 index 000000000..4819d724f --- /dev/null +++ b/test/bats/01_crowdsec_lapi.bats @@ -0,0 +1,51 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + load "../lib/bats-file/load.bash" + ./instance-data load +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +# Tests for LAPI configuration and startup + +@test "lapi (.api.server.enable=false)" { + rune -0 config_set '.api.server.enable=false' + rune -1 "${CROWDSEC}" -no-cs + assert_stderr --partial "You must run at least the API Server or crowdsec" +} + +@test "lapi (no .api.server.listen_uri)" { + rune -0 config_set 'del(.api.server.listen_uri)' + rune -1 "${CROWDSEC}" -no-cs + assert_stderr --partial "no listen_uri specified" +} + +@test "lapi (bad .api.server.listen_uri)" { + rune -0 config_set '.api.server.listen_uri="127.0.0.1:-80"' + rune -1 "${CROWDSEC}" -no-cs + assert_stderr --partial "while starting API server: listening on 127.0.0.1:-80: listen tcp: address -80: invalid port" +} + +@test "lapi (listen on random port)" { + config_set '.common.log_media="stdout"' + rune -0 config_set '.api.server.listen_uri="127.0.0.1:0"' + rune -0 wait-for --err "CrowdSec Local API listening on 127.0.0.1:" "${CROWDSEC}" -no-cs +} + diff --git a/test/bats/01_cscli.bats b/test/bats/01_cscli.bats index d99510f98..dd03ea207 100644 --- a/test/bats/01_cscli.bats +++ b/test/bats/01_cscli.bats @@ -15,7 +15,7 @@ setup() { load "../lib/setup.sh" load "../lib/bats-file/load.bash" ./instance-data load - ./instance-crowdsec start + # don't run crowdsec here, not all tests require a running instance } teardown() { @@ -204,6 +204,7 @@ teardown() { } @test "cscli lapi status" { + rune -0 ./instance-crowdsec start rune -0 cscli lapi status assert_stderr --partial "Loaded credentials from" @@ -260,6 +261,7 @@ teardown() { } @test "cscli - bad LAPI password" { + rune -0 ./instance-crowdsec start LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') config_set "${LOCAL_API_CREDENTIALS}" '.password="meh"' @@ -269,6 +271,7 @@ teardown() { } @test "cscli metrics" { + rune -0 ./instance-crowdsec start rune -0 cscli lapi status rune -0 cscli metrics assert_output --partial "Route" @@ -297,6 +300,7 @@ teardown() { } @test "cscli explain" { + rune -0 ./instance-crowdsec start line="Sep 19 18:33:22 scw-d95986 sshd[24347]: pam_unix(sshd:auth): authentication failure; logname= uid=0 euid=0 tty=ssh ruser= rhost=1.2.3.4" rune -0 cscli parsers install crowdsecurity/syslog-logs