light pkg/api{client,server} refact (#2659)
* tests: don't run crowdsec if not necessary * make listen_uri report the random port number when 0 is requested * move apiserver.getTLSAuthType() -> csconfig.TLSCfg.GetAuthType() * move apiserver.isEnrolled() -> apiclient.ApiClient.IsEnrolled() * extract function apiserver.recoverFromPanic() * simplify and move APIServer.GetTLSConfig() -> TLSCfg.GetTLSConfig() * moved TLSCfg type to csconfig/tls.go * APIServer.InitController(): early return / happy path * extract function apiserver.newGinLogger() * lapi tests * update unit test * lint (testify) * lint (whitespace, variable names) * update docker tests
This commit is contained in:
parent
67cdf91f94
commit
89f704ef18
46 changed files with 927 additions and 477 deletions
|
@ -13,7 +13,7 @@ def test_no_agent(crowdsec, flavor):
|
||||||
'DISABLE_AGENT': 'true',
|
'DISABLE_AGENT': 'true',
|
||||||
}
|
}
|
||||||
with crowdsec(flavor=flavor, environment=env) as cs:
|
with crowdsec(flavor=flavor, environment=env) as cs:
|
||||||
cs.wait_for_log("*CrowdSec Local API listening on 0.0.0.0:8080*")
|
cs.wait_for_log("*CrowdSec Local API listening on *:8080*")
|
||||||
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
||||||
res = cs.cont.exec_run('cscli lapi status')
|
res = cs.cont.exec_run('cscli lapi status')
|
||||||
assert res.exit_code == 0
|
assert res.exit_code == 0
|
||||||
|
@ -37,7 +37,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory):
|
||||||
with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs:
|
with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs:
|
||||||
cs.wait_for_log([
|
cs.wait_for_log([
|
||||||
"*Generate local agent credentials*",
|
"*Generate local agent credentials*",
|
||||||
"*CrowdSec Local API listening on 0.0.0.0:8080*",
|
"*CrowdSec Local API listening on *:8080*",
|
||||||
])
|
])
|
||||||
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
||||||
res = cs.cont.exec_run('cscli lapi status')
|
res = cs.cont.exec_run('cscli lapi status')
|
||||||
|
@ -50,7 +50,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory):
|
||||||
with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs:
|
with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs:
|
||||||
cs.wait_for_log([
|
cs.wait_for_log([
|
||||||
"*Generate local agent credentials*",
|
"*Generate local agent credentials*",
|
||||||
"*CrowdSec Local API listening on 0.0.0.0:8080*",
|
"*CrowdSec Local API listening on *:8080*",
|
||||||
])
|
])
|
||||||
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
||||||
res = cs.cont.exec_run('cscli lapi status')
|
res = cs.cont.exec_run('cscli lapi status')
|
||||||
|
@ -65,7 +65,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory):
|
||||||
with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs:
|
with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs:
|
||||||
cs.wait_for_log([
|
cs.wait_for_log([
|
||||||
"*Generate local agent credentials*",
|
"*Generate local agent credentials*",
|
||||||
"*CrowdSec Local API listening on 0.0.0.0:8080*",
|
"*CrowdSec Local API listening on *:8080*",
|
||||||
])
|
])
|
||||||
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
||||||
res = cs.cont.exec_run('cscli lapi status')
|
res = cs.cont.exec_run('cscli lapi status')
|
||||||
|
@ -78,7 +78,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory):
|
||||||
with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs:
|
with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs:
|
||||||
cs.wait_for_log([
|
cs.wait_for_log([
|
||||||
"*Local agent already registered*",
|
"*Local agent already registered*",
|
||||||
"*CrowdSec Local API listening on 0.0.0.0:8080*",
|
"*CrowdSec Local API listening on *:8080*",
|
||||||
])
|
])
|
||||||
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
||||||
res = cs.cont.exec_run('cscli lapi status')
|
res = cs.cont.exec_run('cscli lapi status')
|
||||||
|
|
|
@ -29,7 +29,7 @@ def test_split_lapi_agent(crowdsec, flavor):
|
||||||
cs_agent = crowdsec(name=agentname, environment=agent_env, flavor=flavor)
|
cs_agent = crowdsec(name=agentname, environment=agent_env, flavor=flavor)
|
||||||
|
|
||||||
with cs_lapi as lapi:
|
with cs_lapi as lapi:
|
||||||
lapi.wait_for_log("*CrowdSec Local API listening on 0.0.0.0:8080*")
|
lapi.wait_for_log("*CrowdSec Local API listening on *:8080*")
|
||||||
lapi.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
lapi.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
||||||
with cs_agent as agent:
|
with cs_agent as agent:
|
||||||
agent.wait_for_log("*Starting processing data*")
|
agent.wait_for_log("*Starting processing data*")
|
||||||
|
|
|
@ -11,7 +11,7 @@ def test_local_api_url_default(crowdsec, flavor):
|
||||||
"""Test LOCAL_API_URL (default)"""
|
"""Test LOCAL_API_URL (default)"""
|
||||||
with crowdsec(flavor=flavor) as cs:
|
with crowdsec(flavor=flavor) as cs:
|
||||||
cs.wait_for_log([
|
cs.wait_for_log([
|
||||||
"*CrowdSec Local API listening on 0.0.0.0:8080*",
|
"*CrowdSec Local API listening on *:8080*",
|
||||||
"*Starting processing data*"
|
"*Starting processing data*"
|
||||||
])
|
])
|
||||||
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
||||||
|
@ -29,7 +29,7 @@ def test_local_api_url(crowdsec, flavor):
|
||||||
}
|
}
|
||||||
with crowdsec(flavor=flavor, environment=env) as cs:
|
with crowdsec(flavor=flavor, environment=env) as cs:
|
||||||
cs.wait_for_log([
|
cs.wait_for_log([
|
||||||
"*CrowdSec Local API listening on 0.0.0.0:8080*",
|
"*CrowdSec Local API listening on *:8080*",
|
||||||
"*Starting processing data*"
|
"*Starting processing data*"
|
||||||
])
|
])
|
||||||
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
||||||
|
@ -54,7 +54,7 @@ def test_local_api_url_ipv6(crowdsec, flavor):
|
||||||
with crowdsec(flavor=flavor, environment=env) as cs:
|
with crowdsec(flavor=flavor, environment=env) as cs:
|
||||||
cs.wait_for_log([
|
cs.wait_for_log([
|
||||||
"*Starting processing data*",
|
"*Starting processing data*",
|
||||||
"*CrowdSec Local API listening on 0.0.0.0:8080*",
|
"*CrowdSec Local API listening on [::1]:8080*",
|
||||||
])
|
])
|
||||||
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
|
||||||
res = cs.cont.exec_run('cscli lapi status')
|
res = cs.cont.exec_run('cscli lapi status')
|
||||||
|
|
|
@ -23,7 +23,7 @@ def test_missing_key_file(crowdsec, flavor):
|
||||||
|
|
||||||
with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs:
|
with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs:
|
||||||
# XXX: this message appears twice, is that normal?
|
# XXX: this message appears twice, is that normal?
|
||||||
cs.wait_for_log("*while serving local API: missing TLS key file*")
|
cs.wait_for_log("*while starting API server: missing TLS key file*")
|
||||||
|
|
||||||
|
|
||||||
def test_missing_cert_file(crowdsec, flavor):
|
def test_missing_cert_file(crowdsec, flavor):
|
||||||
|
@ -35,7 +35,7 @@ def test_missing_cert_file(crowdsec, flavor):
|
||||||
}
|
}
|
||||||
|
|
||||||
with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs:
|
with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs:
|
||||||
cs.wait_for_log("*while serving local API: missing TLS cert file*")
|
cs.wait_for_log("*while starting API server: missing TLS cert file*")
|
||||||
|
|
||||||
|
|
||||||
def test_tls_missing_ca(crowdsec, flavor, certs_dir):
|
def test_tls_missing_ca(crowdsec, flavor, certs_dir):
|
||||||
|
@ -174,7 +174,7 @@ def test_tls_split_lapi_agent(crowdsec, flavor, certs_dir):
|
||||||
with cs_lapi as lapi:
|
with cs_lapi as lapi:
|
||||||
lapi.wait_for_log([
|
lapi.wait_for_log([
|
||||||
"*(tls) Client Auth Type set to VerifyClientCertIfGiven*",
|
"*(tls) Client Auth Type set to VerifyClientCertIfGiven*",
|
||||||
"*CrowdSec Local API listening on 0.0.0.0:8080*"
|
"*CrowdSec Local API listening on *:8080*"
|
||||||
])
|
])
|
||||||
# TODO: wait_for_https
|
# TODO: wait_for_https
|
||||||
lapi.wait_for_http(8080, '/health', want_status=None)
|
lapi.wait_for_http(8080, '/health', want_status=None)
|
||||||
|
@ -225,7 +225,7 @@ def test_tls_mutual_split_lapi_agent(crowdsec, flavor, certs_dir):
|
||||||
with cs_lapi as lapi:
|
with cs_lapi as lapi:
|
||||||
lapi.wait_for_log([
|
lapi.wait_for_log([
|
||||||
"*(tls) Client Auth Type set to VerifyClientCertIfGiven*",
|
"*(tls) Client Auth Type set to VerifyClientCertIfGiven*",
|
||||||
"*CrowdSec Local API listening on 0.0.0.0:8080*"
|
"*CrowdSec Local API listening on *:8080*"
|
||||||
])
|
])
|
||||||
# TODO: wait_for_https
|
# TODO: wait_for_https
|
||||||
lapi.wait_for_http(8080, '/health', want_status=None)
|
lapi.wait_for_http(8080, '/health', want_status=None)
|
||||||
|
@ -276,7 +276,7 @@ def test_tls_client_ou(crowdsec, certs_dir):
|
||||||
with cs_lapi as lapi:
|
with cs_lapi as lapi:
|
||||||
lapi.wait_for_log([
|
lapi.wait_for_log([
|
||||||
"*(tls) Client Auth Type set to VerifyClientCertIfGiven*",
|
"*(tls) Client Auth Type set to VerifyClientCertIfGiven*",
|
||||||
"*CrowdSec Local API listening on 0.0.0.0:8080*"
|
"*CrowdSec Local API listening on *:8080*"
|
||||||
])
|
])
|
||||||
# TODO: wait_for_https
|
# TODO: wait_for_https
|
||||||
lapi.wait_for_http(8080, '/health', want_status=None)
|
lapi.wait_for_http(8080, '/health', want_status=None)
|
||||||
|
@ -306,7 +306,7 @@ def test_tls_client_ou(crowdsec, certs_dir):
|
||||||
with cs_lapi as lapi:
|
with cs_lapi as lapi:
|
||||||
lapi.wait_for_log([
|
lapi.wait_for_log([
|
||||||
"*(tls) Client Auth Type set to VerifyClientCertIfGiven*",
|
"*(tls) Client Auth Type set to VerifyClientCertIfGiven*",
|
||||||
"*CrowdSec Local API listening on 0.0.0.0:8080*"
|
"*CrowdSec Local API listening on *:8080*"
|
||||||
])
|
])
|
||||||
# TODO: wait_for_https
|
# TODO: wait_for_https
|
||||||
lapi.wait_for_http(8080, '/health', want_status=None)
|
lapi.wait_for_http(8080, '/health', want_status=None)
|
||||||
|
|
|
@ -49,31 +49,37 @@ type AlertsDeleteOpts struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) (*models.AddAlertsResponse, *Response, error) {
|
func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) (*models.AddAlertsResponse, *Response, error) {
|
||||||
|
var addedIds models.AddAlertsResponse
|
||||||
var added_ids models.AddAlertsResponse
|
|
||||||
|
|
||||||
u := fmt.Sprintf("%s/alerts", s.client.URLPrefix)
|
u := fmt.Sprintf("%s/alerts", s.client.URLPrefix)
|
||||||
req, err := s.client.NewRequest(http.MethodPost, u, &alerts)
|
req, err := s.client.NewRequest(http.MethodPost, u, &alerts)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := s.client.Do(ctx, req, &added_ids)
|
resp, err := s.client.Do(ctx, req, &addedIds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, resp, err
|
return nil, resp, err
|
||||||
}
|
}
|
||||||
return &added_ids, resp, nil
|
|
||||||
|
return &addedIds, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// to demo query arguments
|
// to demo query arguments
|
||||||
func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.GetAlertsResponse, *Response, error) {
|
func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.GetAlertsResponse, *Response, error) {
|
||||||
var alerts models.GetAlertsResponse
|
var (
|
||||||
var URI string
|
alerts models.GetAlertsResponse
|
||||||
|
URI string
|
||||||
|
)
|
||||||
|
|
||||||
u := fmt.Sprintf("%s/alerts", s.client.URLPrefix)
|
u := fmt.Sprintf("%s/alerts", s.client.URLPrefix)
|
||||||
params, err := qs.Values(opts)
|
params, err := qs.Values(opts)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("building query: %w", err)
|
return nil, nil, fmt.Errorf("building query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(params) > 0 {
|
if len(params) > 0 {
|
||||||
URI = fmt.Sprintf("%s?%s", u, params.Encode())
|
URI = fmt.Sprintf("%s?%s", u, params.Encode())
|
||||||
} else {
|
} else {
|
||||||
|
@ -89,16 +95,19 @@ func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, resp, fmt.Errorf("performing request: %w", err)
|
return nil, resp, fmt.Errorf("performing request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &alerts, resp, nil
|
return &alerts, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// to demo query arguments
|
// to demo query arguments
|
||||||
func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*models.DeleteAlertsResponse, *Response, error) {
|
func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*models.DeleteAlertsResponse, *Response, error) {
|
||||||
var alerts models.DeleteAlertsResponse
|
var alerts models.DeleteAlertsResponse
|
||||||
|
|
||||||
params, err := qs.Values(opts)
|
params, err := qs.Values(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
u := fmt.Sprintf("%s/alerts?%s", s.client.URLPrefix, params.Encode())
|
u := fmt.Sprintf("%s/alerts?%s", s.client.URLPrefix, params.Encode())
|
||||||
|
|
||||||
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
||||||
|
@ -110,12 +119,14 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, resp, err
|
return nil, resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &alerts, resp, nil
|
return &alerts, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AlertsService) DeleteOne(ctx context.Context, alert_id string) (*models.DeleteAlertsResponse, *Response, error) {
|
func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.DeleteAlertsResponse, *Response, error) {
|
||||||
var alerts models.DeleteAlertsResponse
|
var alerts models.DeleteAlertsResponse
|
||||||
u := fmt.Sprintf("%s/alerts/%s", s.client.URLPrefix, alert_id)
|
|
||||||
|
u := fmt.Sprintf("%s/alerts/%s", s.client.URLPrefix, alertID)
|
||||||
|
|
||||||
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -126,11 +137,13 @@ func (s *AlertsService) DeleteOne(ctx context.Context, alert_id string) (*models
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, resp, err
|
return nil, resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &alerts, resp, nil
|
return &alerts, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert, *Response, error) {
|
func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert, *Response, error) {
|
||||||
var alert models.Alert
|
var alert models.Alert
|
||||||
|
|
||||||
u := fmt.Sprintf("%s/alerts/%d", s.client.URLPrefix, alertID)
|
u := fmt.Sprintf("%s/alerts/%d", s.client.URLPrefix, alertID)
|
||||||
|
|
||||||
req, err := s.client.NewRequest(http.MethodGet, u, nil)
|
req, err := s.client.NewRequest(http.MethodGet, u, nil)
|
||||||
|
@ -142,5 +155,6 @@ func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &alert, resp, nil
|
return &alert, resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,10 +26,12 @@ func TestAlertsListAsMachine(t *testing.T) {
|
||||||
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
|
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
|
||||||
})
|
})
|
||||||
log.Printf("URL is %s", urlx)
|
log.Printf("URL is %s", urlx)
|
||||||
|
|
||||||
apiURL, err := url.Parse(urlx + "/")
|
apiURL, err := url.Parse(urlx + "/")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("parsing api url: %s", apiURL)
|
log.Fatalf("parsing api url: %s", apiURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := NewClient(&Config{
|
client, err := NewClient(&Config{
|
||||||
MachineID: "test_login",
|
MachineID: "test_login",
|
||||||
Password: "test_password",
|
Password: "test_password",
|
||||||
|
@ -199,6 +201,7 @@ func TestAlertsListAsMachine(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("test Unable to list alerts : %+v", err)
|
log.Errorf("test Unable to list alerts : %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.Response.StatusCode != http.StatusOK {
|
if resp.Response.StatusCode != http.StatusOK {
|
||||||
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
||||||
}
|
}
|
||||||
|
@ -209,14 +212,17 @@ func TestAlertsListAsMachine(t *testing.T) {
|
||||||
//this one doesn't
|
//this one doesn't
|
||||||
filter := AlertsListOpts{IPEquals: new(string)}
|
filter := AlertsListOpts{IPEquals: new(string)}
|
||||||
*filter.IPEquals = "1.2.3.4"
|
*filter.IPEquals = "1.2.3.4"
|
||||||
|
|
||||||
alerts, resp, err = client.Alerts.List(context.Background(), filter)
|
alerts, resp, err = client.Alerts.List(context.Background(), filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("test Unable to list alerts : %+v", err)
|
log.Errorf("test Unable to list alerts : %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.Response.StatusCode != http.StatusOK {
|
if resp.Response.StatusCode != http.StatusOK {
|
||||||
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
||||||
}
|
}
|
||||||
assert.Equal(t, 0, len(*alerts))
|
|
||||||
|
assert.Empty(t, *alerts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAlertsGetAsMachine(t *testing.T) {
|
func TestAlertsGetAsMachine(t *testing.T) {
|
||||||
|
@ -228,10 +234,12 @@ func TestAlertsGetAsMachine(t *testing.T) {
|
||||||
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
|
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
|
||||||
})
|
})
|
||||||
log.Printf("URL is %s", urlx)
|
log.Printf("URL is %s", urlx)
|
||||||
|
|
||||||
apiURL, err := url.Parse(urlx + "/")
|
apiURL, err := url.Parse(urlx + "/")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("parsing api url: %s", apiURL)
|
log.Fatalf("parsing api url: %s", apiURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := NewClient(&Config{
|
client, err := NewClient(&Config{
|
||||||
MachineID: "test_login",
|
MachineID: "test_login",
|
||||||
Password: "test_password",
|
Password: "test_password",
|
||||||
|
@ -390,6 +398,7 @@ func TestAlertsGetAsMachine(t *testing.T) {
|
||||||
|
|
||||||
alerts, resp, err := client.Alerts.GetByID(context.Background(), 1)
|
alerts, resp, err := client.Alerts.GetByID(context.Background(), 1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if resp.Response.StatusCode != http.StatusOK {
|
if resp.Response.StatusCode != http.StatusOK {
|
||||||
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
||||||
}
|
}
|
||||||
|
@ -401,7 +410,6 @@ func TestAlertsGetAsMachine(t *testing.T) {
|
||||||
//fail
|
//fail
|
||||||
_, _, err = client.Alerts.GetByID(context.Background(), 2)
|
_, _, err = client.Alerts.GetByID(context.Background(), 2)
|
||||||
assert.Contains(t, fmt.Sprintf("%s", err), "API error: object not found")
|
assert.Contains(t, fmt.Sprintf("%s", err), "API error: object not found")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAlertsCreateAsMachine(t *testing.T) {
|
func TestAlertsCreateAsMachine(t *testing.T) {
|
||||||
|
@ -418,10 +426,12 @@ func TestAlertsCreateAsMachine(t *testing.T) {
|
||||||
w.Write([]byte(`["3"]`))
|
w.Write([]byte(`["3"]`))
|
||||||
})
|
})
|
||||||
log.Printf("URL is %s", urlx)
|
log.Printf("URL is %s", urlx)
|
||||||
|
|
||||||
apiURL, err := url.Parse(urlx + "/")
|
apiURL, err := url.Parse(urlx + "/")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("parsing api url: %s", apiURL)
|
log.Fatalf("parsing api url: %s", apiURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := NewClient(&Config{
|
client, err := NewClient(&Config{
|
||||||
MachineID: "test_login",
|
MachineID: "test_login",
|
||||||
Password: "test_password",
|
Password: "test_password",
|
||||||
|
@ -435,13 +445,17 @@ func TestAlertsCreateAsMachine(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
|
||||||
alert := models.AddAlertsRequest{}
|
alert := models.AddAlertsRequest{}
|
||||||
alerts, resp, err := client.Alerts.Add(context.Background(), alert)
|
alerts, resp, err := client.Alerts.Add(context.Background(), alert)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected := &models.AddAlertsResponse{"3"}
|
expected := &models.AddAlertsResponse{"3"}
|
||||||
|
|
||||||
if resp.Response.StatusCode != http.StatusOK {
|
if resp.Response.StatusCode != http.StatusOK {
|
||||||
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(*alerts, *expected) {
|
if !reflect.DeepEqual(*alerts, *expected) {
|
||||||
t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected)
|
t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected)
|
||||||
}
|
}
|
||||||
|
@ -457,15 +471,17 @@ func TestAlertsDeleteAsMachine(t *testing.T) {
|
||||||
})
|
})
|
||||||
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
|
||||||
testMethod(t, r, "DELETE")
|
testMethod(t, r, "DELETE")
|
||||||
assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4")
|
assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery)
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(`{"message":"0 deleted alerts"}`))
|
w.Write([]byte(`{"message":"0 deleted alerts"}`))
|
||||||
})
|
})
|
||||||
log.Printf("URL is %s", urlx)
|
log.Printf("URL is %s", urlx)
|
||||||
|
|
||||||
apiURL, err := url.Parse(urlx + "/")
|
apiURL, err := url.Parse(urlx + "/")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("parsing api url: %s", apiURL)
|
log.Fatalf("parsing api url: %s", apiURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := NewClient(&Config{
|
client, err := NewClient(&Config{
|
||||||
MachineID: "test_login",
|
MachineID: "test_login",
|
||||||
Password: "test_password",
|
Password: "test_password",
|
||||||
|
@ -479,15 +495,18 @@ func TestAlertsDeleteAsMachine(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
|
||||||
alert := AlertsDeleteOpts{IPEquals: new(string)}
|
alert := AlertsDeleteOpts{IPEquals: new(string)}
|
||||||
*alert.IPEquals = "1.2.3.4"
|
*alert.IPEquals = "1.2.3.4"
|
||||||
alerts, resp, err := client.Alerts.Delete(context.Background(), alert)
|
alerts, resp, err := client.Alerts.Delete(context.Background(), alert)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected := &models.DeleteAlertsResponse{NbDeleted: ""}
|
expected := &models.DeleteAlertsResponse{NbDeleted: ""}
|
||||||
|
|
||||||
if resp.Response.StatusCode != http.StatusOK {
|
if resp.Response.StatusCode != http.StatusOK {
|
||||||
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(*alerts, *expected) {
|
if !reflect.DeepEqual(*alerts, *expected) {
|
||||||
t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected)
|
t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected)
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,10 +41,13 @@ func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
// specification of http.RoundTripper.
|
// specification of http.RoundTripper.
|
||||||
req = cloneRequest(req)
|
req = cloneRequest(req)
|
||||||
req.Header.Add("X-Api-Key", t.APIKey)
|
req.Header.Add("X-Api-Key", t.APIKey)
|
||||||
|
|
||||||
if t.UserAgent != "" {
|
if t.UserAgent != "" {
|
||||||
req.Header.Add("User-Agent", t.UserAgent)
|
req.Header.Add("User-Agent", t.UserAgent)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("req-api: %s %s", req.Method, req.URL.String())
|
log.Debugf("req-api: %s %s", req.Method, req.URL.String())
|
||||||
|
|
||||||
if log.GetLevel() >= log.TraceLevel {
|
if log.GetLevel() >= log.TraceLevel {
|
||||||
dump, _ := httputil.DumpRequest(req, true)
|
dump, _ := httputil.DumpRequest(req, true)
|
||||||
log.Tracef("auth-api request: %s", string(dump))
|
log.Tracef("auth-api request: %s", string(dump))
|
||||||
|
@ -55,6 +58,7 @@ func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err)
|
log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err)
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if log.GetLevel() >= log.TraceLevel {
|
if log.GetLevel() >= log.TraceLevel {
|
||||||
dump, _ := httputil.DumpResponse(resp, true)
|
dump, _ := httputil.DumpResponse(resp, true)
|
||||||
log.Tracef("auth-api response: %s", string(dump))
|
log.Tracef("auth-api response: %s", string(dump))
|
||||||
|
@ -73,6 +77,7 @@ func (t *APIKeyTransport) transport() http.RoundTripper {
|
||||||
if t.Transport != nil {
|
if t.Transport != nil {
|
||||||
return t.Transport
|
return t.Transport
|
||||||
}
|
}
|
||||||
|
|
||||||
return http.DefaultTransport
|
return http.DefaultTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,15 +95,19 @@ func (r retryRoundTripper) ShouldRetry(statusCode int) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
var resp *http.Response
|
var (
|
||||||
var err error
|
resp *http.Response
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
backoff := 0
|
backoff := 0
|
||||||
maxAttempts := r.maxAttempts
|
maxAttempts := r.maxAttempts
|
||||||
|
|
||||||
if fflag.DisableHttpRetryBackoff.IsEnabled() {
|
if fflag.DisableHttpRetryBackoff.IsEnabled() {
|
||||||
maxAttempts = 1
|
maxAttempts = 1
|
||||||
}
|
}
|
||||||
|
@ -108,6 +117,7 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
|
||||||
if r.withBackOff {
|
if r.withBackOff {
|
||||||
backoff += 10 + rand.Intn(20)
|
backoff += 10 + rand.Intn(20)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts)
|
log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts)
|
||||||
select {
|
select {
|
||||||
case <-req.Context().Done():
|
case <-req.Context().Done():
|
||||||
|
@ -115,22 +125,28 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
|
||||||
case <-time.After(time.Duration(backoff) * time.Second):
|
case <-time.After(time.Duration(backoff) * time.Second):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.onBeforeRequest != nil {
|
if r.onBeforeRequest != nil {
|
||||||
r.onBeforeRequest(i)
|
r.onBeforeRequest(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
clonedReq := cloneRequest(req)
|
clonedReq := cloneRequest(req)
|
||||||
resp, err = r.next.RoundTrip(clonedReq)
|
resp, err = r.next.RoundTrip(clonedReq)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
left := maxAttempts - i - 1
|
left := maxAttempts - i - 1
|
||||||
if left > 0 {
|
if left > 0 {
|
||||||
log.Errorf("error while performing request: %s; %d retries left", err, left)
|
log.Errorf("error while performing request: %s; %d retries left", err, left)
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !r.ShouldRetry(resp.StatusCode) {
|
if !r.ShouldRetry(resp.StatusCode) {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,6 +173,7 @@ func (t *JWTTransport) refreshJwtToken() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("can't update scenario list: %s", err)
|
return fmt.Errorf("can't update scenario list: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("scenarios list updated for '%s'", *t.MachineID)
|
log.Debugf("scenarios list updated for '%s'", *t.MachineID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,14 +192,18 @@ func (t *JWTTransport) refreshJwtToken() error {
|
||||||
enc := json.NewEncoder(buf)
|
enc := json.NewEncoder(buf)
|
||||||
enc.SetEscapeHTML(false)
|
enc.SetEscapeHTML(false)
|
||||||
err = enc.Encode(auth)
|
err = enc.Encode(auth)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not encode jwt auth body: %w", err)
|
return fmt.Errorf("could not encode jwt auth body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf)
|
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not create request: %w", err)
|
return fmt.Errorf("could not create request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
|
||||||
client := &http.Client{
|
client := &http.Client{
|
||||||
Transport: &retryRoundTripper{
|
Transport: &retryRoundTripper{
|
||||||
next: http.DefaultTransport,
|
next: http.DefaultTransport,
|
||||||
|
@ -191,9 +212,11 @@ func (t *JWTTransport) refreshJwtToken() error {
|
||||||
retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError},
|
retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.UserAgent != "" {
|
if t.UserAgent != "" {
|
||||||
req.Header.Add("User-Agent", t.UserAgent)
|
req.Header.Add("User-Agent", t.UserAgent)
|
||||||
}
|
}
|
||||||
|
|
||||||
if log.GetLevel() >= log.TraceLevel {
|
if log.GetLevel() >= log.TraceLevel {
|
||||||
dump, _ := httputil.DumpRequest(req, true)
|
dump, _ := httputil.DumpRequest(req, true)
|
||||||
log.Tracef("auth-jwt request: %s", string(dump))
|
log.Tracef("auth-jwt request: %s", string(dump))
|
||||||
|
@ -205,6 +228,7 @@ func (t *JWTTransport) refreshJwtToken() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not get jwt token: %w", err)
|
return fmt.Errorf("could not get jwt token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("auth-jwt : http %d", resp.StatusCode)
|
log.Debugf("auth-jwt : http %d", resp.StatusCode)
|
||||||
|
|
||||||
if log.GetLevel() >= log.TraceLevel {
|
if log.GetLevel() >= log.TraceLevel {
|
||||||
|
@ -226,12 +250,15 @@ func (t *JWTTransport) refreshJwtToken() error {
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||||
return fmt.Errorf("unable to decode response: %w", err)
|
return fmt.Errorf("unable to decode response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil {
|
if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil {
|
||||||
return fmt.Errorf("unable to parse jwt expiration: %w", err)
|
return fmt.Errorf("unable to parse jwt expiration: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Token = response.Token
|
t.Token = response.Token
|
||||||
|
|
||||||
log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String())
|
log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -267,6 +294,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
dump, _ := httputil.DumpResponse(resp, true)
|
dump, _ := httputil.DumpResponse(resp, true)
|
||||||
log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
|
log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
/*we had an error (network error for example, or 401 because token is refused), reset the token ?*/
|
/*we had an error (network error for example, or 401 because token is refused), reset the token ?*/
|
||||||
t.Token = ""
|
t.Token = ""
|
||||||
|
@ -333,9 +361,12 @@ func cloneRequest(r *http.Request) *http.Request {
|
||||||
|
|
||||||
if r.Body != nil {
|
if r.Body != nil {
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
|
|
||||||
b.ReadFrom(r.Body)
|
b.ReadFrom(r.Body)
|
||||||
|
|
||||||
r.Body = io.NopCloser(&b)
|
r.Body = io.NopCloser(&b)
|
||||||
r2.Body = io.NopCloser(bytes.NewReader(b.Bytes()))
|
r2.Body = io.NopCloser(bytes.NewReader(b.Bytes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
return r2
|
return r2
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ type enrollRequest struct {
|
||||||
|
|
||||||
func (s *AuthService) UnregisterWatcher(ctx context.Context) (*Response, error) {
|
func (s *AuthService) UnregisterWatcher(ctx context.Context) (*Response, error) {
|
||||||
u := fmt.Sprintf("%s/watchers", s.client.URLPrefix)
|
u := fmt.Sprintf("%s/watchers", s.client.URLPrefix)
|
||||||
|
|
||||||
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -31,6 +32,7 @@ func (s *AuthService) UnregisterWatcher(ctx context.Context) (*Response, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,6 +48,7 @@ func (s *AuthService) RegisterWatcher(ctx context.Context, registration models.W
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,6 +56,7 @@ func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.Watch
|
||||||
var authResp models.WatcherAuthResponse
|
var authResp models.WatcherAuthResponse
|
||||||
|
|
||||||
u := fmt.Sprintf("%s/watchers/login", s.client.URLPrefix)
|
u := fmt.Sprintf("%s/watchers/login", s.client.URLPrefix)
|
||||||
|
|
||||||
req, err := s.client.NewRequest(http.MethodPost, u, &auth)
|
req, err := s.client.NewRequest(http.MethodPost, u, &auth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return authResp, nil, err
|
return authResp, nil, err
|
||||||
|
@ -62,11 +66,13 @@ func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.Watch
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return authResp, resp, err
|
return authResp, resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return authResp, resp, nil
|
return authResp, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthService) EnrollWatcher(ctx context.Context, enrollKey string, name string, tags []string, overwrite bool) (*Response, error) {
|
func (s *AuthService) EnrollWatcher(ctx context.Context, enrollKey string, name string, tags []string, overwrite bool) (*Response, error) {
|
||||||
u := fmt.Sprintf("%s/watchers/enroll", s.client.URLPrefix)
|
u := fmt.Sprintf("%s/watchers/enroll", s.client.URLPrefix)
|
||||||
|
|
||||||
req, err := s.client.NewRequest(http.MethodPost, u, &enrollRequest{EnrollKey: enrollKey, Name: name, Tags: tags, Overwrite: overwrite})
|
req, err := s.client.NewRequest(http.MethodPost, u, &enrollRequest{EnrollKey: enrollKey, Name: name, Tags: tags, Overwrite: overwrite})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -76,5 +82,6 @@ func (s *AuthService) EnrollWatcher(ctx context.Context, enrollKey string, name
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,6 +35,7 @@ func getLoginsForMockErrorCases() map[string]int {
|
||||||
|
|
||||||
func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) {
|
func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) {
|
||||||
loginsForMockErrorCases := getLoginsForMockErrorCases()
|
loginsForMockErrorCases := getLoginsForMockErrorCases()
|
||||||
|
|
||||||
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
||||||
testMethod(t, r, "POST")
|
testMethod(t, r, "POST")
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
|
@ -71,7 +72,6 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) {
|
||||||
* 400, 409, 500 => Error
|
* 400, 409, 500 => Error
|
||||||
*/
|
*/
|
||||||
func TestWatcherRegister(t *testing.T) {
|
func TestWatcherRegister(t *testing.T) {
|
||||||
|
|
||||||
log.SetLevel(log.DebugLevel)
|
log.SetLevel(log.DebugLevel)
|
||||||
|
|
||||||
mux, urlx, teardown := setup()
|
mux, urlx, teardown := setup()
|
||||||
|
@ -79,6 +79,7 @@ func TestWatcherRegister(t *testing.T) {
|
||||||
//body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password}
|
//body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password}
|
||||||
initBasicMuxMock(t, mux, "/watchers")
|
initBasicMuxMock(t, mux, "/watchers")
|
||||||
log.Printf("URL is %s", urlx)
|
log.Printf("URL is %s", urlx)
|
||||||
|
|
||||||
apiURL, err := url.Parse(urlx + "/")
|
apiURL, err := url.Parse(urlx + "/")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("parsing api url: %s", apiURL)
|
t.Fatalf("parsing api url: %s", apiURL)
|
||||||
|
@ -92,16 +93,19 @@ func TestWatcherRegister(t *testing.T) {
|
||||||
URL: apiURL,
|
URL: apiURL,
|
||||||
VersionPrefix: "v1",
|
VersionPrefix: "v1",
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := RegisterClient(&clientconfig, &http.Client{})
|
client, err := RegisterClient(&clientconfig, &http.Client{})
|
||||||
if client == nil || err != nil {
|
if client == nil || err != nil {
|
||||||
t.Fatalf("while registering client : %s", err)
|
t.Fatalf("while registering client : %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("->%T", client)
|
log.Printf("->%T", client)
|
||||||
|
|
||||||
// Testing error handling on Registration (400, 409, 500): should retrieve an error
|
// Testing error handling on Registration (400, 409, 500): should retrieve an error
|
||||||
errorCodesToTest := [3]int{http.StatusBadRequest, http.StatusConflict, http.StatusInternalServerError}
|
errorCodesToTest := [3]int{http.StatusBadRequest, http.StatusConflict, http.StatusInternalServerError}
|
||||||
for _, errorCodeToTest := range errorCodesToTest {
|
for _, errorCodeToTest := range errorCodesToTest {
|
||||||
clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest)
|
clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest)
|
||||||
|
|
||||||
client, err = RegisterClient(&clientconfig, &http.Client{})
|
client, err = RegisterClient(&clientconfig, &http.Client{})
|
||||||
if client != nil || err == nil {
|
if client != nil || err == nil {
|
||||||
t.Fatalf("The RegisterClient function should have returned an error for the response code %d", errorCodeToTest)
|
t.Fatalf("The RegisterClient function should have returned an error for the response code %d", errorCodeToTest)
|
||||||
|
@ -112,7 +116,6 @@ func TestWatcherRegister(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWatcherAuth(t *testing.T) {
|
func TestWatcherAuth(t *testing.T) {
|
||||||
|
|
||||||
log.SetLevel(log.DebugLevel)
|
log.SetLevel(log.DebugLevel)
|
||||||
|
|
||||||
mux, urlx, teardown := setup()
|
mux, urlx, teardown := setup()
|
||||||
|
@ -121,6 +124,7 @@ func TestWatcherAuth(t *testing.T) {
|
||||||
|
|
||||||
initBasicMuxMock(t, mux, "/watchers/login")
|
initBasicMuxMock(t, mux, "/watchers/login")
|
||||||
log.Printf("URL is %s", urlx)
|
log.Printf("URL is %s", urlx)
|
||||||
|
|
||||||
apiURL, err := url.Parse(urlx + "/")
|
apiURL, err := url.Parse(urlx + "/")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("parsing api url: %s", apiURL)
|
t.Fatalf("parsing api url: %s", apiURL)
|
||||||
|
@ -169,6 +173,7 @@ func TestWatcherAuth(t *testing.T) {
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
resp.Response.Body.Close()
|
resp.Response.Body.Close()
|
||||||
|
|
||||||
bodyBytes, err := io.ReadAll(resp.Response.Body)
|
bodyBytes, err := io.ReadAll(resp.Response.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error while reading body: %s", err.Error())
|
t.Fatalf("error while reading body: %s", err.Error())
|
||||||
|
@ -176,14 +181,13 @@ func TestWatcherAuth(t *testing.T) {
|
||||||
|
|
||||||
log.Printf(string(bodyBytes))
|
log.Printf(string(bodyBytes))
|
||||||
t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest)
|
t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest)
|
||||||
} else {
|
|
||||||
log.Printf("The AuthenticateWatcher function handled the error code %d as expected \n\r", errorCodeToTest)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Printf("The AuthenticateWatcher function handled the error code %d as expected \n\r", errorCodeToTest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWatcherUnregister(t *testing.T) {
|
func TestWatcherUnregister(t *testing.T) {
|
||||||
|
|
||||||
log.SetLevel(log.DebugLevel)
|
log.SetLevel(log.DebugLevel)
|
||||||
|
|
||||||
mux, urlx, teardown := setup()
|
mux, urlx, teardown := setup()
|
||||||
|
@ -192,7 +196,7 @@ func TestWatcherUnregister(t *testing.T) {
|
||||||
|
|
||||||
mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) {
|
||||||
testMethod(t, r, "DELETE")
|
testMethod(t, r, "DELETE")
|
||||||
assert.Equal(t, r.ContentLength, int64(0))
|
assert.Equal(t, int64(0), r.ContentLength)
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -211,10 +215,12 @@ func TestWatcherUnregister(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
log.Printf("URL is %s", urlx)
|
log.Printf("URL is %s", urlx)
|
||||||
|
|
||||||
apiURL, err := url.Parse(urlx + "/")
|
apiURL, err := url.Parse(urlx + "/")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("parsing api url: %s", apiURL)
|
t.Fatalf("parsing api url: %s", apiURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
mycfg := &Config{
|
mycfg := &Config{
|
||||||
MachineID: "test_login",
|
MachineID: "test_login",
|
||||||
Password: "test_password",
|
Password: "test_password",
|
||||||
|
@ -228,10 +234,12 @@ func TestWatcherUnregister(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new api client: %s", err)
|
t.Fatalf("new api client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = client.Auth.UnregisterWatcher(context.Background())
|
_, err = client.Auth.UnregisterWatcher(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("while registering client : %s", err)
|
t.Fatalf("while registering client : %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("->%T", client)
|
log.Printf("->%T", client)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -264,6 +272,7 @@ func TestWatcherEnroll(t *testing.T) {
|
||||||
fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`)
|
fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`)
|
||||||
})
|
})
|
||||||
log.Printf("URL is %s", urlx)
|
log.Printf("URL is %s", urlx)
|
||||||
|
|
||||||
apiURL, err := url.Parse(urlx + "/")
|
apiURL, err := url.Parse(urlx + "/")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("parsing api url: %s", apiURL)
|
t.Fatalf("parsing api url: %s", apiURL)
|
||||||
|
|
|
@ -18,7 +18,7 @@ func TestApiAuth(t *testing.T) {
|
||||||
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
testMethod(t, r, "GET")
|
testMethod(t, r, "GET")
|
||||||
if r.Header.Get("X-Api-Key") == "ixu" {
|
if r.Header.Get("X-Api-Key") == "ixu" {
|
||||||
assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4")
|
assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery)
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(`null`))
|
w.Write([]byte(`null`))
|
||||||
} else {
|
} else {
|
||||||
|
@ -66,9 +66,11 @@ func TestApiAuth(t *testing.T) {
|
||||||
_, resp, err = newcli.Decisions.List(context.Background(), alert)
|
_, resp, err = newcli.Decisions.List(context.Background(), alert)
|
||||||
|
|
||||||
log.Infof("--> %s", err)
|
log.Infof("--> %s", err)
|
||||||
|
|
||||||
if resp.Response.StatusCode != http.StatusForbidden {
|
if resp.Response.StatusCode != http.StatusForbidden {
|
||||||
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Contains(t, err.Error(), "API error: access forbidden")
|
assert.Contains(t, err.Error(), "API error: access forbidden")
|
||||||
//ko empty token
|
//ko empty token
|
||||||
auth = &APIKeyTransport{}
|
auth = &APIKeyTransport{}
|
||||||
|
@ -82,5 +84,4 @@ func TestApiAuth(t *testing.T) {
|
||||||
|
|
||||||
log.Infof("--> %s", err)
|
log.Infof("--> %s", err)
|
||||||
assert.Contains(t, err.Error(), "APIKey is empty")
|
assert.Contains(t, err.Error(), "APIKey is empty")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,8 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/models"
|
"github.com/crowdsecurity/crowdsec/pkg/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -43,6 +45,21 @@ func (a *ApiClient) GetClient() *http.Client {
|
||||||
return a.client
|
return a.client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *ApiClient) IsEnrolled() bool {
|
||||||
|
jwtTransport := a.client.Transport.(*JWTTransport)
|
||||||
|
tokenStr := jwtTransport.Token
|
||||||
|
|
||||||
|
token, _ := jwt.Parse(tokenStr, nil)
|
||||||
|
if token == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := token.Claims.(jwt.MapClaims)
|
||||||
|
_, ok := claims["organization_id"]
|
||||||
|
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
type service struct {
|
type service struct {
|
||||||
client *ApiClient
|
client *ApiClient
|
||||||
}
|
}
|
||||||
|
@ -59,12 +76,15 @@ func NewClient(config *Config) (*ApiClient, error) {
|
||||||
}
|
}
|
||||||
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
|
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
|
||||||
tlsconfig.RootCAs = CaCertPool
|
tlsconfig.RootCAs = CaCertPool
|
||||||
|
|
||||||
if Cert != nil {
|
if Cert != nil {
|
||||||
tlsconfig.Certificates = []tls.Certificate{*Cert}
|
tlsconfig.Certificates = []tls.Certificate{*Cert}
|
||||||
}
|
}
|
||||||
|
|
||||||
if ht, ok := http.DefaultTransport.(*http.Transport); ok {
|
if ht, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||||
ht.TLSClientConfig = &tlsconfig
|
ht.TLSClientConfig = &tlsconfig
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &ApiClient{client: t.Client(), BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL}
|
c := &ApiClient{client: t.Client(), BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL}
|
||||||
c.common.client = c
|
c.common.client = c
|
||||||
c.Decisions = (*DecisionsService)(&c.common)
|
c.Decisions = (*DecisionsService)(&c.common)
|
||||||
|
@ -81,16 +101,20 @@ func NewClient(config *Config) (*ApiClient, error) {
|
||||||
func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *http.Client) (*ApiClient, error) {
|
func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *http.Client) (*ApiClient, error) {
|
||||||
if client == nil {
|
if client == nil {
|
||||||
client = &http.Client{}
|
client = &http.Client{}
|
||||||
|
|
||||||
if ht, ok := http.DefaultTransport.(*http.Transport); ok {
|
if ht, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||||
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
|
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
|
||||||
tlsconfig.RootCAs = CaCertPool
|
tlsconfig.RootCAs = CaCertPool
|
||||||
|
|
||||||
if Cert != nil {
|
if Cert != nil {
|
||||||
tlsconfig.Certificates = []tls.Certificate{*Cert}
|
tlsconfig.Certificates = []tls.Certificate{*Cert}
|
||||||
}
|
}
|
||||||
|
|
||||||
ht.TLSClientConfig = &tlsconfig
|
ht.TLSClientConfig = &tlsconfig
|
||||||
client.Transport = ht
|
client.Transport = ht
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &ApiClient{client: client, BaseURL: URL, UserAgent: userAgent, URLPrefix: prefix}
|
c := &ApiClient{client: client, BaseURL: URL, UserAgent: userAgent, URLPrefix: prefix}
|
||||||
c.common.client = c
|
c.common.client = c
|
||||||
c.Decisions = (*DecisionsService)(&c.common)
|
c.Decisions = (*DecisionsService)(&c.common)
|
||||||
|
@ -108,11 +132,13 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) {
|
||||||
if client == nil {
|
if client == nil {
|
||||||
client = &http.Client{}
|
client = &http.Client{}
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
|
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
|
||||||
if Cert != nil {
|
if Cert != nil {
|
||||||
tlsconfig.RootCAs = CaCertPool
|
tlsconfig.RootCAs = CaCertPool
|
||||||
tlsconfig.Certificates = []tls.Certificate{*Cert}
|
tlsconfig.Certificates = []tls.Certificate{*Cert}
|
||||||
}
|
}
|
||||||
|
|
||||||
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig
|
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig
|
||||||
c := &ApiClient{client: client, BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix}
|
c := &ApiClient{client: client, BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix}
|
||||||
c.common.client = c
|
c.common.client = c
|
||||||
|
@ -126,10 +152,11 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) {
|
||||||
if resp != nil && resp.Response != nil {
|
if resp != nil && resp.Response != nil {
|
||||||
return nil, fmt.Errorf("api register (%s) http %s: %w", c.BaseURL, resp.Response.Status, err)
|
return nil, fmt.Errorf("api register (%s) http %s: %w", c.BaseURL, resp.Response.Status, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("api register (%s): %w", c.BaseURL, err)
|
return nil, fmt.Errorf("api register (%s): %w", c.BaseURL, err)
|
||||||
}
|
}
|
||||||
return c, nil
|
|
||||||
|
|
||||||
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Response struct {
|
type Response struct {
|
||||||
|
@ -148,6 +175,7 @@ func (e *ErrorResponse) Error() string {
|
||||||
if len(e.Errors) > 0 {
|
if len(e.Errors) > 0 {
|
||||||
err += fmt.Sprintf(" (%s)", e.Errors)
|
err += fmt.Sprintf(" (%s)", e.Errors)
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,7 +188,9 @@ func CheckResponse(r *http.Response) error {
|
||||||
if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 {
|
if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
errorResponse := &ErrorResponse{}
|
errorResponse := &ErrorResponse{}
|
||||||
|
|
||||||
data, err := io.ReadAll(r.Body)
|
data, err := io.ReadAll(r.Body)
|
||||||
if err == nil && data != nil {
|
if err == nil && data != nil {
|
||||||
err := json.Unmarshal(data, errorResponse)
|
err := json.Unmarshal(data, errorResponse)
|
||||||
|
@ -171,6 +201,7 @@ func CheckResponse(r *http.Response) error {
|
||||||
errorResponse.Message = new(string)
|
errorResponse.Message = new(string)
|
||||||
*errorResponse.Message = fmt.Sprintf("http code %d, no error message", r.StatusCode)
|
*errorResponse.Message = fmt.Sprintf("http code %d, no error message", r.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
return errorResponse
|
return errorResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ func (c *ApiClient) NewRequest(method, url string, body interface{}) (*http.Requ
|
||||||
if !strings.HasSuffix(c.BaseURL.Path, "/") {
|
if !strings.HasSuffix(c.BaseURL.Path, "/") {
|
||||||
return nil, fmt.Errorf("BaseURL must have a trailing slash, but %q does not", c.BaseURL)
|
return nil, fmt.Errorf("BaseURL must have a trailing slash, but %q does not", c.BaseURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
u, err := c.BaseURL.Parse(url)
|
u, err := c.BaseURL.Parse(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -29,8 +30,8 @@ func (c *ApiClient) NewRequest(method, url string, body interface{}) (*http.Requ
|
||||||
buf = &bytes.Buffer{}
|
buf = &bytes.Buffer{}
|
||||||
enc := json.NewEncoder(buf)
|
enc := json.NewEncoder(buf)
|
||||||
enc.SetEscapeHTML(false)
|
enc.SetEscapeHTML(false)
|
||||||
err := enc.Encode(body)
|
|
||||||
if err != nil {
|
if err = enc.Encode(body); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -51,6 +52,7 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (*
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return nil, errors.New("context must be non-nil")
|
return nil, errors.New("context must be non-nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
// Check rate limit
|
// Check rate limit
|
||||||
|
@ -62,6 +64,7 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (*
|
||||||
if log.GetLevel() >= log.DebugLevel {
|
if log.GetLevel() >= log.DebugLevel {
|
||||||
log.Debugf("[URL] %s %s", req.Method, req.URL)
|
log.Debugf("[URL] %s %s", req.Method, req.URL)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.client.Do(req)
|
resp, err := c.client.Do(req)
|
||||||
if resp != nil && resp.Body != nil {
|
if resp != nil && resp.Body != nil {
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
@ -82,8 +85,10 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (*
|
||||||
e.URL = url.String()
|
e.URL = url.String()
|
||||||
return newResponse(resp), e
|
return newResponse(resp), e
|
||||||
}
|
}
|
||||||
|
|
||||||
return newResponse(resp), err
|
return newResponse(resp), err
|
||||||
}
|
}
|
||||||
|
|
||||||
return newResponse(resp), err
|
return newResponse(resp), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,9 +117,12 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (*
|
||||||
if errors.Is(decErr, io.EOF) {
|
if errors.Is(decErr, io.EOF) {
|
||||||
decErr = nil // ignore EOF errors caused by empty response body
|
decErr = nil // ignore EOF errors caused by empty response body
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, decErr
|
return response, decErr
|
||||||
}
|
}
|
||||||
|
|
||||||
io.Copy(w, resp.Body)
|
io.Copy(w, resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, err
|
return response, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ func TestNewRequestInvalid(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("parsing api url: %s", apiURL)
|
t.Fatalf("parsing api url: %s", apiURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := NewClient(&Config{
|
client, err := NewClient(&Config{
|
||||||
MachineID: "test_login",
|
MachineID: "test_login",
|
||||||
Password: "test_password",
|
Password: "test_password",
|
||||||
|
@ -54,6 +55,7 @@ func TestNewRequestTimeout(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("parsing api url: %s", apiURL)
|
t.Fatalf("parsing api url: %s", apiURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := NewClient(&Config{
|
client, err := NewClient(&Config{
|
||||||
MachineID: "test_login",
|
MachineID: "test_login",
|
||||||
Password: "test_password",
|
Password: "test_password",
|
||||||
|
|
|
@ -40,6 +40,7 @@ func setupWithPrefix(urlPrefix string) (mux *http.ServeMux, serverURL string, te
|
||||||
|
|
||||||
func testMethod(t *testing.T, r *http.Request, want string) {
|
func testMethod(t *testing.T, r *http.Request, want string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
if got := r.Method; got != want {
|
if got := r.Method; got != want {
|
||||||
t.Errorf("Request method: %v, want %v", got, want)
|
t.Errorf("Request method: %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
|
@ -77,6 +78,7 @@ func TestNewClientOk(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("test Unable to list alerts : %+v", err)
|
t.Fatalf("test Unable to list alerts : %+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.Response.StatusCode != http.StatusOK {
|
if resp.Response.StatusCode != http.StatusOK {
|
||||||
t.Fatalf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusCreated)
|
t.Fatalf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
@ -126,6 +128,7 @@ func TestNewDefaultClient(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new api client: %s", err)
|
t.Fatalf("new api client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
w.Write([]byte(`{"code": 401, "message" : "brr"}`))
|
w.Write([]byte(`{"code": 401, "message" : "brr"}`))
|
||||||
|
@ -157,6 +160,7 @@ func TestNewClientRegisterKO(t *testing.T) {
|
||||||
func TestNewClientRegisterOK(t *testing.T) {
|
func TestNewClientRegisterOK(t *testing.T) {
|
||||||
log.SetLevel(log.TraceLevel)
|
log.SetLevel(log.TraceLevel)
|
||||||
mux, urlx, teardown := setup()
|
mux, urlx, teardown := setup()
|
||||||
|
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
|
||||||
/*mock login*/
|
/*mock login*/
|
||||||
|
@ -180,12 +184,14 @@ func TestNewClientRegisterOK(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("while registering client : %s", err)
|
t.Fatalf("while registering client : %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("->%T", client)
|
log.Printf("->%T", client)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewClientBadAnswer(t *testing.T) {
|
func TestNewClientBadAnswer(t *testing.T) {
|
||||||
log.SetLevel(log.TraceLevel)
|
log.SetLevel(log.TraceLevel)
|
||||||
mux, urlx, teardown := setup()
|
mux, urlx, teardown := setup()
|
||||||
|
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
|
||||||
/*mock login*/
|
/*mock login*/
|
||||||
|
|
|
@ -42,6 +42,7 @@ func (o *DecisionsStreamOpts) addQueryParamsToURL(url string) (string, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s?%s", url, params.Encode()), nil
|
return fmt.Sprintf("%s?%s", url, params.Encode()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,10 +62,12 @@ type DecisionsDeleteOpts struct {
|
||||||
// to demo query arguments
|
// to demo query arguments
|
||||||
func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*models.GetDecisionsResponse, *Response, error) {
|
func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*models.GetDecisionsResponse, *Response, error) {
|
||||||
var decisions models.GetDecisionsResponse
|
var decisions models.GetDecisionsResponse
|
||||||
|
|
||||||
params, err := qs.Values(opts)
|
params, err := qs.Values(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode())
|
u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode())
|
||||||
|
|
||||||
req, err := s.client.NewRequest(http.MethodGet, u, nil)
|
req, err := s.client.NewRequest(http.MethodGet, u, nil)
|
||||||
|
@ -111,14 +114,18 @@ func (s *DecisionsService) GetDecisionsFromGroups(decisionsGroups []*modelscapi.
|
||||||
Origin: ptr.Of(types.CAPIOrigin),
|
Origin: ptr.Of(types.CAPIOrigin),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
decisions = append(decisions, partialDecisions...)
|
decisions = append(decisions, partialDecisions...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return decisions
|
return decisions
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) {
|
func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) {
|
||||||
var decisions modelscapi.GetDecisionsStreamResponse
|
var (
|
||||||
var v2Decisions models.DecisionsStreamResponse
|
decisions modelscapi.GetDecisionsStreamResponse
|
||||||
|
v2Decisions models.DecisionsStreamResponse
|
||||||
|
)
|
||||||
|
|
||||||
scenarioDeleted := "deleted"
|
scenarioDeleted := "deleted"
|
||||||
durationDeleted := "1h"
|
durationDeleted := "1h"
|
||||||
|
@ -134,8 +141,10 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m
|
||||||
}
|
}
|
||||||
|
|
||||||
v2Decisions.New = s.GetDecisionsFromGroups(decisions.New)
|
v2Decisions.New = s.GetDecisionsFromGroups(decisions.New)
|
||||||
|
|
||||||
for _, decisionsGroup := range decisions.Deleted {
|
for _, decisionsGroup := range decisions.Deleted {
|
||||||
partialDecisions := make([]*models.Decision, len(decisionsGroup.Decisions))
|
partialDecisions := make([]*models.Decision, len(decisionsGroup.Decisions))
|
||||||
|
|
||||||
for idx, decision := range decisionsGroup.Decisions {
|
for idx, decision := range decisionsGroup.Decisions {
|
||||||
decision := decision // fix exportloopref linter message
|
decision := decision // fix exportloopref linter message
|
||||||
partialDecisions[idx] = &models.Decision{
|
partialDecisions[idx] = &models.Decision{
|
||||||
|
@ -147,6 +156,7 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m
|
||||||
Origin: ptr.Of(types.CAPIOrigin),
|
Origin: ptr.Of(types.CAPIOrigin),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
v2Decisions.Deleted = append(v2Decisions.Deleted, partialDecisions...)
|
v2Decisions.Deleted = append(v2Decisions.Deleted, partialDecisions...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,6 +171,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
|
||||||
log.Debugf("Fetching blocklist %s", *blocklist.URL)
|
log.Debugf("Fetching blocklist %s", *blocklist.URL)
|
||||||
|
|
||||||
client := http.Client{}
|
client := http.Client{}
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, *blocklist.URL, nil)
|
req, err := http.NewRequest(http.MethodGet, *blocklist.URL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
|
@ -169,6 +180,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
|
||||||
if lastPullTimestamp != nil {
|
if lastPullTimestamp != nil {
|
||||||
req.Header.Set("If-Modified-Since", *lastPullTimestamp)
|
req.Header.Set("If-Modified-Since", *lastPullTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
log.Debugf("[URL] %s %s", req.Method, req.URL)
|
log.Debugf("[URL] %s %s", req.Method, req.URL)
|
||||||
// we dont use client_http Do method because we need the reader and is not provided. We would be forced to use Pipe and goroutine, etc
|
// we dont use client_http Do method because we need the reader and is not provided. We would be forced to use Pipe and goroutine, etc
|
||||||
|
@ -188,6 +200,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
|
||||||
|
|
||||||
// If the error type is *url.Error, sanitize its URL before returning.
|
// If the error type is *url.Error, sanitize its URL before returning.
|
||||||
log.Errorf("Error fetching blocklist %s: %s", *blocklist.URL, err)
|
log.Errorf("Error fetching blocklist %s: %s", *blocklist.URL, err)
|
||||||
|
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -197,13 +210,17 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Blocklist %s has not been modified (decisions about to expire)", *blocklist.URL)
|
log.Debugf("Blocklist %s has not been modified (decisions about to expire)", *blocklist.URL)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, false, nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
log.Debugf("Received nok status code %d for blocklist %s", resp.StatusCode, *blocklist.URL)
|
log.Debugf("Received nok status code %d for blocklist %s", resp.StatusCode, *blocklist.URL)
|
||||||
return nil, false, nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
decisions := make([]*models.Decision, 0)
|
decisions := make([]*models.Decision, 0)
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
decision := scanner.Text()
|
decision := scanner.Text()
|
||||||
|
@ -227,6 +244,7 @@ func (s *DecisionsService) GetStream(ctx context.Context, opts DecisionsStreamOp
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.client.URLPrefix == "v3" {
|
if s.client.URLPrefix == "v3" {
|
||||||
return s.FetchV3Decisions(ctx, u)
|
return s.FetchV3Decisions(ctx, u)
|
||||||
} else {
|
} else {
|
||||||
|
@ -239,6 +257,7 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var decisions modelscapi.GetDecisionsStreamResponse
|
var decisions modelscapi.GetDecisionsStreamResponse
|
||||||
|
|
||||||
req, err := s.client.NewRequest(http.MethodGet, u, nil)
|
req, err := s.client.NewRequest(http.MethodGet, u, nil)
|
||||||
|
@ -255,8 +274,8 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) {
|
func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) {
|
||||||
|
|
||||||
u := fmt.Sprintf("%s/decisions", s.client.URLPrefix)
|
u := fmt.Sprintf("%s/decisions", s.client.URLPrefix)
|
||||||
|
|
||||||
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -266,15 +285,18 @@ func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) (*models.DeleteDecisionResponse, *Response, error) {
|
func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) (*models.DeleteDecisionResponse, *Response, error) {
|
||||||
var deleteDecisionResponse models.DeleteDecisionResponse
|
var deleteDecisionResponse models.DeleteDecisionResponse
|
||||||
|
|
||||||
params, err := qs.Values(opts)
|
params, err := qs.Values(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode())
|
u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode())
|
||||||
|
|
||||||
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
||||||
|
@ -286,12 +308,14 @@ func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, resp, err
|
return nil, resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &deleteDecisionResponse, resp, nil
|
return &deleteDecisionResponse, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DecisionsService) DeleteOne(ctx context.Context, decision_id string) (*models.DeleteDecisionResponse, *Response, error) {
|
func (s *DecisionsService) DeleteOne(ctx context.Context, decisionID string) (*models.DeleteDecisionResponse, *Response, error) {
|
||||||
var deleteDecisionResponse models.DeleteDecisionResponse
|
var deleteDecisionResponse models.DeleteDecisionResponse
|
||||||
u := fmt.Sprintf("%s/decisions/%s", s.client.URLPrefix, decision_id)
|
|
||||||
|
u := fmt.Sprintf("%s/decisions/%s", s.client.URLPrefix, decisionID)
|
||||||
|
|
||||||
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -302,5 +326,6 @@ func (s *DecisionsService) DeleteOne(ctx context.Context, decision_id string) (*
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, resp, err
|
return nil, resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &deleteDecisionResponse, resp, nil
|
return &deleteDecisionResponse, resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,8 +28,8 @@ func TestDecisionsList(t *testing.T) {
|
||||||
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
testMethod(t, r, "GET")
|
testMethod(t, r, "GET")
|
||||||
if r.URL.RawQuery == "ip=1.2.3.4" {
|
if r.URL.RawQuery == "ip=1.2.3.4" {
|
||||||
assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4")
|
assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery)
|
||||||
assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu")
|
assert.Equal(t, "ixu", r.Header.Get("X-Api-Key"))
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(`[{"duration":"3h59m55.756182786s","id":4,"origin":"cscli","scenario":"manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'","scope":"Ip","type":"ban","value":"1.2.3.4"}]`))
|
w.Write([]byte(`[{"duration":"3h59m55.756182786s","id":4,"origin":"cscli","scenario":"manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'","scope":"Ip","type":"ban","value":"1.2.3.4"}]`))
|
||||||
} else {
|
} else {
|
||||||
|
@ -83,6 +83,7 @@ func TestDecisionsList(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new api client: %s", err)
|
t.Fatalf("new api client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(*decisions, *expected) {
|
if !reflect.DeepEqual(*decisions, *expected) {
|
||||||
t.Fatalf("returned %+v, want %+v", resp, expected)
|
t.Fatalf("returned %+v, want %+v", resp, expected)
|
||||||
}
|
}
|
||||||
|
@ -96,8 +97,8 @@ func TestDecisionsList(t *testing.T) {
|
||||||
if resp.Response.StatusCode != http.StatusOK {
|
if resp.Response.StatusCode != http.StatusOK {
|
||||||
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
||||||
}
|
}
|
||||||
assert.Equal(t, len(*decisions), 0)
|
|
||||||
|
|
||||||
|
assert.Empty(t, *decisions)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDecisionsStream(t *testing.T) {
|
func TestDecisionsStream(t *testing.T) {
|
||||||
|
@ -107,8 +108,7 @@ func TestDecisionsStream(t *testing.T) {
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
|
||||||
mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, "ixu", r.Header.Get("X-Api-Key"))
|
||||||
assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu")
|
|
||||||
testMethod(t, r, http.MethodGet)
|
testMethod(t, r, http.MethodGet)
|
||||||
if r.Method == http.MethodGet {
|
if r.Method == http.MethodGet {
|
||||||
if r.URL.RawQuery == "startup=true" {
|
if r.URL.RawQuery == "startup=true" {
|
||||||
|
@ -121,7 +121,7 @@ func TestDecisionsStream(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu")
|
assert.Equal(t, "ixu", r.Header.Get("X-Api-Key"))
|
||||||
testMethod(t, r, http.MethodDelete)
|
testMethod(t, r, http.MethodDelete)
|
||||||
if r.Method == http.MethodDelete {
|
if r.Method == http.MethodDelete {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
@ -173,6 +173,7 @@ func TestDecisionsStream(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new api client: %s", err)
|
t.Fatalf("new api client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(*decisions, *expected) {
|
if !reflect.DeepEqual(*decisions, *expected) {
|
||||||
t.Fatalf("returned %+v, want %+v", resp, expected)
|
t.Fatalf("returned %+v, want %+v", resp, expected)
|
||||||
}
|
}
|
||||||
|
@ -184,8 +185,9 @@ func TestDecisionsStream(t *testing.T) {
|
||||||
if resp.Response.StatusCode != http.StatusOK {
|
if resp.Response.StatusCode != http.StatusOK {
|
||||||
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
|
||||||
}
|
}
|
||||||
assert.Equal(t, 0, len(decisions.New))
|
|
||||||
assert.Equal(t, 0, len(decisions.Deleted))
|
assert.Empty(t, decisions.New)
|
||||||
|
assert.Empty(t, decisions.Deleted)
|
||||||
|
|
||||||
//delete stream
|
//delete stream
|
||||||
resp, err = newcli.Decisions.StopStream(context.Background())
|
resp, err = newcli.Decisions.StopStream(context.Background())
|
||||||
|
@ -203,8 +205,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) {
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
|
||||||
mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, "ixu", r.Header.Get("X-Api-Key"))
|
||||||
assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu")
|
|
||||||
testMethod(t, r, http.MethodGet)
|
testMethod(t, r, http.MethodGet)
|
||||||
if r.Method == http.MethodGet {
|
if r.Method == http.MethodGet {
|
||||||
if r.URL.RawQuery == "startup=true" {
|
if r.URL.RawQuery == "startup=true" {
|
||||||
|
@ -275,6 +276,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new api client: %s", err)
|
t.Fatalf("new api client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(*decisions, *expected) {
|
if !reflect.DeepEqual(*decisions, *expected) {
|
||||||
t.Fatalf("returned %+v, want %+v", resp, expected)
|
t.Fatalf("returned %+v, want %+v", resp, expected)
|
||||||
}
|
}
|
||||||
|
@ -287,8 +289,7 @@ func TestDecisionsStreamV3(t *testing.T) {
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
|
||||||
mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, "ixu", r.Header.Get("X-Api-Key"))
|
||||||
assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu")
|
|
||||||
testMethod(t, r, http.MethodGet)
|
testMethod(t, r, http.MethodGet)
|
||||||
if r.Method == http.MethodGet {
|
if r.Method == http.MethodGet {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
@ -368,6 +369,7 @@ func TestDecisionsStreamV3(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new api client: %s", err)
|
t.Fatalf("new api client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(*decisions, *expected) {
|
if !reflect.DeepEqual(*decisions, *expected) {
|
||||||
t.Fatalf("returned %+v, want %+v", resp, expected)
|
t.Fatalf("returned %+v, want %+v", resp, expected)
|
||||||
}
|
}
|
||||||
|
@ -451,6 +453,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("new api client: %s", err)
|
t.Fatalf("new api client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(decisions, expected) {
|
if !reflect.DeepEqual(decisions, expected) {
|
||||||
t.Fatalf("returned %+v, want %+v", decisions, expected)
|
t.Fatalf("returned %+v, want %+v", decisions, expected)
|
||||||
}
|
}
|
||||||
|
@ -484,7 +487,7 @@ func TestDeleteDecisions(t *testing.T) {
|
||||||
})
|
})
|
||||||
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
testMethod(t, r, "DELETE")
|
testMethod(t, r, "DELETE")
|
||||||
assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4")
|
assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery)
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(`{"nbDeleted":"1"}`))
|
w.Write([]byte(`{"nbDeleted":"1"}`))
|
||||||
//w.Write([]byte(`{"message":"0 deleted alerts"}`))
|
//w.Write([]byte(`{"message":"0 deleted alerts"}`))
|
||||||
|
@ -512,6 +515,7 @@ func TestDeleteDecisions(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected err : %s", err)
|
t.Fatalf("unexpected err : %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, "1", deleted.NbDeleted)
|
assert.Equal(t, "1", deleted.NbDeleted)
|
||||||
|
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
@ -519,6 +523,7 @@ func TestDeleteDecisions(t *testing.T) {
|
||||||
|
|
||||||
func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
|
func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
|
||||||
baseURLString := "http://localhost:8080/v1/decisions/stream"
|
baseURLString := "http://localhost:8080/v1/decisions/stream"
|
||||||
|
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Startup bool
|
Startup bool
|
||||||
Scopes string
|
Scopes string
|
||||||
|
@ -553,6 +558,7 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
|
||||||
want: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true",
|
want: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
tt := tt
|
tt := tt
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
|
@ -15,7 +15,9 @@ type DecisionDeleteService service
|
||||||
// DecisionDeleteService purposely reuses AddSignalsRequestItemDecisions model
|
// DecisionDeleteService purposely reuses AddSignalsRequestItemDecisions model
|
||||||
func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *models.DecisionsDeleteRequest) (interface{}, *Response, error) {
|
func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *models.DecisionsDeleteRequest) (interface{}, *Response, error) {
|
||||||
var response interface{}
|
var response interface{}
|
||||||
|
|
||||||
u := fmt.Sprintf("%s/decisions/delete", d.client.URLPrefix)
|
u := fmt.Sprintf("%s/decisions/delete", d.client.URLPrefix)
|
||||||
|
|
||||||
req, err := d.client.NewRequest(http.MethodPost, u, &deletedDecisions)
|
req, err := d.client.NewRequest(http.MethodPost, u, &deletedDecisions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("while building request: %w", err)
|
return nil, nil, fmt.Errorf("while building request: %w", err)
|
||||||
|
@ -25,10 +27,12 @@ func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *model
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, resp, fmt.Errorf("while performing request: %w", err)
|
return nil, resp, fmt.Errorf("while performing request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.Response.StatusCode != http.StatusOK {
|
if resp.Response.StatusCode != http.StatusOK {
|
||||||
log.Warnf("Decisions delete response : http %s", resp.Response.Status)
|
log.Warnf("Decisions delete response : http %s", resp.Response.Status)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Decisions delete response : http %s", resp.Response.Status)
|
log.Debugf("Decisions delete response : http %s", resp.Response.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &response, resp, nil
|
return &response, resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,6 @@ import (
|
||||||
type HeartBeatService service
|
type HeartBeatService service
|
||||||
|
|
||||||
func (h *HeartBeatService) Ping(ctx context.Context) (bool, *Response, error) {
|
func (h *HeartBeatService) Ping(ctx context.Context) (bool, *Response, error) {
|
||||||
|
|
||||||
u := fmt.Sprintf("%s/heartbeat", h.client.URLPrefix)
|
u := fmt.Sprintf("%s/heartbeat", h.client.URLPrefix)
|
||||||
|
|
||||||
req, err := h.client.NewRequest(http.MethodGet, u, nil)
|
req, err := h.client.NewRequest(http.MethodGet, u, nil)
|
||||||
|
|
|
@ -14,6 +14,7 @@ func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (inte
|
||||||
var response interface{}
|
var response interface{}
|
||||||
|
|
||||||
u := fmt.Sprintf("%s/metrics/", s.client.URLPrefix)
|
u := fmt.Sprintf("%s/metrics/", s.client.URLPrefix)
|
||||||
|
|
||||||
req, err := s.client.NewRequest(http.MethodPost, u, &metrics)
|
req, err := s.client.NewRequest(http.MethodPost, u, &metrics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
@ -23,5 +24,6 @@ func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (inte
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, resp, err
|
return nil, resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &response, resp, nil
|
return &response, resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsReque
|
||||||
var response interface{}
|
var response interface{}
|
||||||
|
|
||||||
u := fmt.Sprintf("%s/signals", s.client.URLPrefix)
|
u := fmt.Sprintf("%s/signals", s.client.URLPrefix)
|
||||||
|
|
||||||
req, err := s.client.NewRequest(http.MethodPost, u, &signals)
|
req, err := s.client.NewRequest(http.MethodPost, u, &signals)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("while building request: %w", err)
|
return nil, nil, fmt.Errorf("while building request: %w", err)
|
||||||
|
@ -25,10 +26,12 @@ func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsReque
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, resp, fmt.Errorf("while performing request: %w", err)
|
return nil, resp, fmt.Errorf("while performing request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.Response.StatusCode != http.StatusOK {
|
if resp.Response.StatusCode != http.StatusOK {
|
||||||
log.Warnf("Signal push response : http %s", resp.Response.Status)
|
log.Warnf("Signal push response : http %s", resp.Response.Status)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Signal push response : http %s", resp.Response.Status)
|
log.Debugf("Signal push response : http %s", resp.Response.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &response, resp, nil
|
return &response, resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,13 +9,13 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/csplugin"
|
"github.com/crowdsecurity/crowdsec/pkg/csplugin"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/models"
|
"github.com/crowdsecurity/crowdsec/pkg/models"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type LAPI struct {
|
type LAPI struct {
|
||||||
|
@ -57,6 +57,7 @@ func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, aut
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.t.Fatal(err)
|
l.t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if authType == "apikey" {
|
if authType == "apikey" {
|
||||||
req.Header.Add("X-Api-Key", l.bouncerKey)
|
req.Header.Add("X-Api-Key", l.bouncerKey)
|
||||||
} else if authType == "password" {
|
} else if authType == "password" {
|
||||||
|
@ -64,7 +65,9 @@ func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, aut
|
||||||
} else {
|
} else {
|
||||||
l.t.Fatal("auth type not supported")
|
l.t.Fatal("auth type not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
l.router.ServeHTTP(w, req)
|
l.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
return w
|
return w
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,6 +81,7 @@ func InitMachineTest(t *testing.T) (*gin.Engine, models.WatcherAuthResponse, csc
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, models.WatcherAuthResponse{}, config, err
|
return nil, models.WatcherAuthResponse{}, config, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return router, loginResp, config, nil
|
return router, loginResp, config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,7 +154,6 @@ func TestCreateAlert(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateAlertChannels(t *testing.T) {
|
func TestCreateAlertChannels(t *testing.T) {
|
||||||
|
|
||||||
apiServer, config, err := NewAPIServer(t)
|
apiServer, config, err := NewAPIServer(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalln(err)
|
log.Fatalln(err)
|
||||||
|
@ -164,18 +167,22 @@ func TestCreateAlertChannels(t *testing.T) {
|
||||||
}
|
}
|
||||||
lapi := LAPI{router: apiServer.router, loginResp: loginResp}
|
lapi := LAPI{router: apiServer.router, loginResp: loginResp}
|
||||||
|
|
||||||
var pd csplugin.ProfileAlert
|
var (
|
||||||
var wg sync.WaitGroup
|
pd csplugin.ProfileAlert
|
||||||
|
wg sync.WaitGroup
|
||||||
|
)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
pd = <-apiServer.controller.PluginChannel
|
pd = <-apiServer.controller.PluginChannel
|
||||||
|
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json")
|
go lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json")
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
assert.Equal(t, len(pd.Alert.Decisions), 1)
|
assert.Len(t, pd.Alert.Decisions, 1)
|
||||||
apiServer.Close()
|
apiServer.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -345,7 +352,6 @@ func TestAlertListFilters(t *testing.T) {
|
||||||
w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password")
|
w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password")
|
||||||
assert.Equal(t, 500, w.Code)
|
assert.Equal(t, 500, w.Code)
|
||||||
assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String())
|
assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String())
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAlertBulkInsert(t *testing.T) {
|
func TestAlertBulkInsert(t *testing.T) {
|
||||||
|
@ -393,7 +399,6 @@ func TestCreateAlertErrors(t *testing.T) {
|
||||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", lapi.loginResp.Token+"s"))
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", lapi.loginResp.Token+"s"))
|
||||||
lapi.router.ServeHTTP(w, req)
|
lapi.router.ServeHTTP(w, req)
|
||||||
assert.Equal(t, 401, w.Code)
|
assert.Equal(t, 401, w.Code)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteAlert(t *testing.T) {
|
func TestDeleteAlert(t *testing.T) {
|
||||||
|
@ -506,5 +511,4 @@ func TestDeleteAlertTrustedIPS(t *testing.T) {
|
||||||
|
|
||||||
lapi.InsertAlertFromFile("./tests/alert_sample.json")
|
lapi.InsertAlertFromFile("./tests/alert_sample.json")
|
||||||
assertAlertDeletedFromIP("127.0.0.1")
|
assertAlertDeletedFromIP("127.0.0.1")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,5 +48,4 @@ func TestAPIKey(t *testing.T) {
|
||||||
|
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
assert.Equal(t, "null", w.Body.String())
|
assert.Equal(t, "null", w.Body.String())
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,12 +75,14 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration {
|
||||||
if ret <= 0 {
|
if ret <= 0 {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apic) FetchScenariosListFromDB() ([]string, error) {
|
func (a *apic) FetchScenariosListFromDB() ([]string, error) {
|
||||||
scenarios := make([]string, 0)
|
scenarios := make([]string, 0)
|
||||||
machines, err := a.dbClient.ListMachines()
|
machines, err := a.dbClient.ListMachines()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("while listing machines: %w", err)
|
return nil, fmt.Errorf("while listing machines: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -88,18 +90,22 @@ func (a *apic) FetchScenariosListFromDB() ([]string, error) {
|
||||||
for _, v := range machines {
|
for _, v := range machines {
|
||||||
machineScenarios := strings.Split(v.Scenarios, ",")
|
machineScenarios := strings.Split(v.Scenarios, ",")
|
||||||
log.Debugf("%d scenarios for machine %d", len(machineScenarios), v.ID)
|
log.Debugf("%d scenarios for machine %d", len(machineScenarios), v.ID)
|
||||||
|
|
||||||
for _, sv := range machineScenarios {
|
for _, sv := range machineScenarios {
|
||||||
if !slices.Contains(scenarios, sv) && sv != "" {
|
if !slices.Contains(scenarios, sv) && sv != "" {
|
||||||
scenarios = append(scenarios, sv)
|
scenarios = append(scenarios, sv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Returning list of scenarios : %+v", scenarios)
|
log.Debugf("Returning list of scenarios : %+v", scenarios)
|
||||||
|
|
||||||
return scenarios, nil
|
return scenarios, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decisionsToApiDecisions(decisions []*models.Decision) models.AddSignalsRequestItemDecisions {
|
func decisionsToApiDecisions(decisions []*models.Decision) models.AddSignalsRequestItemDecisions {
|
||||||
apiDecisions := models.AddSignalsRequestItemDecisions{}
|
apiDecisions := models.AddSignalsRequestItemDecisions{}
|
||||||
|
|
||||||
for _, decision := range decisions {
|
for _, decision := range decisions {
|
||||||
x := &models.AddSignalsRequestItemDecisionsItem{
|
x := &models.AddSignalsRequestItemDecisionsItem{
|
||||||
Duration: ptr.Of(*decision.Duration),
|
Duration: ptr.Of(*decision.Duration),
|
||||||
|
@ -114,11 +120,14 @@ func decisionsToApiDecisions(decisions []*models.Decision) models.AddSignalsRequ
|
||||||
UUID: decision.UUID,
|
UUID: decision.UUID,
|
||||||
}
|
}
|
||||||
*x.ID = decision.ID
|
*x.ID = decision.ID
|
||||||
|
|
||||||
if decision.Simulated != nil {
|
if decision.Simulated != nil {
|
||||||
x.Simulated = *decision.Simulated
|
x.Simulated = *decision.Simulated
|
||||||
}
|
}
|
||||||
|
|
||||||
apiDecisions = append(apiDecisions, x)
|
apiDecisions = append(apiDecisions, x)
|
||||||
}
|
}
|
||||||
|
|
||||||
return apiDecisions
|
return apiDecisions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,6 +158,7 @@ func alertToSignal(alert *models.Alert, scenarioTrust string, shareContext bool)
|
||||||
}
|
}
|
||||||
if shareContext {
|
if shareContext {
|
||||||
signal.Context = make([]*models.AddSignalsRequestItemContextItems0, 0)
|
signal.Context = make([]*models.AddSignalsRequestItemContextItems0, 0)
|
||||||
|
|
||||||
for _, meta := range alert.Meta {
|
for _, meta := range alert.Meta {
|
||||||
contextItem := models.AddSignalsRequestItemContextItems0{
|
contextItem := models.AddSignalsRequestItemContextItems0{
|
||||||
Key: meta.Key,
|
Key: meta.Key,
|
||||||
|
@ -157,13 +167,14 @@ func alertToSignal(alert *models.Alert, scenarioTrust string, shareContext bool)
|
||||||
signal.Context = append(signal.Context, &contextItem)
|
signal.Context = append(signal.Context, &contextItem)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return signal
|
return signal
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) {
|
func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) {
|
||||||
var err error
|
var err error
|
||||||
ret := &apic{
|
|
||||||
|
|
||||||
|
ret := &apic{
|
||||||
AlertsAddChan: make(chan []*models.Alert),
|
AlertsAddChan: make(chan []*models.Alert),
|
||||||
dbClient: dbClient,
|
dbClient: dbClient,
|
||||||
mu: sync.Mutex{},
|
mu: sync.Mutex{},
|
||||||
|
@ -186,9 +197,11 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con
|
||||||
|
|
||||||
password := strfmt.Password(config.Credentials.Password)
|
password := strfmt.Password(config.Credentials.Password)
|
||||||
apiURL, err := url.Parse(config.Credentials.URL)
|
apiURL, err := url.Parse(config.Credentials.URL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.URL, err)
|
return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.URL, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
papiURL, err := url.Parse(config.Credentials.PapiURL)
|
papiURL, err := url.Parse(config.Credentials.PapiURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.PapiURL, err)
|
return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.PapiURL, err)
|
||||||
|
@ -198,6 +211,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("while fetching scenarios from db: %w", err)
|
return nil, fmt.Errorf("while fetching scenarios from db: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ret.apiClient, err = apiclient.NewClient(&apiclient.Config{
|
ret.apiClient, err = apiclient.NewClient(&apiclient.Config{
|
||||||
MachineID: config.Credentials.Login,
|
MachineID: config.Credentials.Login,
|
||||||
Password: password,
|
Password: password,
|
||||||
|
@ -228,7 +242,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con
|
||||||
return ret, fmt.Errorf("authenticate watcher (%s): %w", config.Credentials.Login, err)
|
return ret, fmt.Errorf("authenticate watcher (%s): %w", config.Credentials.Login, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ret.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil {
|
if err = ret.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil {
|
||||||
return ret, fmt.Errorf("unable to parse jwt expiration: %w", err)
|
return ret, fmt.Errorf("unable to parse jwt expiration: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -242,6 +256,7 @@ func (a *apic) Push() error {
|
||||||
defer trace.CatchPanic("lapi/pushToAPIC")
|
defer trace.CatchPanic("lapi/pushToAPIC")
|
||||||
|
|
||||||
var cache models.AddSignalsRequest
|
var cache models.AddSignalsRequest
|
||||||
|
|
||||||
ticker := time.NewTicker(a.pushIntervalFirst)
|
ticker := time.NewTicker(a.pushIntervalFirst)
|
||||||
|
|
||||||
log.Infof("Start push to CrowdSec Central API (interval: %s once, then %s)", a.pushIntervalFirst.Round(time.Second), a.pushInterval)
|
log.Infof("Start push to CrowdSec Central API (interval: %s once, then %s)", a.pushIntervalFirst.Round(time.Second), a.pushInterval)
|
||||||
|
@ -252,28 +267,35 @@ func (a *apic) Push() error {
|
||||||
a.pullTomb.Kill(nil)
|
a.pullTomb.Kill(nil)
|
||||||
a.metricsTomb.Kill(nil)
|
a.metricsTomb.Kill(nil)
|
||||||
log.Infof("push tomb is dying, sending cache (%d elements) before exiting", len(cache))
|
log.Infof("push tomb is dying, sending cache (%d elements) before exiting", len(cache))
|
||||||
|
|
||||||
if len(cache) == 0 {
|
if len(cache) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
go a.Send(&cache)
|
go a.Send(&cache)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
ticker.Reset(a.pushInterval)
|
ticker.Reset(a.pushInterval)
|
||||||
|
|
||||||
if len(cache) > 0 {
|
if len(cache) > 0 {
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
cacheCopy := cache
|
cacheCopy := cache
|
||||||
cache = make(models.AddSignalsRequest, 0)
|
cache = make(models.AddSignalsRequest, 0)
|
||||||
a.mu.Unlock()
|
a.mu.Unlock()
|
||||||
log.Infof("Signal push: %d signals to push", len(cacheCopy))
|
log.Infof("Signal push: %d signals to push", len(cacheCopy))
|
||||||
|
|
||||||
go a.Send(&cacheCopy)
|
go a.Send(&cacheCopy)
|
||||||
}
|
}
|
||||||
case alerts := <-a.AlertsAddChan:
|
case alerts := <-a.AlertsAddChan:
|
||||||
var signals []*models.AddSignalsRequestItem
|
var signals []*models.AddSignalsRequestItem
|
||||||
|
|
||||||
for _, alert := range alerts {
|
for _, alert := range alerts {
|
||||||
if ok := shouldShareAlert(alert, a.consoleConfig); ok {
|
if ok := shouldShareAlert(alert, a.consoleConfig); ok {
|
||||||
signals = append(signals, alertToSignal(alert, getScenarioTrustOfAlert(alert), *a.consoleConfig.ShareContext))
|
signals = append(signals, alertToSignal(alert, getScenarioTrustOfAlert(alert), *a.consoleConfig.ShareContext))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
cache = append(cache, signals...)
|
cache = append(cache, signals...)
|
||||||
a.mu.Unlock()
|
a.mu.Unlock()
|
||||||
|
@ -288,11 +310,13 @@ func getScenarioTrustOfAlert(alert *models.Alert) string {
|
||||||
} else if alert.ScenarioVersion == nil || *alert.ScenarioVersion == "" || *alert.ScenarioVersion == "?" {
|
} else if alert.ScenarioVersion == nil || *alert.ScenarioVersion == "" || *alert.ScenarioVersion == "?" {
|
||||||
scenarioTrust = "tainted"
|
scenarioTrust = "tainted"
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(alert.Decisions) > 0 {
|
if len(alert.Decisions) > 0 {
|
||||||
if *alert.Decisions[0].Origin == types.CscliOrigin {
|
if *alert.Decisions[0].Origin == types.CscliOrigin {
|
||||||
scenarioTrust = "manual"
|
scenarioTrust = "manual"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return scenarioTrust
|
return scenarioTrust
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -301,6 +325,7 @@ func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig
|
||||||
log.Debugf("simulation enabled for alert (id:%d), will not be sent to CAPI", alert.ID)
|
log.Debugf("simulation enabled for alert (id:%d), will not be sent to CAPI", alert.ID)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
switch scenarioTrust := getScenarioTrustOfAlert(alert); scenarioTrust {
|
switch scenarioTrust := getScenarioTrustOfAlert(alert); scenarioTrust {
|
||||||
case "manual":
|
case "manual":
|
||||||
if !*consoleConfig.ShareManualDecisions {
|
if !*consoleConfig.ShareManualDecisions {
|
||||||
|
@ -318,6 +343,7 @@ func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,34 +359,44 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) {
|
||||||
|
|
||||||
I don't know enough about gin to tell how much of an issue it can be.
|
I don't know enough about gin to tell how much of an issue it can be.
|
||||||
*/
|
*/
|
||||||
var cache []*models.AddSignalsRequestItem = *cacheOrig
|
var (
|
||||||
var send models.AddSignalsRequest
|
cache []*models.AddSignalsRequestItem = *cacheOrig
|
||||||
|
send models.AddSignalsRequest
|
||||||
|
)
|
||||||
|
|
||||||
bulkSize := 50
|
bulkSize := 50
|
||||||
pageStart := 0
|
pageStart := 0
|
||||||
pageEnd := bulkSize
|
pageEnd := bulkSize
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|
||||||
if pageEnd >= len(cache) {
|
if pageEnd >= len(cache) {
|
||||||
send = cache[pageStart:]
|
send = cache[pageStart:]
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
_, _, err := a.apiClient.Signal.Add(ctx, &send)
|
_, _, err := a.apiClient.Signal.Add(ctx, &send)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("sending signal to central API: %s", err)
|
log.Errorf("sending signal to central API: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
send = cache[pageStart:pageEnd]
|
send = cache[pageStart:pageEnd]
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
_, _, err := a.apiClient.Signal.Add(ctx, &send)
|
_, _, err := a.apiClient.Signal.Add(ctx, &send)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//we log it here as well, because the return value of func might be discarded
|
//we log it here as well, because the return value of func might be discarded
|
||||||
log.Errorf("sending signal to central API: %s", err)
|
log.Errorf("sending signal to central API: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pageStart += bulkSize
|
pageStart += bulkSize
|
||||||
pageEnd += bulkSize
|
pageEnd += bulkSize
|
||||||
}
|
}
|
||||||
|
@ -372,18 +408,22 @@ func (a *apic) CAPIPullIsOld() (bool, error) {
|
||||||
alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID)))
|
alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID)))
|
||||||
alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert
|
alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert
|
||||||
count, err := alerts.Count(a.dbClient.CTX)
|
count, err := alerts.Count(a.dbClient.CTX)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("while looking for CAPI alert: %w", err)
|
return false, fmt.Errorf("while looking for CAPI alert: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if count > 0 {
|
if count > 0 {
|
||||||
log.Printf("last CAPI pull is newer than 1h30, skip.")
|
log.Printf("last CAPI pull is newer than 1h30, skip.")
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delete_counters map[string]map[string]int) (int, error) {
|
func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, deleteCounters map[string]map[string]int) (int, error) {
|
||||||
nbDeleted := 0
|
nbDeleted := 0
|
||||||
|
|
||||||
for _, decision := range deletedDecisions {
|
for _, decision := range deletedDecisions {
|
||||||
filter := map[string][]string{
|
filter := map[string][]string{
|
||||||
"value": {*decision.Value},
|
"value": {*decision.Value},
|
||||||
|
@ -398,20 +438,25 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("deleting decisions error: %w", err)
|
return 0, fmt.Errorf("deleting decisions error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dbCliDel, err := strconv.Atoi(dbCliRet)
|
dbCliDel, err := strconv.Atoi(dbCliRet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err)
|
return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err)
|
||||||
}
|
}
|
||||||
updateCounterForDecision(delete_counters, decision.Origin, decision.Scenario, dbCliDel)
|
|
||||||
|
updateCounterForDecision(deleteCounters, decision.Origin, decision.Scenario, dbCliDel)
|
||||||
nbDeleted += dbCliDel
|
nbDeleted += dbCliDel
|
||||||
}
|
}
|
||||||
|
|
||||||
return nbDeleted, nil
|
return nbDeleted, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, delete_counters map[string]map[string]int) (int, error) {
|
func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) {
|
||||||
var nbDeleted int
|
var nbDeleted int
|
||||||
|
|
||||||
for _, decisions := range deletedDecisions {
|
for _, decisions := range deletedDecisions {
|
||||||
scope := decisions.Scope
|
scope := decisions.Scope
|
||||||
|
|
||||||
for _, decision := range decisions.Decisions {
|
for _, decision := range decisions.Decisions {
|
||||||
filter := map[string][]string{
|
filter := map[string][]string{
|
||||||
"value": {decision},
|
"value": {decision},
|
||||||
|
@ -425,26 +470,32 @@ func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisi
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("deleting decisions error: %w", err)
|
return 0, fmt.Errorf("deleting decisions error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dbCliDel, err := strconv.Atoi(dbCliRet)
|
dbCliDel, err := strconv.Atoi(dbCliRet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err)
|
return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err)
|
||||||
}
|
}
|
||||||
updateCounterForDecision(delete_counters, ptr.Of(types.CAPIOrigin), nil, dbCliDel)
|
|
||||||
|
updateCounterForDecision(deleteCounters, ptr.Of(types.CAPIOrigin), nil, dbCliDel)
|
||||||
nbDeleted += dbCliDel
|
nbDeleted += dbCliDel
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nbDeleted, nil
|
return nbDeleted, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert {
|
func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert {
|
||||||
newAlerts := make([]*models.Alert, 0)
|
newAlerts := make([]*models.Alert, 0)
|
||||||
|
|
||||||
for _, decision := range decisions {
|
for _, decision := range decisions {
|
||||||
found := false
|
found := false
|
||||||
|
|
||||||
for _, sub := range newAlerts {
|
for _, sub := range newAlerts {
|
||||||
if sub.Source.Scope == nil {
|
if sub.Source.Scope == nil {
|
||||||
log.Warningf("nil scope in %+v", sub)
|
log.Warningf("nil scope in %+v", sub)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if *decision.Origin == types.CAPIOrigin {
|
if *decision.Origin == types.CAPIOrigin {
|
||||||
if *sub.Source.Scope == types.CAPIOrigin {
|
if *sub.Source.Scope == types.CAPIOrigin {
|
||||||
found = true
|
found = true
|
||||||
|
@ -464,11 +515,13 @@ func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert {
|
||||||
log.Warningf("unknown origin %s : %+v", *decision.Origin, decision)
|
log.Warningf("unknown origin %s : %+v", *decision.Origin, decision)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
log.Debugf("Create entry for origin:%s scenario:%s", *decision.Origin, *decision.Scenario)
|
log.Debugf("Create entry for origin:%s scenario:%s", *decision.Origin, *decision.Scenario)
|
||||||
newAlerts = append(newAlerts, createAlertForDecision(decision))
|
newAlerts = append(newAlerts, createAlertForDecision(decision))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return newAlerts
|
return newAlerts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -489,6 +542,7 @@ func createAlertForDecision(decision *models.Decision) *models.Alert {
|
||||||
// XXX: this or nil?
|
// XXX: this or nil?
|
||||||
scenario = ""
|
scenario = ""
|
||||||
scope = ""
|
scope = ""
|
||||||
|
|
||||||
log.Warningf("unknown origin %s", *decision.Origin)
|
log.Warningf("unknown origin %s", *decision.Origin)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -512,10 +566,10 @@ func createAlertForDecision(decision *models.Decision) *models.Alert {
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function takes in list of parent alerts and decisions and then pairs them up.
|
// This function takes in list of parent alerts and decisions and then pairs them up.
|
||||||
func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decision, add_counters map[string]map[string]int) []*models.Alert {
|
func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decision, addCounters map[string]map[string]int) []*models.Alert {
|
||||||
for _, decision := range decisions {
|
for _, decision := range decisions {
|
||||||
//count and create separate alerts for each list
|
//count and create separate alerts for each list
|
||||||
updateCounterForDecision(add_counters, decision.Origin, decision.Scenario, 1)
|
updateCounterForDecision(addCounters, decision.Origin, decision.Scenario, 1)
|
||||||
|
|
||||||
/*CAPI might send lower case scopes, unify it.*/
|
/*CAPI might send lower case scopes, unify it.*/
|
||||||
switch strings.ToLower(*decision.Scope) {
|
switch strings.ToLower(*decision.Scope) {
|
||||||
|
@ -524,6 +578,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio
|
||||||
case "range":
|
case "range":
|
||||||
*decision.Scope = types.Range
|
*decision.Scope = types.Range
|
||||||
}
|
}
|
||||||
|
|
||||||
found := false
|
found := false
|
||||||
//add the individual decisions to the right list
|
//add the individual decisions to the right list
|
||||||
for idx, alert := range alerts {
|
for idx, alert := range alerts {
|
||||||
|
@ -531,6 +586,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio
|
||||||
if *alert.Source.Scope == types.CAPIOrigin {
|
if *alert.Source.Scope == types.CAPIOrigin {
|
||||||
alerts[idx].Decisions = append(alerts[idx].Decisions, decision)
|
alerts[idx].Decisions = append(alerts[idx].Decisions, decision)
|
||||||
found = true
|
found = true
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
} else if *decision.Origin == types.ListOrigin {
|
} else if *decision.Origin == types.ListOrigin {
|
||||||
|
@ -543,10 +599,12 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio
|
||||||
log.Warningf("unknown origin %s", *decision.Origin)
|
log.Warningf("unknown origin %s", *decision.Origin)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
log.Warningf("Orphaned decision for %s - %s", *decision.Origin, *decision.Scenario)
|
log.Warningf("Orphaned decision for %s - %s", *decision.Origin, *decision.Scenario)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return alerts
|
return alerts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -581,18 +639,20 @@ func (a *apic) PullTop(forcePull bool) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get stream: %w", err)
|
return fmt.Errorf("get stream: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
a.startup = false
|
a.startup = false
|
||||||
/*to count additions/deletions across lists*/
|
/*to count additions/deletions across lists*/
|
||||||
|
|
||||||
log.Debugf("Received %d new decisions", len(data.New))
|
log.Debugf("Received %d new decisions", len(data.New))
|
||||||
log.Debugf("Received %d deleted decisions", len(data.Deleted))
|
log.Debugf("Received %d deleted decisions", len(data.Deleted))
|
||||||
|
|
||||||
if data.Links != nil {
|
if data.Links != nil {
|
||||||
log.Debugf("Received %d blocklists links", len(data.Links.Blocklists))
|
log.Debugf("Received %d blocklists links", len(data.Links.Blocklists))
|
||||||
}
|
}
|
||||||
|
|
||||||
add_counters, delete_counters := makeAddAndDeleteCounters()
|
addCounters, deleteCounters := makeAddAndDeleteCounters()
|
||||||
// process deleted decisions
|
// process deleted decisions
|
||||||
if nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, delete_counters); err != nil {
|
if nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters); err != nil {
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted)
|
log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted)
|
||||||
|
@ -610,28 +670,30 @@ func (a *apic) PullTop(forcePull bool) error {
|
||||||
|
|
||||||
alert := createAlertForDecision(decisions[0])
|
alert := createAlertForDecision(decisions[0])
|
||||||
alertsFromCapi := []*models.Alert{alert}
|
alertsFromCapi := []*models.Alert{alert}
|
||||||
alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, add_counters)
|
alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters)
|
||||||
|
|
||||||
err = a.SaveAlerts(alertsFromCapi, add_counters, delete_counters)
|
err = a.SaveAlerts(alertsFromCapi, addCounters, deleteCounters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("while saving alerts: %w", err)
|
return fmt.Errorf("while saving alerts: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// update blocklists
|
// update blocklists
|
||||||
if err := a.UpdateBlocklists(data.Links, add_counters, forcePull); err != nil {
|
if err := a.UpdateBlocklists(data.Links, addCounters, forcePull); err != nil {
|
||||||
return fmt.Errorf("while updating blocklists: %w", err)
|
return fmt.Errorf("while updating blocklists: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// we receive a link to a blocklist, we pull the content of the blocklist and we create one alert
|
// we receive a link to a blocklist, we pull the content of the blocklist and we create one alert
|
||||||
func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error {
|
func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error {
|
||||||
add_counters, _ := makeAddAndDeleteCounters()
|
addCounters, _ := makeAddAndDeleteCounters()
|
||||||
if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{
|
if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{
|
||||||
Blocklists: []*modelscapi.BlocklistLink{blocklist},
|
Blocklists: []*modelscapi.BlocklistLink{blocklist},
|
||||||
}, add_counters, forcePull); err != nil {
|
}, addCounters, forcePull); err != nil {
|
||||||
return fmt.Errorf("while pulling blocklist: %w", err)
|
return fmt.Errorf("while pulling blocklist: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -641,17 +703,20 @@ func (a *apic) whitelistedBy(decision *models.Decision) string {
|
||||||
if decision.Value == nil {
|
if decision.Value == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
ipval := net.ParseIP(*decision.Value)
|
ipval := net.ParseIP(*decision.Value)
|
||||||
for _, cidr := range a.whitelists.Cidrs {
|
for _, cidr := range a.whitelists.Cidrs {
|
||||||
if cidr.Contains(ipval) {
|
if cidr.Contains(ipval) {
|
||||||
return cidr.String()
|
return cidr.String()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ip := range a.whitelists.Ips {
|
for _, ip := range a.whitelists.Ips {
|
||||||
if ip != nil && ip.Equal(ipval) {
|
if ip != nil && ip.Equal(ipval) {
|
||||||
return ip.String()
|
return ip.String()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -661,12 +726,14 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis
|
||||||
}
|
}
|
||||||
//deal with CAPI whitelists for fire. We want to avoid having a second list, so we shrink in place
|
//deal with CAPI whitelists for fire. We want to avoid having a second list, so we shrink in place
|
||||||
outIdx := 0
|
outIdx := 0
|
||||||
|
|
||||||
for _, decision := range decisions {
|
for _, decision := range decisions {
|
||||||
whitelister := a.whitelistedBy(decision)
|
whitelister := a.whitelistedBy(decision)
|
||||||
if whitelister != "" {
|
if whitelister != "" {
|
||||||
log.Infof("%s from %s is whitelisted by %s", *decision.Value, *decision.Scenario, whitelister)
|
log.Infof("%s from %s is whitelisted by %s", *decision.Value, *decision.Scenario, whitelister)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
decisions[outIdx] = decision
|
decisions[outIdx] = decision
|
||||||
outIdx++
|
outIdx++
|
||||||
}
|
}
|
||||||
|
@ -674,17 +741,20 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis
|
||||||
return decisions[:outIdx]
|
return decisions[:outIdx]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, add_counters map[string]map[string]int, delete_counters map[string]map[string]int) error {
|
func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error {
|
||||||
for _, alert := range alertsFromCapi {
|
for _, alert := range alertsFromCapi {
|
||||||
setAlertScenario(alert, add_counters, delete_counters)
|
setAlertScenario(alert, addCounters, deleteCounters)
|
||||||
log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions))
|
log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions))
|
||||||
|
|
||||||
if a.dbClient.Type == "sqlite" && (a.dbClient.WalMode == nil || !*a.dbClient.WalMode) {
|
if a.dbClient.Type == "sqlite" && (a.dbClient.WalMode == nil || !*a.dbClient.WalMode) {
|
||||||
log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist")
|
log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist")
|
||||||
}
|
}
|
||||||
|
|
||||||
alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alert)
|
alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err)
|
return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("%s : added %d entries, deleted %d entries (alert:%d)", *alert.Source.Scope, inserted, deleted, alertID)
|
log.Printf("%s : added %d entries, deleted %d entries (alert:%d)", *alert.Source.Scope, inserted, deleted, alertID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -697,71 +767,91 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo
|
||||||
alertQuery.Where(alert.SourceScopeEQ(fmt.Sprintf("%s:%s", types.ListOrigin, *blocklist.Name)))
|
alertQuery.Where(alert.SourceScopeEQ(fmt.Sprintf("%s:%s", types.ListOrigin, *blocklist.Name)))
|
||||||
alertQuery.Order(ent.Desc(alert.FieldCreatedAt))
|
alertQuery.Order(ent.Desc(alert.FieldCreatedAt))
|
||||||
alertInstance, err := alertQuery.First(context.Background())
|
alertInstance, err := alertQuery.First(context.Background())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ent.IsNotFound(err) {
|
if ent.IsNotFound(err) {
|
||||||
log.Debugf("no alert found for %s, force refresh", *blocklist.Name)
|
log.Debugf("no alert found for %s, force refresh", *blocklist.Name)
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, fmt.Errorf("while getting alert: %w", err)
|
return false, fmt.Errorf("while getting alert: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
decisionQuery := a.dbClient.Ent.Decision.Query()
|
decisionQuery := a.dbClient.Ent.Decision.Query()
|
||||||
decisionQuery.Where(decision.HasOwnerWith(alert.IDEQ(alertInstance.ID)))
|
decisionQuery.Where(decision.HasOwnerWith(alert.IDEQ(alertInstance.ID)))
|
||||||
firstDecision, err := decisionQuery.First(context.Background())
|
firstDecision, err := decisionQuery.First(context.Background())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ent.IsNotFound(err) {
|
if ent.IsNotFound(err) {
|
||||||
log.Debugf("no decision found for %s, force refresh", *blocklist.Name)
|
log.Debugf("no decision found for %s, force refresh", *blocklist.Name)
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, fmt.Errorf("while getting decision: %w", err)
|
return false, fmt.Errorf("while getting decision: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if firstDecision == nil || firstDecision.Until == nil || firstDecision.Until.Sub(time.Now().UTC()) < (a.pullInterval+15*time.Minute) {
|
if firstDecision == nil || firstDecision.Until == nil || firstDecision.Until.Sub(time.Now().UTC()) < (a.pullInterval+15*time.Minute) {
|
||||||
log.Debugf("at least one decision found for %s, expire soon, force refresh", *blocklist.Name)
|
log.Debugf("at least one decision found for %s, expire soon, force refresh", *blocklist.Name)
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, add_counters map[string]map[string]int, forcePull bool) error {
|
func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error {
|
||||||
if blocklist.Scope == nil {
|
if blocklist.Scope == nil {
|
||||||
log.Warningf("blocklist has no scope")
|
log.Warningf("blocklist has no scope")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if blocklist.Duration == nil {
|
if blocklist.Duration == nil {
|
||||||
log.Warningf("blocklist has no duration")
|
log.Warningf("blocklist has no duration")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !forcePull {
|
if !forcePull {
|
||||||
_forcePull, err := a.ShouldForcePullBlocklist(blocklist)
|
_forcePull, err := a.ShouldForcePullBlocklist(blocklist)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err)
|
return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
forcePull = _forcePull
|
forcePull = _forcePull
|
||||||
}
|
}
|
||||||
|
|
||||||
blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
|
blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
|
||||||
var lastPullTimestamp *string
|
|
||||||
var err error
|
var (
|
||||||
|
lastPullTimestamp *string
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
if !forcePull {
|
if !forcePull {
|
||||||
lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
|
lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
|
return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp)
|
decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err)
|
return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hasChanged {
|
if !hasChanged {
|
||||||
if lastPullTimestamp == nil {
|
if lastPullTimestamp == nil {
|
||||||
log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name)
|
log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name)
|
||||||
} else {
|
} else {
|
||||||
log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp)
|
log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
|
err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
|
return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(decisions) == 0 {
|
if len(decisions) == 0 {
|
||||||
log.Infof("blocklist %s has no decisions", *blocklist.Name)
|
log.Infof("blocklist %s has no decisions", *blocklist.Name)
|
||||||
return nil
|
return nil
|
||||||
|
@ -770,19 +860,21 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
|
||||||
decisions = a.ApplyApicWhitelists(decisions)
|
decisions = a.ApplyApicWhitelists(decisions)
|
||||||
alert := createAlertForDecision(decisions[0])
|
alert := createAlertForDecision(decisions[0])
|
||||||
alertsFromCapi := []*models.Alert{alert}
|
alertsFromCapi := []*models.Alert{alert}
|
||||||
alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, add_counters)
|
alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters)
|
||||||
|
|
||||||
err = a.SaveAlerts(alertsFromCapi, add_counters, nil)
|
err = a.SaveAlerts(alertsFromCapi, addCounters, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("while saving alert from blocklist %s: %w", *blocklist.Name, err)
|
return fmt.Errorf("while saving alert from blocklist %s: %w", *blocklist.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, add_counters map[string]map[string]int, forcePull bool) error {
|
func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error {
|
||||||
if links == nil {
|
if links == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if links.Blocklists == nil {
|
if links.Blocklists == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -792,21 +884,23 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("while creating default client: %w", err)
|
return fmt.Errorf("while creating default client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, blocklist := range links.Blocklists {
|
for _, blocklist := range links.Blocklists {
|
||||||
if err := a.updateBlocklist(defaultClient, blocklist, add_counters, forcePull); err != nil {
|
if err := a.updateBlocklist(defaultClient, blocklist, addCounters, forcePull); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setAlertScenario(alert *models.Alert, add_counters map[string]map[string]int, delete_counters map[string]map[string]int) {
|
func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) {
|
||||||
if *alert.Source.Scope == types.CAPIOrigin {
|
if *alert.Source.Scope == types.CAPIOrigin {
|
||||||
*alert.Source.Scope = types.CommunityBlocklistPullSourceScope
|
*alert.Source.Scope = types.CommunityBlocklistPullSourceScope
|
||||||
alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", add_counters[types.CAPIOrigin]["all"], delete_counters[types.CAPIOrigin]["all"]))
|
alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.CAPIOrigin]["all"], deleteCounters[types.CAPIOrigin]["all"]))
|
||||||
} else if *alert.Source.Scope == types.ListOrigin {
|
} else if *alert.Source.Scope == types.ListOrigin {
|
||||||
*alert.Source.Scope = fmt.Sprintf("%s:%s", types.ListOrigin, *alert.Scenario)
|
*alert.Source.Scope = fmt.Sprintf("%s:%s", types.ListOrigin, *alert.Scenario)
|
||||||
alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", add_counters[types.ListOrigin][*alert.Scenario], delete_counters[types.ListOrigin][*alert.Scenario]))
|
alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.ListOrigin][*alert.Scenario], deleteCounters[types.ListOrigin][*alert.Scenario]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -814,20 +908,26 @@ func (a *apic) Pull() error {
|
||||||
defer trace.CatchPanic("lapi/pullFromAPIC")
|
defer trace.CatchPanic("lapi/pullFromAPIC")
|
||||||
|
|
||||||
toldOnce := false
|
toldOnce := false
|
||||||
|
|
||||||
for {
|
for {
|
||||||
scenario, err := a.FetchScenariosListFromDB()
|
scenario, err := a.FetchScenariosListFromDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to fetch scenarios from db: %s", err)
|
log.Errorf("unable to fetch scenarios from db: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(scenario) > 0 {
|
if len(scenario) > 0 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if !toldOnce {
|
if !toldOnce {
|
||||||
log.Warning("scenario list is empty, will not pull yet")
|
log.Warning("scenario list is empty, will not pull yet")
|
||||||
|
|
||||||
toldOnce = true
|
toldOnce = true
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := a.PullTop(false); err != nil {
|
if err := a.PullTop(false); err != nil {
|
||||||
log.Errorf("capi pull top: %s", err)
|
log.Errorf("capi pull top: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -839,6 +939,7 @@ func (a *apic) Pull() error {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
ticker.Reset(a.pullInterval)
|
ticker.Reset(a.pullInterval)
|
||||||
|
|
||||||
if err := a.PullTop(false); err != nil {
|
if err := a.PullTop(false); err != nil {
|
||||||
log.Errorf("capi pull top: %s", err)
|
log.Errorf("capi pull top: %s", err)
|
||||||
continue
|
continue
|
||||||
|
@ -846,6 +947,7 @@ func (a *apic) Pull() error {
|
||||||
case <-a.pullTomb.Dying(): // if one apic routine is dying, do we kill the others?
|
case <-a.pullTomb.Dying(): // if one apic routine is dying, do we kill the others?
|
||||||
a.metricsTomb.Kill(nil)
|
a.metricsTomb.Kill(nil)
|
||||||
a.pushTomb.Kill(nil)
|
a.pushTomb.Kill(nil)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -858,15 +960,15 @@ func (a *apic) Shutdown() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeAddAndDeleteCounters() (map[string]map[string]int, map[string]map[string]int) {
|
func makeAddAndDeleteCounters() (map[string]map[string]int, map[string]map[string]int) {
|
||||||
add_counters := make(map[string]map[string]int)
|
addCounters := make(map[string]map[string]int)
|
||||||
add_counters[types.CAPIOrigin] = make(map[string]int)
|
addCounters[types.CAPIOrigin] = make(map[string]int)
|
||||||
add_counters[types.ListOrigin] = make(map[string]int)
|
addCounters[types.ListOrigin] = make(map[string]int)
|
||||||
|
|
||||||
delete_counters := make(map[string]map[string]int)
|
deleteCounters := make(map[string]map[string]int)
|
||||||
delete_counters[types.CAPIOrigin] = make(map[string]int)
|
deleteCounters[types.CAPIOrigin] = make(map[string]int)
|
||||||
delete_counters[types.ListOrigin] = make(map[string]int)
|
deleteCounters[types.ListOrigin] = make(map[string]int)
|
||||||
|
|
||||||
return add_counters, delete_counters
|
return addCounters, deleteCounters
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateCounterForDecision(counter map[string]map[string]int, origin *string, scenario *string, totalDecisions int) {
|
func updateCounterForDecision(counter map[string]map[string]int, origin *string, scenario *string, totalDecisions int) {
|
||||||
|
|
|
@ -2,10 +2,10 @@ package apiserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"slices"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/crowdsecurity/go-cs-lib/ptr"
|
"github.com/crowdsecurity/go-cs-lib/ptr"
|
||||||
"github.com/crowdsecurity/go-cs-lib/trace"
|
"github.com/crowdsecurity/go-cs-lib/trace"
|
||||||
|
@ -66,6 +66,7 @@ func (a *apic) fetchMachineIDs() ([]string, error) {
|
||||||
}
|
}
|
||||||
// sorted slices are required for the slices.Equal comparison
|
// sorted slices are required for the slices.Equal comparison
|
||||||
slices.Sort(ret)
|
slices.Sort(ret)
|
||||||
|
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,6 +92,7 @@ func (a *apic) SendMetrics(stop chan (bool)) {
|
||||||
if count < len(metInts)-1 {
|
if count < len(metInts)-1 {
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
|
|
||||||
return metInts[count]
|
return metInts[count]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,8 +102,10 @@ func (a *apic) SendMetrics(stop chan (bool)) {
|
||||||
ids, err := a.fetchMachineIDs()
|
ids, err := a.fetchMachineIDs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("unable to get machines (%s), will retry", err)
|
log.Debugf("unable to get machines (%s), will retry", err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
machineIDs = ids
|
machineIDs = ids
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,16 +121,20 @@ func (a *apic) SendMetrics(stop chan (bool)) {
|
||||||
case <-stop:
|
case <-stop:
|
||||||
checkTicker.Stop()
|
checkTicker.Stop()
|
||||||
metTicker.Stop()
|
metTicker.Stop()
|
||||||
|
|
||||||
return
|
return
|
||||||
case <-checkTicker.C:
|
case <-checkTicker.C:
|
||||||
oldIDs := machineIDs
|
oldIDs := machineIDs
|
||||||
|
|
||||||
reloadMachineIDs()
|
reloadMachineIDs()
|
||||||
|
|
||||||
if !slices.Equal(oldIDs, machineIDs) {
|
if !slices.Equal(oldIDs, machineIDs) {
|
||||||
log.Infof("capi metrics: machines changed, immediate send")
|
log.Infof("capi metrics: machines changed, immediate send")
|
||||||
metTicker.Reset(1 * time.Millisecond)
|
metTicker.Reset(1 * time.Millisecond)
|
||||||
}
|
}
|
||||||
case <-metTicker.C:
|
case <-metTicker.C:
|
||||||
metTicker.Stop()
|
metTicker.Stop()
|
||||||
|
|
||||||
metrics, err := a.GetMetrics()
|
metrics, err := a.GetMetrics()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to get metrics (%s)", err)
|
log.Errorf("unable to get metrics (%s)", err)
|
||||||
|
@ -134,17 +142,20 @@ func (a *apic) SendMetrics(stop chan (bool)) {
|
||||||
// metrics are nil if they could not be retrieved
|
// metrics are nil if they could not be retrieved
|
||||||
if metrics != nil {
|
if metrics != nil {
|
||||||
log.Info("capi metrics: sending")
|
log.Info("capi metrics: sending")
|
||||||
|
|
||||||
_, _, err = a.apiClient.Metrics.Add(context.Background(), metrics)
|
_, _, err = a.apiClient.Metrics.Add(context.Background(), metrics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("capi metrics: failed: %s", err)
|
log.Errorf("capi metrics: failed: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
metTicker.Reset(nextMetInt())
|
metTicker.Reset(nextMetInt())
|
||||||
case <-a.metricsTomb.Dying(): // if one apic routine is dying, do we kill the others?
|
case <-a.metricsTomb.Dying(): // if one apic routine is dying, do we kill the others?
|
||||||
checkTicker.Stop()
|
checkTicker.Stop()
|
||||||
metTicker.Stop()
|
metTicker.Stop()
|
||||||
a.pullTomb.Kill(nil)
|
a.pullTomb.Kill(nil)
|
||||||
a.pushTomb.Kill(nil)
|
a.pushTomb.Kill(nil)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,6 +61,7 @@ func TestAPICSendMetrics(t *testing.T) {
|
||||||
|
|
||||||
httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/metrics/", httpmock.NewBytesResponder(200, []byte{}))
|
httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/metrics/", httpmock.NewBytesResponder(200, []byte{}))
|
||||||
httpmock.Activate()
|
httpmock.Activate()
|
||||||
|
|
||||||
defer httpmock.Deactivate()
|
defer httpmock.Deactivate()
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
|
|
|
@ -44,12 +44,14 @@ func getDBClient(t *testing.T) *database.Client {
|
||||||
DbPath: dbPath.Name(),
|
DbPath: dbPath.Name(),
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return dbClient
|
return dbClient
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAPIC(t *testing.T) *apic {
|
func getAPIC(t *testing.T) *apic {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
dbClient := getDBClient(t)
|
dbClient := getDBClient(t)
|
||||||
|
|
||||||
return &apic{
|
return &apic{
|
||||||
AlertsAddChan: make(chan []*models.Alert),
|
AlertsAddChan: make(chan []*models.Alert),
|
||||||
//DecisionDeleteChan: make(chan []*models.Decision),
|
//DecisionDeleteChan: make(chan []*models.Decision),
|
||||||
|
@ -74,6 +76,7 @@ func absDiff(a int, b int) (c int) {
|
||||||
if c = a - b; c < 0 {
|
if c = a - b; c < 0 {
|
||||||
return -1 * c
|
return -1 * c
|
||||||
}
|
}
|
||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,6 +97,7 @@ func jsonMarshalX(v interface{}) []byte {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,7 +180,6 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
|
||||||
|
|
||||||
assert.ElementsMatch(t, tc.expectedScenarios, scenarios)
|
assert.ElementsMatch(t, tc.expectedScenarios, scenarios)
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,6 +223,7 @@ func TestNewAPIC(t *testing.T) {
|
||||||
expectedErr: "first path segment in URL cannot contain colon",
|
expectedErr: "first path segment in URL cannot contain colon",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
tc := tc
|
tc := tc
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
@ -274,7 +278,7 @@ func TestAPICHandleDeletedDecisions(t *testing.T) {
|
||||||
Scope: ptr.Of("IP"),
|
Scope: ptr.Of("IP"),
|
||||||
}}, deleteCounters)
|
}}, deleteCounters)
|
||||||
|
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 2, nbDeleted)
|
assert.Equal(t, 2, nbDeleted)
|
||||||
assert.Equal(t, 2, deleteCounters[types.CAPIOrigin]["all"])
|
assert.Equal(t, 2, deleteCounters[types.CAPIOrigin]["all"])
|
||||||
}
|
}
|
||||||
|
@ -338,6 +342,7 @@ func TestAPICGetMetrics(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
tc := tc
|
tc := tc
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
@ -394,6 +399,7 @@ func TestCreateAlertsForDecision(t *testing.T) {
|
||||||
Origin: ptr.Of(types.CAPIOrigin),
|
Origin: ptr.Of(types.CAPIOrigin),
|
||||||
Scenario: ptr.Of("crowdsecurity/ssh-bf"),
|
Scenario: ptr.Of("crowdsecurity/ssh-bf"),
|
||||||
}
|
}
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
decisions []*models.Decision
|
decisions []*models.Decision
|
||||||
}
|
}
|
||||||
|
@ -443,6 +449,7 @@ func TestCreateAlertsForDecision(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
tc := tc
|
tc := tc
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
@ -477,6 +484,7 @@ func TestFillAlertsWithDecisions(t *testing.T) {
|
||||||
Scenario: ptr.Of("crowdsecurity/ssh-bf"),
|
Scenario: ptr.Of("crowdsecurity/ssh-bf"),
|
||||||
Scope: ptr.Of("ip"),
|
Scope: ptr.Of("ip"),
|
||||||
}
|
}
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
alerts []*models.Alert
|
alerts []*models.Alert
|
||||||
decisions []*models.Decision
|
decisions []*models.Decision
|
||||||
|
@ -520,6 +528,7 @@ func TestFillAlertsWithDecisions(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
tc := tc
|
tc := tc
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
@ -546,12 +555,14 @@ func TestAPICWhitelists(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to parse cidr : %s", err)
|
t.Fatalf("unable to parse cidr : %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet)
|
api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet)
|
||||||
cidrwl1 = "11.2.3.0/24"
|
cidrwl1 = "11.2.3.0/24"
|
||||||
_, tnet, err = net.ParseCIDR(cidrwl1)
|
_, tnet, err = net.ParseCIDR(cidrwl1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to parse cidr : %s", err)
|
t.Fatalf("unable to parse cidr : %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet)
|
api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet)
|
||||||
api.dbClient.Ent.Decision.Create().
|
api.dbClient.Ent.Decision.Create().
|
||||||
SetOrigin(types.CAPIOrigin).
|
SetOrigin(types.CAPIOrigin).
|
||||||
|
@ -564,6 +575,7 @@ func TestAPICWhitelists(t *testing.T) {
|
||||||
assertTotalDecisionCount(t, api.dbClient, 1)
|
assertTotalDecisionCount(t, api.dbClient, 1)
|
||||||
assertTotalValidDecisionCount(t, api.dbClient, 1)
|
assertTotalValidDecisionCount(t, api.dbClient, 1)
|
||||||
httpmock.Activate()
|
httpmock.Activate()
|
||||||
|
|
||||||
defer httpmock.DeactivateAndReset()
|
defer httpmock.DeactivateAndReset()
|
||||||
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
|
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
|
||||||
200, jsonMarshalX(
|
200, jsonMarshalX(
|
||||||
|
@ -681,33 +693,39 @@ func TestAPICWhitelists(t *testing.T) {
|
||||||
AllX(context.Background())
|
AllX(context.Background())
|
||||||
|
|
||||||
decisionScenarioFreq := make(map[string]int)
|
decisionScenarioFreq := make(map[string]int)
|
||||||
decisionIp := make(map[string]int)
|
decisionIP := make(map[string]int)
|
||||||
|
|
||||||
alertScenario := make(map[string]int)
|
alertScenario := make(map[string]int)
|
||||||
|
|
||||||
for _, alert := range alerts {
|
for _, alert := range alerts {
|
||||||
alertScenario[alert.SourceScope]++
|
alertScenario[alert.SourceScope]++
|
||||||
}
|
}
|
||||||
assert.Equal(t, 3, len(alertScenario))
|
|
||||||
|
assert.Len(t, alertScenario, 3)
|
||||||
assert.Equal(t, 1, alertScenario[types.CommunityBlocklistPullSourceScope])
|
assert.Equal(t, 1, alertScenario[types.CommunityBlocklistPullSourceScope])
|
||||||
assert.Equal(t, 1, alertScenario["lists:blocklist1"])
|
assert.Equal(t, 1, alertScenario["lists:blocklist1"])
|
||||||
assert.Equal(t, 1, alertScenario["lists:blocklist2"])
|
assert.Equal(t, 1, alertScenario["lists:blocklist2"])
|
||||||
|
|
||||||
for _, decisions := range validDecisions {
|
for _, decisions := range validDecisions {
|
||||||
decisionScenarioFreq[decisions.Scenario]++
|
decisionScenarioFreq[decisions.Scenario]++
|
||||||
decisionIp[decisions.Value]++
|
decisionIP[decisions.Value]++
|
||||||
}
|
}
|
||||||
assert.Equal(t, 1, decisionIp["2.2.3.4"], 1)
|
|
||||||
assert.Equal(t, 1, decisionIp["6.2.3.4"], 1)
|
assert.Equal(t, 1, decisionIP["2.2.3.4"], 1)
|
||||||
if _, ok := decisionIp["13.2.3.4"]; ok {
|
assert.Equal(t, 1, decisionIP["6.2.3.4"], 1)
|
||||||
|
|
||||||
|
if _, ok := decisionIP["13.2.3.4"]; ok {
|
||||||
t.Errorf("13.2.3.4 is whitelisted")
|
t.Errorf("13.2.3.4 is whitelisted")
|
||||||
}
|
}
|
||||||
if _, ok := decisionIp["13.2.3.5"]; ok {
|
|
||||||
|
if _, ok := decisionIP["13.2.3.5"]; ok {
|
||||||
t.Errorf("13.2.3.5 is whitelisted")
|
t.Errorf("13.2.3.5 is whitelisted")
|
||||||
}
|
}
|
||||||
if _, ok := decisionIp["9.2.3.4"]; ok {
|
|
||||||
|
if _, ok := decisionIP["9.2.3.4"]; ok {
|
||||||
t.Errorf("9.2.3.4 is whitelisted")
|
t.Errorf("9.2.3.4 is whitelisted")
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, 1, decisionScenarioFreq["blocklist1"], 1)
|
assert.Equal(t, 1, decisionScenarioFreq["blocklist1"], 1)
|
||||||
assert.Equal(t, 1, decisionScenarioFreq["blocklist2"], 1)
|
assert.Equal(t, 1, decisionScenarioFreq["blocklist2"], 1)
|
||||||
assert.Equal(t, 2, decisionScenarioFreq["crowdsecurity/test1"], 2)
|
assert.Equal(t, 2, decisionScenarioFreq["crowdsecurity/test1"], 2)
|
||||||
|
@ -726,6 +744,7 @@ func TestAPICPullTop(t *testing.T) {
|
||||||
assertTotalDecisionCount(t, api.dbClient, 1)
|
assertTotalDecisionCount(t, api.dbClient, 1)
|
||||||
assertTotalValidDecisionCount(t, api.dbClient, 1)
|
assertTotalValidDecisionCount(t, api.dbClient, 1)
|
||||||
httpmock.Activate()
|
httpmock.Activate()
|
||||||
|
|
||||||
defer httpmock.DeactivateAndReset()
|
defer httpmock.DeactivateAndReset()
|
||||||
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
|
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
|
||||||
200, jsonMarshalX(
|
200, jsonMarshalX(
|
||||||
|
@ -817,7 +836,8 @@ func TestAPICPullTop(t *testing.T) {
|
||||||
for _, alert := range alerts {
|
for _, alert := range alerts {
|
||||||
alertScenario[alert.SourceScope]++
|
alertScenario[alert.SourceScope]++
|
||||||
}
|
}
|
||||||
assert.Equal(t, 3, len(alertScenario))
|
|
||||||
|
assert.Len(t, alertScenario, 3)
|
||||||
assert.Equal(t, 1, alertScenario[types.CommunityBlocklistPullSourceScope])
|
assert.Equal(t, 1, alertScenario[types.CommunityBlocklistPullSourceScope])
|
||||||
assert.Equal(t, 1, alertScenario["lists:blocklist1"])
|
assert.Equal(t, 1, alertScenario["lists:blocklist1"])
|
||||||
assert.Equal(t, 1, alertScenario["lists:blocklist2"])
|
assert.Equal(t, 1, alertScenario["lists:blocklist2"])
|
||||||
|
@ -835,6 +855,7 @@ func TestAPICPullTop(t *testing.T) {
|
||||||
func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
|
func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
|
||||||
// no decision in db, no last modified parameter.
|
// no decision in db, no last modified parameter.
|
||||||
api := getAPIC(t)
|
api := getAPIC(t)
|
||||||
|
|
||||||
httpmock.Activate()
|
httpmock.Activate()
|
||||||
defer httpmock.DeactivateAndReset()
|
defer httpmock.DeactivateAndReset()
|
||||||
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
|
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
|
||||||
|
@ -904,6 +925,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
|
||||||
|
|
||||||
func TestAPICPullTopBLCacheForceCall(t *testing.T) {
|
func TestAPICPullTopBLCacheForceCall(t *testing.T) {
|
||||||
api := getAPIC(t)
|
api := getAPIC(t)
|
||||||
|
|
||||||
httpmock.Activate()
|
httpmock.Activate()
|
||||||
defer httpmock.DeactivateAndReset()
|
defer httpmock.DeactivateAndReset()
|
||||||
// create a decision about to expire. It should force fetch
|
// create a decision about to expire. It should force fetch
|
||||||
|
@ -975,6 +997,7 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) {
|
||||||
|
|
||||||
func TestAPICPullBlocklistCall(t *testing.T) {
|
func TestAPICPullBlocklistCall(t *testing.T) {
|
||||||
api := getAPIC(t)
|
api := getAPIC(t)
|
||||||
|
|
||||||
httpmock.Activate()
|
httpmock.Activate()
|
||||||
defer httpmock.DeactivateAndReset()
|
defer httpmock.DeactivateAndReset()
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,6 @@ package apiserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
@ -15,7 +13,6 @@ import (
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-co-op/gocron"
|
"github.com/go-co-op/gocron"
|
||||||
"github.com/golang-jwt/jwt/v4"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gopkg.in/natefinch/lumberjack.v2"
|
"gopkg.in/natefinch/lumberjack.v2"
|
||||||
|
@ -23,7 +20,6 @@ import (
|
||||||
|
|
||||||
"github.com/crowdsecurity/go-cs-lib/trace"
|
"github.com/crowdsecurity/go-cs-lib/trace"
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers"
|
"github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers"
|
||||||
v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
|
v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
||||||
|
@ -32,9 +28,7 @@ import (
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/types"
|
"github.com/crowdsecurity/crowdsec/pkg/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
const keyLength = 32
|
||||||
keyLength = 32
|
|
||||||
)
|
|
||||||
|
|
||||||
type APIServer struct {
|
type APIServer struct {
|
||||||
URL string
|
URL string
|
||||||
|
@ -52,57 +46,117 @@ type APIServer struct {
|
||||||
isEnrolled bool
|
isEnrolled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one.
|
func recoverFromPanic(c *gin.Context) {
|
||||||
|
err := recover()
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for a broken connection, as it is not really a
|
||||||
|
// condition that warrants a panic stack trace.
|
||||||
|
brokenPipe := false
|
||||||
|
|
||||||
|
if ne, ok := err.(*net.OpError); ok {
|
||||||
|
if se, ok := ne.Err.(*os.SyscallError); ok {
|
||||||
|
if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
|
||||||
|
brokenPipe = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// because of https://github.com/golang/net/blob/39120d07d75e76f0079fe5d27480bcb965a21e4c/http2/server.go
|
||||||
|
// and because it seems gin doesn't handle those neither, we need to "hand define" some errors to properly catch them
|
||||||
|
if strErr, ok := err.(error); ok {
|
||||||
|
//stolen from http2/server.go in x/net
|
||||||
|
var (
|
||||||
|
errClientDisconnected = errors.New("client disconnected")
|
||||||
|
errClosedBody = errors.New("body closed by handler")
|
||||||
|
errHandlerComplete = errors.New("http2: request body closed due to handler exiting")
|
||||||
|
errStreamClosed = errors.New("http2: stream closed")
|
||||||
|
)
|
||||||
|
|
||||||
|
if errors.Is(strErr, errClientDisconnected) ||
|
||||||
|
errors.Is(strErr, errClosedBody) ||
|
||||||
|
errors.Is(strErr, errHandlerComplete) ||
|
||||||
|
errors.Is(strErr, errStreamClosed) {
|
||||||
|
brokenPipe = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if brokenPipe {
|
||||||
|
log.Warningf("client %s disconnected : %s", c.ClientIP(), err)
|
||||||
|
c.Abort()
|
||||||
|
} else {
|
||||||
|
filename := trace.WriteStackTrace(err)
|
||||||
|
log.Warningf("client %s error : %s", c.ClientIP(), err)
|
||||||
|
log.Warningf("stacktrace written to %s, please join to your issue", filename)
|
||||||
|
c.AbortWithStatus(http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CustomRecoveryWithWriter returns a middleware for a writer that recovers from any panics and writes a 500 if there was one.
|
||||||
func CustomRecoveryWithWriter() gin.HandlerFunc {
|
func CustomRecoveryWithWriter() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
defer func() {
|
defer recoverFromPanic(c)
|
||||||
if err := recover(); err != nil {
|
|
||||||
// Check for a broken connection, as it is not really a
|
|
||||||
// condition that warrants a panic stack trace.
|
|
||||||
var brokenPipe bool
|
|
||||||
if ne, ok := err.(*net.OpError); ok {
|
|
||||||
if se, ok := ne.Err.(*os.SyscallError); ok {
|
|
||||||
if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
|
|
||||||
brokenPipe = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// because of https://github.com/golang/net/blob/39120d07d75e76f0079fe5d27480bcb965a21e4c/http2/server.go
|
|
||||||
// and because it seems gin doesn't handle those neither, we need to "hand define" some errors to properly catch them
|
|
||||||
if strErr, ok := err.(error); ok {
|
|
||||||
//stolen from http2/server.go in x/net
|
|
||||||
var (
|
|
||||||
errClientDisconnected = errors.New("client disconnected")
|
|
||||||
errClosedBody = errors.New("body closed by handler")
|
|
||||||
errHandlerComplete = errors.New("http2: request body closed due to handler exiting")
|
|
||||||
errStreamClosed = errors.New("http2: stream closed")
|
|
||||||
)
|
|
||||||
if errors.Is(strErr, errClientDisconnected) ||
|
|
||||||
errors.Is(strErr, errClosedBody) ||
|
|
||||||
errors.Is(strErr, errHandlerComplete) ||
|
|
||||||
errors.Is(strErr, errStreamClosed) {
|
|
||||||
brokenPipe = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if brokenPipe {
|
|
||||||
log.Warningf("client %s disconnected : %s", c.ClientIP(), err)
|
|
||||||
c.Abort()
|
|
||||||
} else {
|
|
||||||
filename := trace.WriteStackTrace(err)
|
|
||||||
log.Warningf("client %s error : %s", c.ClientIP(), err)
|
|
||||||
log.Warningf("stacktrace written to %s, please join to your issue", filename)
|
|
||||||
c.AbortWithStatus(http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// XXX: could be a method of LocalApiServerCfg
|
||||||
|
func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, error) {
|
||||||
|
clog := log.New()
|
||||||
|
|
||||||
|
if err := types.ConfigureLogger(clog); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("while configuring gin logger: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.LogLevel != nil {
|
||||||
|
clog.SetLevel(*config.LogLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.LogMedia != "file" {
|
||||||
|
return clog, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log rotation
|
||||||
|
|
||||||
|
logFile := filepath.Join(config.LogDir, "crowdsec_api.log")
|
||||||
|
log.Debugf("starting router, logging to %s", logFile)
|
||||||
|
|
||||||
|
logger := &lumberjack.Logger{
|
||||||
|
Filename: logFile,
|
||||||
|
MaxSize: 500, //megabytes
|
||||||
|
MaxBackups: 3,
|
||||||
|
MaxAge: 28, //days
|
||||||
|
Compress: true, //disabled by default
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.LogMaxSize != 0 {
|
||||||
|
logger.MaxSize = config.LogMaxSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.LogMaxFiles != 0 {
|
||||||
|
logger.MaxBackups = config.LogMaxFiles
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.LogMaxAge != 0 {
|
||||||
|
logger.MaxAge = config.LogMaxAge
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.CompressLogs != nil {
|
||||||
|
logger.Compress = *config.CompressLogs
|
||||||
|
}
|
||||||
|
|
||||||
|
clog.SetOutput(logger)
|
||||||
|
|
||||||
|
return clog, logFile, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates a LAPI server.
|
||||||
|
// It sets up a gin router, a database client, and a controller.
|
||||||
func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
|
func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
|
||||||
var flushScheduler *gocron.Scheduler
|
var flushScheduler *gocron.Scheduler
|
||||||
|
|
||||||
dbClient, err := database.NewClient(config.DbConfig)
|
dbClient, err := database.NewClient(config.DbConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to init database client: %w", err)
|
return nil, fmt.Errorf("unable to init database client: %w", err)
|
||||||
|
@ -115,63 +169,26 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logFile := ""
|
|
||||||
if config.LogMedia == "file" {
|
|
||||||
logFile = filepath.Join(config.LogDir, "crowdsec_api.log")
|
|
||||||
}
|
|
||||||
|
|
||||||
if log.GetLevel() < log.DebugLevel {
|
if log.GetLevel() < log.DebugLevel {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
log.Debugf("starting router, logging to %s", logFile)
|
|
||||||
router := gin.New()
|
router := gin.New()
|
||||||
|
|
||||||
|
router.ForwardedByClientIP = false
|
||||||
|
|
||||||
if config.TrustedProxies != nil && config.UseForwardedForHeaders {
|
if config.TrustedProxies != nil && config.UseForwardedForHeaders {
|
||||||
if err := router.SetTrustedProxies(*config.TrustedProxies); err != nil {
|
if err = router.SetTrustedProxies(*config.TrustedProxies); err != nil {
|
||||||
return nil, fmt.Errorf("while setting trusted_proxies: %w", err)
|
return nil, fmt.Errorf("while setting trusted_proxies: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
router.ForwardedByClientIP = true
|
router.ForwardedByClientIP = true
|
||||||
} else {
|
|
||||||
router.ForwardedByClientIP = false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*The logger that will be used by handlers*/
|
// The logger that will be used by handlers
|
||||||
clog := log.New()
|
clog, logFile, err := newGinLogger(config)
|
||||||
|
if err != nil {
|
||||||
if err := types.ConfigureLogger(clog); err != nil {
|
return nil, err
|
||||||
return nil, fmt.Errorf("while configuring gin logger: %w", err)
|
|
||||||
}
|
|
||||||
if config.LogLevel != nil {
|
|
||||||
clog.SetLevel(*config.LogLevel)
|
|
||||||
}
|
|
||||||
|
|
||||||
/*Configure logs*/
|
|
||||||
if logFile != "" {
|
|
||||||
_maxsize := 500
|
|
||||||
if config.LogMaxSize != 0 {
|
|
||||||
_maxsize = config.LogMaxSize
|
|
||||||
}
|
|
||||||
_maxfiles := 3
|
|
||||||
if config.LogMaxFiles != 0 {
|
|
||||||
_maxfiles = config.LogMaxFiles
|
|
||||||
}
|
|
||||||
_maxage := 28
|
|
||||||
if config.LogMaxAge != 0 {
|
|
||||||
_maxage = config.LogMaxAge
|
|
||||||
}
|
|
||||||
_compress := true
|
|
||||||
if config.CompressLogs != nil {
|
|
||||||
_compress = *config.CompressLogs
|
|
||||||
}
|
|
||||||
|
|
||||||
LogOutput := &lumberjack.Logger{
|
|
||||||
Filename: logFile,
|
|
||||||
MaxSize: _maxsize, //megabytes
|
|
||||||
MaxBackups: _maxfiles,
|
|
||||||
MaxAge: _maxage, //days
|
|
||||||
Compress: _compress, //disabled by default
|
|
||||||
}
|
|
||||||
clog.SetOutput(LogOutput)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
gin.DefaultErrorWriter = clog.WriterLevel(log.ErrorLevel)
|
gin.DefaultErrorWriter = clog.WriterLevel(log.ErrorLevel)
|
||||||
|
@ -206,41 +223,50 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
|
||||||
DisableRemoteLapiRegistration: config.DisableRemoteLapiRegistration,
|
DisableRemoteLapiRegistration: config.DisableRemoteLapiRegistration,
|
||||||
}
|
}
|
||||||
|
|
||||||
var apiClient *apic
|
var (
|
||||||
var papiClient *Papi
|
apiClient *apic
|
||||||
var isMachineEnrolled = false
|
papiClient *Papi
|
||||||
|
isMachineEnrolled = false
|
||||||
|
)
|
||||||
|
|
||||||
|
controller.AlertsAddChan = nil
|
||||||
|
controller.DecisionDeleteChan = nil
|
||||||
|
|
||||||
if config.OnlineClient != nil && config.OnlineClient.Credentials != nil {
|
if config.OnlineClient != nil && config.OnlineClient.Credentials != nil {
|
||||||
log.Printf("Loading CAPI manager")
|
log.Printf("Loading CAPI manager")
|
||||||
|
|
||||||
apiClient, err = NewAPIC(config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists)
|
apiClient, err = NewAPIC(config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("CAPI manager configured successfully")
|
log.Infof("CAPI manager configured successfully")
|
||||||
isMachineEnrolled = isEnrolled(apiClient.apiClient)
|
|
||||||
controller.AlertsAddChan = apiClient.AlertsAddChan
|
controller.AlertsAddChan = apiClient.AlertsAddChan
|
||||||
if isMachineEnrolled {
|
|
||||||
|
if apiClient.apiClient.IsEnrolled() {
|
||||||
|
isMachineEnrolled = true
|
||||||
|
|
||||||
log.Infof("Machine is enrolled in the console, Loading PAPI Client")
|
log.Infof("Machine is enrolled in the console, Loading PAPI Client")
|
||||||
|
|
||||||
papiClient, err = NewPAPI(apiClient, dbClient, config.ConsoleConfig, *config.PapiLogLevel)
|
papiClient, err = NewPAPI(apiClient, dbClient, config.ConsoleConfig, *config.PapiLogLevel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.DecisionDeleteChan = papiClient.Channels.DeleteDecisionChannel
|
controller.DecisionDeleteChan = papiClient.Channels.DeleteDecisionChannel
|
||||||
} else {
|
} else {
|
||||||
log.Errorf("Machine is not enrolled in the console, can't synchronize with the console")
|
log.Errorf("Machine is not enrolled in the console, can't synchronize with the console")
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
apiClient = nil
|
|
||||||
controller.AlertsAddChan = nil
|
|
||||||
controller.DecisionDeleteChan = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if trustedIPs, err := config.GetTrustedIPs(); err == nil {
|
trustedIPs, err := config.GetTrustedIPs()
|
||||||
controller.TrustedIPs = trustedIPs
|
if err != nil {
|
||||||
} else {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
controller.TrustedIPs = trustedIPs
|
||||||
|
|
||||||
return &APIServer{
|
return &APIServer{
|
||||||
URL: config.ListenURI,
|
URL: config.ListenURI,
|
||||||
TLS: config.TLS,
|
TLS: config.TLS,
|
||||||
|
@ -255,80 +281,20 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
|
||||||
consoleConfig: config.ConsoleConfig,
|
consoleConfig: config.ConsoleConfig,
|
||||||
isEnrolled: isMachineEnrolled,
|
isEnrolled: isMachineEnrolled,
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func isEnrolled(client *apiclient.ApiClient) bool {
|
|
||||||
apiHTTPClient := client.GetClient()
|
|
||||||
jwtTransport := apiHTTPClient.Transport.(*apiclient.JWTTransport)
|
|
||||||
tokenStr := jwtTransport.Token
|
|
||||||
|
|
||||||
token, _ := jwt.Parse(tokenStr, nil)
|
|
||||||
if token == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
claims := token.Claims.(jwt.MapClaims)
|
|
||||||
_, ok := claims["organization_id"]
|
|
||||||
|
|
||||||
return ok
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIServer) Router() (*gin.Engine, error) {
|
func (s *APIServer) Router() (*gin.Engine, error) {
|
||||||
return s.router, nil
|
return s.router, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIServer) GetTLSConfig() (*tls.Config, error) {
|
|
||||||
var caCert []byte
|
|
||||||
var err error
|
|
||||||
var caCertPool *x509.CertPool
|
|
||||||
var clientAuthType tls.ClientAuthType
|
|
||||||
|
|
||||||
if s.TLS == nil {
|
|
||||||
return &tls.Config{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.TLS.ClientVerification == "" {
|
|
||||||
//sounds like a sane default : verify client cert if given, but don't make it mandatory
|
|
||||||
clientAuthType = tls.VerifyClientCertIfGiven
|
|
||||||
} else {
|
|
||||||
clientAuthType, err = getTLSAuthType(s.TLS.ClientVerification)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.TLS.CACertPath != "" {
|
|
||||||
if clientAuthType > tls.RequestClientCert {
|
|
||||||
log.Infof("(tls) Client Auth Type set to %s", clientAuthType.String())
|
|
||||||
caCert, err = os.ReadFile(s.TLS.CACertPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("while opening cert file: %w", err)
|
|
||||||
}
|
|
||||||
caCertPool, err = x509.SystemCertPool()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("Error loading system CA certificates: %s", err)
|
|
||||||
}
|
|
||||||
if caCertPool == nil {
|
|
||||||
caCertPool = x509.NewCertPool()
|
|
||||||
}
|
|
||||||
caCertPool.AppendCertsFromPEM(caCert)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &tls.Config{
|
|
||||||
ServerName: s.TLS.ServerName, //should it be removed ?
|
|
||||||
ClientAuth: clientAuthType,
|
|
||||||
ClientCAs: caCertPool,
|
|
||||||
MinVersion: tls.VersionTLS12, // TLS versions below 1.2 are considered insecure - see https://www.rfc-editor.org/rfc/rfc7525.txt for details
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *APIServer) Run(apiReady chan bool) error {
|
func (s *APIServer) Run(apiReady chan bool) error {
|
||||||
defer trace.CatchPanic("lapi/runServer")
|
defer trace.CatchPanic("lapi/runServer")
|
||||||
tlsCfg, err := s.GetTLSConfig()
|
|
||||||
|
tlsCfg, err := s.TLS.GetTLSConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("while creating TLS config: %w", err)
|
return fmt.Errorf("while creating TLS config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.httpServer = &http.Server{
|
s.httpServer = &http.Server{
|
||||||
Addr: s.URL,
|
Addr: s.URL,
|
||||||
Handler: s.router,
|
Handler: s.router,
|
||||||
|
@ -386,41 +352,74 @@ func (s *APIServer) Run(apiReady chan bool) error {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
s.httpServerTomb.Go(func() error {
|
s.httpServerTomb.Go(func() error { s.listenAndServeURL(apiReady); return nil })
|
||||||
go func() {
|
|
||||||
apiReady <- true
|
|
||||||
log.Infof("CrowdSec Local API listening on %s", s.URL)
|
|
||||||
if s.TLS != nil && (s.TLS.CertFilePath != "" || s.TLS.KeyFilePath != "") {
|
|
||||||
if s.TLS.KeyFilePath == "" {
|
|
||||||
log.Fatalf("while serving local API: %v", errors.New("missing TLS key file"))
|
|
||||||
} else if s.TLS.CertFilePath == "" {
|
|
||||||
log.Fatalf("while serving local API: %v", errors.New("missing TLS cert file"))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.httpServer.ListenAndServeTLS(s.TLS.CertFilePath, s.TLS.KeyFilePath); err != nil {
|
|
||||||
log.Fatalf("while serving local API: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := s.httpServer.ListenAndServe(); err != http.ErrServerClosed {
|
|
||||||
log.Fatalf("while serving local API: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
<-s.httpServerTomb.Dying()
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// listenAndServeURL starts the http server and blocks until it's closed
|
||||||
|
// it also updates the URL field with the actual address the server is listening on
|
||||||
|
// it's meant to be run in a separate goroutine
|
||||||
|
func (s *APIServer) listenAndServeURL(apiReady chan bool) {
|
||||||
|
serverError := make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
listener, err := net.Listen("tcp", s.URL)
|
||||||
|
if err != nil {
|
||||||
|
serverError <- fmt.Errorf("listening on %s: %w", s.URL, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.URL = listener.Addr().String()
|
||||||
|
log.Infof("CrowdSec Local API listening on %s", s.URL)
|
||||||
|
apiReady <- true
|
||||||
|
|
||||||
|
if s.TLS != nil && (s.TLS.CertFilePath != "" || s.TLS.KeyFilePath != "") {
|
||||||
|
if s.TLS.KeyFilePath == "" {
|
||||||
|
serverError <- errors.New("missing TLS key file")
|
||||||
|
return
|
||||||
|
} else if s.TLS.CertFilePath == "" {
|
||||||
|
serverError <- errors.New("missing TLS cert file")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.httpServer.ServeTLS(listener, s.TLS.CertFilePath, s.TLS.KeyFilePath)
|
||||||
|
} else {
|
||||||
|
err = s.httpServer.Serve(listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil && err != http.ErrServerClosed {
|
||||||
|
serverError <- fmt.Errorf("while serving local API: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-serverError:
|
||||||
|
log.Fatalf("while starting API server: %s", err)
|
||||||
|
case <-s.httpServerTomb.Dying():
|
||||||
|
log.Infof("Shutting down API server")
|
||||||
|
// do we need a graceful shutdown here?
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := s.httpServer.Shutdown(ctx); err != nil {
|
||||||
|
log.Errorf("while shutting down http server: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *APIServer) Close() {
|
func (s *APIServer) Close() {
|
||||||
if s.apic != nil {
|
if s.apic != nil {
|
||||||
s.apic.Shutdown() // stop apic first since it use dbClient
|
s.apic.Shutdown() // stop apic first since it use dbClient
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.papi != nil {
|
if s.papi != nil {
|
||||||
s.papi.Shutdown() // papi also uses the dbClient
|
s.papi.Shutdown() // papi also uses the dbClient
|
||||||
}
|
}
|
||||||
|
|
||||||
s.dbClient.Ent.Close()
|
s.dbClient.Ent.Close()
|
||||||
|
|
||||||
if s.flushScheduler != nil {
|
if s.flushScheduler != nil {
|
||||||
s.flushScheduler.Stop()
|
s.flushScheduler.Stop()
|
||||||
}
|
}
|
||||||
|
@ -428,6 +427,7 @@ func (s *APIServer) Close() {
|
||||||
|
|
||||||
func (s *APIServer) Shutdown() error {
|
func (s *APIServer) Shutdown() error {
|
||||||
s.Close()
|
s.Close()
|
||||||
|
|
||||||
if s.httpServer != nil {
|
if s.httpServer != nil {
|
||||||
if err := s.httpServer.Shutdown(context.TODO()); err != nil {
|
if err := s.httpServer.Shutdown(context.TODO()); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -438,13 +438,17 @@ func (s *APIServer) Shutdown() error {
|
||||||
if pipe, ok := gin.DefaultErrorWriter.(*io.PipeWriter); ok {
|
if pipe, ok := gin.DefaultErrorWriter.(*io.PipeWriter); ok {
|
||||||
pipe.Close()
|
pipe.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
if pipe, ok := gin.DefaultWriter.(*io.PipeWriter); ok {
|
if pipe, ok := gin.DefaultWriter.(*io.PipeWriter); ok {
|
||||||
pipe.Close()
|
pipe.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
s.httpServerTomb.Kill(nil)
|
s.httpServerTomb.Kill(nil)
|
||||||
|
|
||||||
if err := s.httpServerTomb.Wait(); err != nil {
|
if err := s.httpServerTomb.Wait(); err != nil {
|
||||||
return fmt.Errorf("while waiting on httpServerTomb: %w", err)
|
return fmt.Errorf("while waiting on httpServerTomb: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -453,36 +457,41 @@ func (s *APIServer) AttachPluginBroker(broker *csplugin.PluginBroker) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIServer) InitController() error {
|
func (s *APIServer) InitController() error {
|
||||||
|
|
||||||
err := s.controller.Init()
|
err := s.controller.Init()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("controller init: %w", err)
|
return fmt.Errorf("controller init: %w", err)
|
||||||
}
|
}
|
||||||
if s.TLS != nil {
|
|
||||||
var cacheExpiration time.Duration
|
if s.TLS == nil {
|
||||||
if s.TLS.CacheExpiration != nil {
|
return nil
|
||||||
cacheExpiration = *s.TLS.CacheExpiration
|
|
||||||
} else {
|
|
||||||
cacheExpiration = time.Hour
|
|
||||||
}
|
|
||||||
s.controller.HandlerV1.Middlewares.JWT.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedAgentsOU, s.TLS.CRLPath,
|
|
||||||
cacheExpiration,
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"component": "tls-auth",
|
|
||||||
"type": "agent",
|
|
||||||
}))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("while creating TLS auth for agents: %w", err)
|
|
||||||
}
|
|
||||||
s.controller.HandlerV1.Middlewares.APIKey.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedBouncersOU, s.TLS.CRLPath,
|
|
||||||
cacheExpiration,
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"component": "tls-auth",
|
|
||||||
"type": "bouncer",
|
|
||||||
}))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("while creating TLS auth for bouncers: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
|
// TLS is configured: create the TLSAuth middleware for agents and bouncers
|
||||||
|
|
||||||
|
cacheExpiration := time.Hour
|
||||||
|
if s.TLS.CacheExpiration != nil {
|
||||||
|
cacheExpiration = *s.TLS.CacheExpiration
|
||||||
|
}
|
||||||
|
|
||||||
|
s.controller.HandlerV1.Middlewares.JWT.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedAgentsOU, s.TLS.CRLPath,
|
||||||
|
cacheExpiration,
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"component": "tls-auth",
|
||||||
|
"type": "agent",
|
||||||
|
}))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("while creating TLS auth for agents: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.controller.HandlerV1.Middlewares.APIKey.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedBouncersOU, s.TLS.CRLPath,
|
||||||
|
cacheExpiration,
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"component": "tls-auth",
|
||||||
|
"type": "bouncer",
|
||||||
|
}))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("while creating TLS auth for bouncers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,21 +11,20 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/go-openapi/strfmt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/crowdsecurity/go-cs-lib/cstest"
|
"github.com/crowdsecurity/go-cs-lib/cstest"
|
||||||
"github.com/crowdsecurity/go-cs-lib/version"
|
"github.com/crowdsecurity/go-cs-lib/version"
|
||||||
|
|
||||||
middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
|
middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/models"
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/types"
|
|
||||||
"github.com/go-openapi/strfmt"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database"
|
"github.com/crowdsecurity/crowdsec/pkg/database"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/crowdsecurity/crowdsec/pkg/models"
|
||||||
|
"github.com/crowdsecurity/crowdsec/pkg/types"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var testMachineID = "test"
|
var testMachineID = "test"
|
||||||
|
@ -46,6 +45,7 @@ func LoadTestConfig(t *testing.T) csconfig.Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
|
tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
|
||||||
|
|
||||||
t.Cleanup(func() { os.RemoveAll(tempDir) })
|
t.Cleanup(func() { os.RemoveAll(tempDir) })
|
||||||
|
|
||||||
dbconfig := csconfig.DatabaseCfg{
|
dbconfig := csconfig.DatabaseCfg{
|
||||||
|
@ -70,6 +70,7 @@ func LoadTestConfig(t *testing.T) csconfig.Config {
|
||||||
if err := config.API.Server.LoadProfiles(); err != nil {
|
if err := config.API.Server.LoadProfiles(); err != nil {
|
||||||
log.Fatalf("failed to load profiles: %s", err)
|
log.Fatalf("failed to load profiles: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,6 +82,7 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
|
tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
|
||||||
|
|
||||||
t.Cleanup(func() { os.RemoveAll(tempDir) })
|
t.Cleanup(func() { os.RemoveAll(tempDir) })
|
||||||
|
|
||||||
dbconfig := csconfig.DatabaseCfg{
|
dbconfig := csconfig.DatabaseCfg{
|
||||||
|
@ -107,18 +109,22 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config {
|
||||||
if err := config.API.Server.LoadProfiles(); err != nil {
|
if err := config.API.Server.LoadProfiles(); err != nil {
|
||||||
log.Fatalf("failed to load profiles: %s", err)
|
log.Fatalf("failed to load profiles: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config, error) {
|
func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config, error) {
|
||||||
config := LoadTestConfig(t)
|
config := LoadTestConfig(t)
|
||||||
|
|
||||||
os.Remove("./ent")
|
os.Remove("./ent")
|
||||||
apiServer, err := NewServer(config.API.Server)
|
apiServer, err := NewServer(config.API.Server)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, config, fmt.Errorf("unable to run local API: %s", err)
|
return nil, config, fmt.Errorf("unable to run local API: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Creating new API server")
|
log.Printf("Creating new API server")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
return apiServer, config, nil
|
return apiServer, config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,6 +141,7 @@ func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, config, fmt.Errorf("unable to run local API: %s", err)
|
return nil, config, fmt.Errorf("unable to run local API: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return router, config, nil
|
return router, config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,12 +157,14 @@ func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, config, fmt.Errorf("unable to run local API: %s", err)
|
return nil, config, fmt.Errorf("unable to run local API: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Creating new API server")
|
log.Printf("Creating new API server")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
router, err := apiServer.Router()
|
router, err := apiServer.Router()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, config, fmt.Errorf("unable to run local API: %s", err)
|
return nil, config, fmt.Errorf("unable to run local API: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return router, config, nil
|
return router, config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,9 +173,11 @@ func ValidateMachine(machineID string, config *csconfig.DatabaseCfg) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to create new database client: %s", err)
|
return fmt.Errorf("unable to create new database client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := dbClient.ValidateMachine(machineID); err != nil {
|
if err := dbClient.ValidateMachine(machineID); err != nil {
|
||||||
return fmt.Errorf("unable to validate machine: %s", err)
|
return fmt.Errorf("unable to validate machine: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,23 +190,24 @@ func GetMachineIP(machineID string, config *csconfig.DatabaseCfg) (string, error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("Unable to list machines: %s", err)
|
return "", fmt.Errorf("Unable to list machines: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, machine := range machines {
|
for _, machine := range machines {
|
||||||
if machine.MachineId == machineID {
|
if machine.MachineId == machineID {
|
||||||
return machine.IpAddress, nil
|
return machine.IpAddress, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAlertReaderFromFile(path string) *strings.Reader {
|
func GetAlertReaderFromFile(path string) *strings.Reader {
|
||||||
|
|
||||||
alertContentBytes, err := os.ReadFile(path)
|
alertContentBytes, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
alerts := make([]*models.Alert, 0)
|
alerts := make([]*models.Alert, 0)
|
||||||
if err := json.Unmarshal(alertContentBytes, &alerts); err != nil {
|
if err = json.Unmarshal(alertContentBytes, &alerts); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,12 +220,13 @@ func GetAlertReaderFromFile(path string) *strings.Reader {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
return strings.NewReader(string(alertContent))
|
|
||||||
|
|
||||||
|
return strings.NewReader(string(alertContent))
|
||||||
}
|
}
|
||||||
|
|
||||||
func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision, int, error) {
|
func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision, int, error) {
|
||||||
var response []*models.Decision
|
var response []*models.Decision
|
||||||
|
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
return nil, 0, errors.New("response is nil")
|
return nil, 0, errors.New("response is nil")
|
||||||
}
|
}
|
||||||
|
@ -221,11 +234,13 @@ func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, resp.Code, err
|
return nil, resp.Code, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, resp.Code, nil
|
return response, resp.Code, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string, int, error) {
|
func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string, int, error) {
|
||||||
var response map[string]string
|
var response map[string]string
|
||||||
|
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
return nil, 0, errors.New("response is nil")
|
return nil, 0, errors.New("response is nil")
|
||||||
}
|
}
|
||||||
|
@ -233,11 +248,13 @@ func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, resp.Code, err
|
return nil, resp.Code, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, resp.Code, nil
|
return response, resp.Code, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int, error) {
|
func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int, error) {
|
||||||
var response models.DeleteDecisionResponse
|
var response models.DeleteDecisionResponse
|
||||||
|
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
return nil, 0, errors.New("response is nil")
|
return nil, 0, errors.New("response is nil")
|
||||||
}
|
}
|
||||||
|
@ -245,11 +262,13 @@ func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDec
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, resp.Code, err
|
return nil, resp.Code, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &response, resp.Code, nil
|
return &response, resp.Code, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int, error) {
|
func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int, error) {
|
||||||
response := make(map[string][]*models.Decision)
|
response := make(map[string][]*models.Decision)
|
||||||
|
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
return nil, 0, errors.New("response is nil")
|
return nil, 0, errors.New("response is nil")
|
||||||
}
|
}
|
||||||
|
@ -257,6 +276,7 @@ func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*mod
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, resp.Code, err
|
return nil, resp.Code, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, resp.Code, nil
|
return response, resp.Code, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -271,6 +291,7 @@ func CreateTestMachine(router *gin.Engine) (string, error) {
|
||||||
req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body))
|
req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body))
|
||||||
req.Header.Set("User-Agent", UserAgent)
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
router.ServeHTTP(w, req)
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -279,10 +300,12 @@ func CreateTestBouncer(config *csconfig.DatabaseCfg) (string, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("unable to create new database client: %s", err)
|
log.Fatalf("unable to create new database client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
apiKey, err := middlewares.GenerateAPIKey(keyLength)
|
apiKey, err := middlewares.GenerateAPIKey(keyLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to generate api key: %s", err)
|
return "", fmt.Errorf("unable to generate api key: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType)
|
_, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to create blocker: %s", err)
|
return "", fmt.Errorf("unable to create blocker: %s", err)
|
||||||
|
@ -322,7 +345,6 @@ func TestUnknownPath(t *testing.T) {
|
||||||
router.ServeHTTP(w, req)
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, 404, w.Code)
|
assert.Equal(t, 404, w.Code)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -348,6 +370,7 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
|
tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
|
||||||
|
|
||||||
t.Cleanup(func() { os.RemoveAll(tempDir) })
|
t.Cleanup(func() { os.RemoveAll(tempDir) })
|
||||||
|
|
||||||
dbconfig := csconfig.DatabaseCfg{
|
dbconfig := csconfig.DatabaseCfg{
|
||||||
|
@ -370,10 +393,12 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
|
||||||
if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil {
|
if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
api, err := NewServer(&cfg)
|
api, err := NewServer(&cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create api : %s", err)
|
t.Fatalf("failed to create api : %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if api == nil {
|
if api == nil {
|
||||||
t.Fatalf("failed to create api #2 is nbill")
|
t.Fatalf("failed to create api #2 is nbill")
|
||||||
}
|
}
|
||||||
|
@ -397,11 +422,9 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
|
||||||
t.Fatalf("expected %s in %s", expectedStr, string(data))
|
t.Fatalf("expected %s in %s", expectedStr, string(data))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoggingErrorToFileConfig(t *testing.T) {
|
func TestLoggingErrorToFileConfig(t *testing.T) {
|
||||||
|
|
||||||
/*declare settings*/
|
/*declare settings*/
|
||||||
maxAge := "1h"
|
maxAge := "1h"
|
||||||
flushConfig := csconfig.FlushDBCfg{
|
flushConfig := csconfig.FlushDBCfg{
|
||||||
|
@ -409,6 +432,7 @@ func TestLoggingErrorToFileConfig(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
|
tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
|
||||||
|
|
||||||
t.Cleanup(func() { os.RemoveAll(tempDir) })
|
t.Cleanup(func() { os.RemoveAll(tempDir) })
|
||||||
|
|
||||||
dbconfig := csconfig.DatabaseCfg{
|
dbconfig := csconfig.DatabaseCfg{
|
||||||
|
@ -434,6 +458,7 @@ func TestLoggingErrorToFileConfig(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create api : %s", err)
|
t.Fatalf("failed to create api : %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if api == nil {
|
if api == nil {
|
||||||
t.Fatalf("failed to create api #2 is nbill")
|
t.Fatalf("failed to create api #2 is nbill")
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,13 +6,14 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/alexliesenfeld/health"
|
"github.com/alexliesenfeld/health"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers/v1"
|
v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers/v1"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/csplugin"
|
"github.com/crowdsecurity/crowdsec/pkg/csplugin"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database"
|
"github.com/crowdsecurity/crowdsec/pkg/database"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/models"
|
"github.com/crowdsecurity/crowdsec/pkg/models"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Controller struct {
|
type Controller struct {
|
||||||
|
|
|
@ -10,15 +10,15 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
jwt "github.com/appleboy/gin-jwt/v2"
|
jwt "github.com/appleboy/gin-jwt/v2"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/go-openapi/strfmt"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/csplugin"
|
"github.com/crowdsecurity/crowdsec/pkg/csplugin"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/models"
|
"github.com/crowdsecurity/crowdsec/pkg/models"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/types"
|
"github.com/crowdsecurity/crowdsec/pkg/types"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/go-openapi/strfmt"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func FormatOneAlert(alert *ent.Alert) *models.Alert {
|
func FormatOneAlert(alert *ent.Alert) *models.Alert {
|
||||||
|
|
|
@ -7,11 +7,12 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/fflag"
|
"github.com/crowdsecurity/crowdsec/pkg/fflag"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/models"
|
"github.com/crowdsecurity/crowdsec/pkg/models"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Format decisions for the bouncers
|
// Format decisions for the bouncers
|
||||||
|
|
|
@ -3,9 +3,10 @@ package v1
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"github.com/crowdsecurity/crowdsec/pkg/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Controller) HandleDBErrors(gctx *gin.Context, err error) {
|
func (c *Controller) HandleDBErrors(gctx *gin.Context, err error) {
|
||||||
|
|
|
@ -3,10 +3,11 @@ package v1
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/models"
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/types"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-openapi/strfmt"
|
"github.com/go-openapi/strfmt"
|
||||||
|
|
||||||
|
"github.com/crowdsecurity/crowdsec/pkg/models"
|
||||||
|
"github.com/crowdsecurity/crowdsec/pkg/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Controller) CreateMachine(gctx *gin.Context) {
|
func (c *Controller) CreateMachine(gctx *gin.Context) {
|
||||||
|
|
|
@ -35,8 +35,11 @@ var LapiBouncerHits = prometheus.NewCounterVec(
|
||||||
[]string{"bouncer", "route", "method"},
|
[]string{"bouncer", "route", "method"},
|
||||||
)
|
)
|
||||||
|
|
||||||
/* keep track of the number of calls (per bouncer) that lead to nil/non-nil responses.
|
/*
|
||||||
while it's not exact, it's a good way to know - when you have a rutpure bouncer - what is the rate of ok/ko answers you got from lapi*/
|
keep track of the number of calls (per bouncer) that lead to nil/non-nil responses.
|
||||||
|
|
||||||
|
while it's not exact, it's a good way to know - when you have a rutpure bouncer - what is the rate of ok/ko answers you got from lapi
|
||||||
|
*/
|
||||||
var LapiNilDecisions = prometheus.NewCounterVec(
|
var LapiNilDecisions = prometheus.NewCounterVec(
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
Name: "cs_lapi_decisions_ko_total",
|
Name: "cs_lapi_decisions_ko_total",
|
||||||
|
|
|
@ -4,8 +4,9 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -91,9 +92,9 @@ func TestGetDecisionFilters(t *testing.T) {
|
||||||
w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY)
|
w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
decisions, code, err := readDecisionsGetResp(w)
|
decisions, code, err := readDecisionsGetResp(w)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, 2, len(decisions))
|
assert.Len(t, decisions, 2)
|
||||||
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
|
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
|
||||||
assert.Equal(t, "91.121.79.179", *decisions[0].Value)
|
assert.Equal(t, "91.121.79.179", *decisions[0].Value)
|
||||||
assert.Equal(t, int64(1), decisions[0].ID)
|
assert.Equal(t, int64(1), decisions[0].ID)
|
||||||
|
@ -106,9 +107,9 @@ func TestGetDecisionFilters(t *testing.T) {
|
||||||
w = lapi.RecordResponse("GET", "/v1/decisions?type=ban", emptyBody, APIKEY)
|
w = lapi.RecordResponse("GET", "/v1/decisions?type=ban", emptyBody, APIKEY)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
decisions, code, err = readDecisionsGetResp(w)
|
decisions, code, err = readDecisionsGetResp(w)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, 2, len(decisions))
|
assert.Len(t, decisions, 2)
|
||||||
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
|
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
|
||||||
assert.Equal(t, "91.121.79.179", *decisions[0].Value)
|
assert.Equal(t, "91.121.79.179", *decisions[0].Value)
|
||||||
assert.Equal(t, int64(1), decisions[0].ID)
|
assert.Equal(t, int64(1), decisions[0].ID)
|
||||||
|
@ -124,9 +125,9 @@ func TestGetDecisionFilters(t *testing.T) {
|
||||||
w = lapi.RecordResponse("GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY)
|
w = lapi.RecordResponse("GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
decisions, code, err = readDecisionsGetResp(w)
|
decisions, code, err = readDecisionsGetResp(w)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, 1, len(decisions))
|
assert.Len(t, decisions, 1)
|
||||||
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
|
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
|
||||||
assert.Equal(t, "91.121.79.179", *decisions[0].Value)
|
assert.Equal(t, "91.121.79.179", *decisions[0].Value)
|
||||||
assert.Equal(t, int64(1), decisions[0].ID)
|
assert.Equal(t, int64(1), decisions[0].ID)
|
||||||
|
@ -139,9 +140,9 @@ func TestGetDecisionFilters(t *testing.T) {
|
||||||
w = lapi.RecordResponse("GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY)
|
w = lapi.RecordResponse("GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
decisions, code, err = readDecisionsGetResp(w)
|
decisions, code, err = readDecisionsGetResp(w)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, 1, len(decisions))
|
assert.Len(t, decisions, 1)
|
||||||
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
|
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
|
||||||
assert.Equal(t, "91.121.79.179", *decisions[0].Value)
|
assert.Equal(t, "91.121.79.179", *decisions[0].Value)
|
||||||
assert.Equal(t, int64(1), decisions[0].ID)
|
assert.Equal(t, int64(1), decisions[0].ID)
|
||||||
|
@ -153,12 +154,11 @@ func TestGetDecisionFilters(t *testing.T) {
|
||||||
w = lapi.RecordResponse("GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY)
|
w = lapi.RecordResponse("GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
decisions, code, err = readDecisionsGetResp(w)
|
decisions, code, err = readDecisionsGetResp(w)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, 2, len(decisions))
|
assert.Len(t, decisions, 2)
|
||||||
assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.179")
|
assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.179")
|
||||||
assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.178")
|
assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.178")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetDecision(t *testing.T) {
|
func TestGetDecision(t *testing.T) {
|
||||||
|
@ -171,9 +171,9 @@ func TestGetDecision(t *testing.T) {
|
||||||
w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY)
|
w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
decisions, code, err := readDecisionsGetResp(w)
|
decisions, code, err := readDecisionsGetResp(w)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, 3, len(decisions))
|
assert.Len(t, decisions, 3)
|
||||||
/*decisions get doesn't perform deduplication*/
|
/*decisions get doesn't perform deduplication*/
|
||||||
assert.Equal(t, "crowdsecurity/test", *decisions[0].Scenario)
|
assert.Equal(t, "crowdsecurity/test", *decisions[0].Scenario)
|
||||||
assert.Equal(t, "127.0.0.1", *decisions[0].Value)
|
assert.Equal(t, "127.0.0.1", *decisions[0].Value)
|
||||||
|
@ -190,7 +190,7 @@ func TestGetDecision(t *testing.T) {
|
||||||
// Get Decision with invalid filter. It should ignore this filter
|
// Get Decision with invalid filter. It should ignore this filter
|
||||||
w = lapi.RecordResponse("GET", "/v1/decisions?test=test", emptyBody, APIKEY)
|
w = lapi.RecordResponse("GET", "/v1/decisions?test=test", emptyBody, APIKEY)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
assert.Equal(t, 3, len(decisions))
|
assert.Len(t, decisions, 3)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteDecisionByID(t *testing.T) {
|
func TestDeleteDecisionByID(t *testing.T) {
|
||||||
|
@ -202,47 +202,47 @@ func TestDeleteDecisionByID(t *testing.T) {
|
||||||
//Have one alerts
|
//Have one alerts
|
||||||
w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
|
w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
|
||||||
decisions, code, err := readDecisionsStreamResp(w)
|
decisions, code, err := readDecisionsStreamResp(w)
|
||||||
assert.Equal(t, err, nil)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, 0, len(decisions["deleted"]))
|
assert.Empty(t, decisions["deleted"])
|
||||||
assert.Equal(t, 1, len(decisions["new"]))
|
assert.Len(t, decisions["new"], 1)
|
||||||
|
|
||||||
// Delete alert with Invalid ID
|
// Delete alert with Invalid ID
|
||||||
w = lapi.RecordResponse("DELETE", "/v1/decisions/test", emptyBody, PASSWORD)
|
w = lapi.RecordResponse("DELETE", "/v1/decisions/test", emptyBody, PASSWORD)
|
||||||
assert.Equal(t, 400, w.Code)
|
assert.Equal(t, 400, w.Code)
|
||||||
err_resp, _, err := readDecisionsErrorResp(w)
|
errResp, _, err := readDecisionsErrorResp(w)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "decision_id must be valid integer", err_resp["message"])
|
assert.Equal(t, "decision_id must be valid integer", errResp["message"])
|
||||||
|
|
||||||
// Delete alert with ID that not exist
|
// Delete alert with ID that not exist
|
||||||
w = lapi.RecordResponse("DELETE", "/v1/decisions/100", emptyBody, PASSWORD)
|
w = lapi.RecordResponse("DELETE", "/v1/decisions/100", emptyBody, PASSWORD)
|
||||||
assert.Equal(t, 500, w.Code)
|
assert.Equal(t, 500, w.Code)
|
||||||
err_resp, _, err = readDecisionsErrorResp(w)
|
errResp, _, err = readDecisionsErrorResp(w)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", err_resp["message"])
|
assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", errResp["message"])
|
||||||
|
|
||||||
//Have one alerts
|
//Have one alerts
|
||||||
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
|
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
|
||||||
decisions, code, err = readDecisionsStreamResp(w)
|
decisions, code, err = readDecisionsStreamResp(w)
|
||||||
assert.Equal(t, err, nil)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, 0, len(decisions["deleted"]))
|
assert.Empty(t, decisions["deleted"])
|
||||||
assert.Equal(t, 1, len(decisions["new"]))
|
assert.Len(t, decisions["new"], 1)
|
||||||
|
|
||||||
// Delete alert with valid ID
|
// Delete alert with valid ID
|
||||||
w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD)
|
w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
resp, _, err := readDecisionsDeleteResp(w)
|
resp, _, err := readDecisionsDeleteResp(w)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, resp.NbDeleted, "1")
|
assert.Equal(t, "1", resp.NbDeleted)
|
||||||
|
|
||||||
//Have one alert (because we delete an alert that has dup targets)
|
//Have one alert (because we delete an alert that has dup targets)
|
||||||
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
|
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
|
||||||
decisions, code, err = readDecisionsStreamResp(w)
|
decisions, code, err = readDecisionsStreamResp(w)
|
||||||
assert.Equal(t, err, nil)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, 0, len(decisions["deleted"]))
|
assert.Empty(t, decisions["deleted"])
|
||||||
assert.Equal(t, 1, len(decisions["new"]))
|
assert.Len(t, decisions["new"], 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteDecision(t *testing.T) {
|
func TestDeleteDecision(t *testing.T) {
|
||||||
|
@ -254,16 +254,16 @@ func TestDeleteDecision(t *testing.T) {
|
||||||
// Delete alert with Invalid filter
|
// Delete alert with Invalid filter
|
||||||
w := lapi.RecordResponse("DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD)
|
w := lapi.RecordResponse("DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD)
|
||||||
assert.Equal(t, 500, w.Code)
|
assert.Equal(t, 500, w.Code)
|
||||||
err_resp, _, err := readDecisionsErrorResp(w)
|
errResp, _, err := readDecisionsErrorResp(w)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, err_resp["message"], "'test' doesn't exist: invalid filter")
|
assert.Equal(t, "'test' doesn't exist: invalid filter", errResp["message"])
|
||||||
|
|
||||||
// Delete all alert
|
// Delete all alert
|
||||||
w = lapi.RecordResponse("DELETE", "/v1/decisions", emptyBody, PASSWORD)
|
w = lapi.RecordResponse("DELETE", "/v1/decisions", emptyBody, PASSWORD)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
resp, _, err := readDecisionsDeleteResp(w)
|
resp, _, err := readDecisionsDeleteResp(w)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, resp.NbDeleted, "3")
|
assert.Equal(t, "3", resp.NbDeleted)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamStartDecisionDedup(t *testing.T) {
|
func TestStreamStartDecisionDedup(t *testing.T) {
|
||||||
|
@ -276,10 +276,10 @@ func TestStreamStartDecisionDedup(t *testing.T) {
|
||||||
// Get Stream, we only get one decision (the longest one)
|
// Get Stream, we only get one decision (the longest one)
|
||||||
w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
|
w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
|
||||||
decisions, code, err := readDecisionsStreamResp(w)
|
decisions, code, err := readDecisionsStreamResp(w)
|
||||||
assert.Equal(t, nil, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, 0, len(decisions["deleted"]))
|
assert.Empty(t, decisions["deleted"])
|
||||||
assert.Equal(t, 1, len(decisions["new"]))
|
assert.Len(t, decisions["new"], 1)
|
||||||
assert.Equal(t, int64(3), decisions["new"][0].ID)
|
assert.Equal(t, int64(3), decisions["new"][0].ID)
|
||||||
assert.Equal(t, "test", *decisions["new"][0].Origin)
|
assert.Equal(t, "test", *decisions["new"][0].Origin)
|
||||||
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
|
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
|
||||||
|
@ -291,10 +291,10 @@ func TestStreamStartDecisionDedup(t *testing.T) {
|
||||||
// Get Stream, we only get one decision (the longest one, id=2)
|
// Get Stream, we only get one decision (the longest one, id=2)
|
||||||
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
|
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
|
||||||
decisions, code, err = readDecisionsStreamResp(w)
|
decisions, code, err = readDecisionsStreamResp(w)
|
||||||
assert.Equal(t, nil, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, 0, len(decisions["deleted"]))
|
assert.Empty(t, decisions["deleted"])
|
||||||
assert.Equal(t, 1, len(decisions["new"]))
|
assert.Len(t, decisions["new"], 1)
|
||||||
assert.Equal(t, int64(2), decisions["new"][0].ID)
|
assert.Equal(t, int64(2), decisions["new"][0].ID)
|
||||||
assert.Equal(t, "test", *decisions["new"][0].Origin)
|
assert.Equal(t, "test", *decisions["new"][0].Origin)
|
||||||
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
|
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
|
||||||
|
@ -306,10 +306,10 @@ func TestStreamStartDecisionDedup(t *testing.T) {
|
||||||
// And get the remaining decision (1)
|
// And get the remaining decision (1)
|
||||||
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
|
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
|
||||||
decisions, code, err = readDecisionsStreamResp(w)
|
decisions, code, err = readDecisionsStreamResp(w)
|
||||||
assert.Equal(t, nil, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, 0, len(decisions["deleted"]))
|
assert.Empty(t, decisions["deleted"])
|
||||||
assert.Equal(t, 1, len(decisions["new"]))
|
assert.Len(t, decisions["new"], 1)
|
||||||
assert.Equal(t, int64(1), decisions["new"][0].ID)
|
assert.Equal(t, int64(1), decisions["new"][0].ID)
|
||||||
assert.Equal(t, "test", *decisions["new"][0].Origin)
|
assert.Equal(t, "test", *decisions["new"][0].Origin)
|
||||||
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
|
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
|
||||||
|
@ -321,13 +321,13 @@ func TestStreamStartDecisionDedup(t *testing.T) {
|
||||||
//and now we only get a deleted decision
|
//and now we only get a deleted decision
|
||||||
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
|
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
|
||||||
decisions, code, err = readDecisionsStreamResp(w)
|
decisions, code, err = readDecisionsStreamResp(w)
|
||||||
assert.Equal(t, nil, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, 1, len(decisions["deleted"]))
|
assert.Len(t, decisions["deleted"], 1)
|
||||||
assert.Equal(t, int64(1), decisions["deleted"][0].ID)
|
assert.Equal(t, int64(1), decisions["deleted"][0].ID)
|
||||||
assert.Equal(t, "test", *decisions["deleted"][0].Origin)
|
assert.Equal(t, "test", *decisions["deleted"][0].Origin)
|
||||||
assert.Equal(t, "127.0.0.1", *decisions["deleted"][0].Value)
|
assert.Equal(t, "127.0.0.1", *decisions["deleted"][0].Value)
|
||||||
assert.Equal(t, 0, len(decisions["new"]))
|
assert.Empty(t, decisions["new"])
|
||||||
}
|
}
|
||||||
|
|
||||||
type DecisionCheck struct {
|
type DecisionCheck struct {
|
||||||
|
|
|
@ -91,5 +91,4 @@ func TestLogin(t *testing.T) {
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), "\"token\"")
|
assert.Contains(t, w.Body.String(), "\"token\"")
|
||||||
assert.Contains(t, w.Body.String(), "\"expire\"")
|
assert.Contains(t, w.Body.String(), "\"expire\"")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,6 @@ func TestCreateMachine(t *testing.T) {
|
||||||
|
|
||||||
assert.Equal(t, 201, w.Code)
|
assert.Equal(t, 201, w.Code)
|
||||||
assert.Equal(t, "", w.Body.String())
|
assert.Equal(t, "", w.Body.String())
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateMachineWithForwardedFor(t *testing.T) {
|
func TestCreateMachineWithForwardedFor(t *testing.T) {
|
||||||
|
@ -78,6 +77,7 @@ func TestCreateMachineWithForwardedFor(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Could not get machine IP : %s", err)
|
log.Fatalf("Could not get machine IP : %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, "1.1.1.1", ip)
|
assert.Equal(t, "1.1.1.1", ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,5 +165,4 @@ func TestCreateMachineAlreadyExist(t *testing.T) {
|
||||||
|
|
||||||
assert.Equal(t, 403, w.Code)
|
assert.Equal(t, 403, w.Code)
|
||||||
assert.Equal(t, "{\"message\":\"user 'test': user already exist\"}", w.Body.String())
|
assert.Equal(t, "{\"message\":\"user 'test': user already exist\"}", w.Body.String())
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,18 +8,19 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database"
|
"github.com/crowdsecurity/crowdsec/pkg/database"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/types"
|
"github.com/crowdsecurity/crowdsec/pkg/types"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
APIKeyHeader = "X-Api-Key"
|
APIKeyHeader = "X-Api-Key"
|
||||||
bouncerContextKey = "bouncer_info"
|
bouncerContextKey = "bouncer_info"
|
||||||
// max allowed by bcrypt 72 = 54 bytes in base64
|
// max allowed by bcrypt 72 = 54 bytes in base64
|
||||||
dummyAPIKeySize = 54
|
dummyAPIKeySize = 54
|
||||||
)
|
)
|
||||||
|
|
||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
|
|
|
@ -10,15 +10,16 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
jwt "github.com/appleboy/gin-jwt/v2"
|
jwt "github.com/appleboy/gin-jwt/v2"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/go-openapi/strfmt"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database"
|
"github.com/crowdsecurity/crowdsec/pkg/database"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database/ent/machine"
|
"github.com/crowdsecurity/crowdsec/pkg/database/ent/machine"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/models"
|
"github.com/crowdsecurity/crowdsec/pkg/models"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/types"
|
"github.com/crowdsecurity/crowdsec/pkg/types"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/go-openapi/strfmt"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var identityKey = "id"
|
var identityKey = "id"
|
||||||
|
@ -46,16 +47,12 @@ func IdentityHandler(c *gin.Context) interface{} {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
type authInput struct {
|
type authInput struct {
|
||||||
machineID string
|
machineID string
|
||||||
clientMachine *ent.Machine
|
clientMachine *ent.Machine
|
||||||
scenariosInput []string
|
scenariosInput []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
|
func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
|
||||||
ret := authInput{}
|
ret := authInput{}
|
||||||
|
|
||||||
|
@ -123,8 +120,6 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
|
||||||
return &ret, nil
|
return &ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
|
func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
|
||||||
var loginInput models.WatcherAuthRequest
|
var loginInput models.WatcherAuthRequest
|
||||||
var err error
|
var err error
|
||||||
|
@ -169,7 +164,6 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
|
||||||
return &ret, nil
|
return &ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
|
func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
|
||||||
var err error
|
var err error
|
||||||
var auth *authInput
|
var auth *authInput
|
||||||
|
|
|
@ -1,27 +0,0 @@
|
||||||
package apiserver
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
func getTLSAuthType(authType string) (tls.ClientAuthType, error) {
|
|
||||||
switch authType {
|
|
||||||
case "NoClientCert":
|
|
||||||
return tls.NoClientCert, nil
|
|
||||||
case "RequestClientCert":
|
|
||||||
log.Warn("RequestClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead")
|
|
||||||
return tls.RequestClientCert, nil
|
|
||||||
case "RequireAnyClientCert":
|
|
||||||
log.Warn("RequireAnyClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead")
|
|
||||||
return tls.RequireAnyClientCert, nil
|
|
||||||
case "VerifyClientCertIfGiven":
|
|
||||||
return tls.VerifyClientCertIfGiven, nil
|
|
||||||
case "RequireAndVerifyClientCert":
|
|
||||||
return tls.RequireAndVerifyClientCert, nil
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("unknown TLS client_verification value: %s", authType)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -212,18 +212,6 @@ type LocalApiServerCfg struct {
|
||||||
CapiWhitelists *CapiWhitelist `yaml:"-"`
|
CapiWhitelists *CapiWhitelist `yaml:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TLSCfg struct {
|
|
||||||
CertFilePath string `yaml:"cert_file"`
|
|
||||||
KeyFilePath string `yaml:"key_file"`
|
|
||||||
ClientVerification string `yaml:"client_verification,omitempty"`
|
|
||||||
ServerName string `yaml:"server_name"`
|
|
||||||
CACertPath string `yaml:"ca_cert_path"`
|
|
||||||
AllowedAgentsOU []string `yaml:"agents_allowed_ou"`
|
|
||||||
AllowedBouncersOU []string `yaml:"bouncers_allowed_ou"`
|
|
||||||
CRLPath string `yaml:"crl_path"`
|
|
||||||
CacheExpiration *time.Duration `yaml:"cache_expiration,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) LoadAPIServer() error {
|
func (c *Config) LoadAPIServer() error {
|
||||||
if c.DisableAPI {
|
if c.DisableAPI {
|
||||||
log.Warning("crowdsec local API is disabled from flag")
|
log.Warning("crowdsec local API is disabled from flag")
|
||||||
|
@ -243,13 +231,16 @@ func (c *Config) LoadAPIServer() error {
|
||||||
if !*c.API.Server.Enable {
|
if !*c.API.Server.Enable {
|
||||||
log.Warning("crowdsec local API is disabled because 'enable' is set to false")
|
log.Warning("crowdsec local API is disabled because 'enable' is set to false")
|
||||||
c.DisableAPI = true
|
c.DisableAPI = true
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.DisableAPI {
|
if c.DisableAPI {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.API.Server.ListenURI == "" {
|
||||||
|
return fmt.Errorf("no listen_uri specified")
|
||||||
|
}
|
||||||
|
|
||||||
//inherit log level from common, then api->server
|
//inherit log level from common, then api->server
|
||||||
var logLevel log.Level
|
var logLevel log.Level
|
||||||
if c.API.Server.LogLevel != nil {
|
if c.API.Server.LogLevel != nil {
|
||||||
|
|
|
@ -219,7 +219,9 @@ func TestLoadAPIServer(t *testing.T) {
|
||||||
input: &Config{
|
input: &Config{
|
||||||
Self: []byte(configData),
|
Self: []byte(configData),
|
||||||
API: &APICfg{
|
API: &APICfg{
|
||||||
Server: &LocalApiServerCfg{},
|
Server: &LocalApiServerCfg{
|
||||||
|
ListenURI: "http://crowdsec.api",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Common: &CommonCfg{
|
Common: &CommonCfg{
|
||||||
LogDir: "./testdata/",
|
LogDir: "./testdata/",
|
||||||
|
|
87
pkg/csconfig/tls.go
Normal file
87
pkg/csconfig/tls.go
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
package csconfig
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TLSCfg struct {
|
||||||
|
CertFilePath string `yaml:"cert_file"`
|
||||||
|
KeyFilePath string `yaml:"key_file"`
|
||||||
|
ClientVerification string `yaml:"client_verification,omitempty"`
|
||||||
|
ServerName string `yaml:"server_name"`
|
||||||
|
CACertPath string `yaml:"ca_cert_path"`
|
||||||
|
AllowedAgentsOU []string `yaml:"agents_allowed_ou"`
|
||||||
|
AllowedBouncersOU []string `yaml:"bouncers_allowed_ou"`
|
||||||
|
CRLPath string `yaml:"crl_path"`
|
||||||
|
CacheExpiration *time.Duration `yaml:"cache_expiration,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TLSCfg) GetAuthType() (tls.ClientAuthType, error) {
|
||||||
|
if t.ClientVerification == "" {
|
||||||
|
// sounds like a sane default: verify client cert if given, but don't make it mandatory
|
||||||
|
return tls.VerifyClientCertIfGiven, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t.ClientVerification {
|
||||||
|
case "NoClientCert":
|
||||||
|
return tls.NoClientCert, nil
|
||||||
|
case "RequestClientCert":
|
||||||
|
log.Warn("RequestClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead")
|
||||||
|
return tls.RequestClientCert, nil
|
||||||
|
case "RequireAnyClientCert":
|
||||||
|
log.Warn("RequireAnyClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead")
|
||||||
|
return tls.RequireAnyClientCert, nil
|
||||||
|
case "VerifyClientCertIfGiven":
|
||||||
|
return tls.VerifyClientCertIfGiven, nil
|
||||||
|
case "RequireAndVerifyClientCert":
|
||||||
|
return tls.RequireAndVerifyClientCert, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unknown TLS client_verification value: %s", t.ClientVerification)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TLSCfg) GetTLSConfig() (*tls.Config, error) {
|
||||||
|
if t == nil {
|
||||||
|
return &tls.Config{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
clientAuthType, err := t.GetAuthType()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
caCertPool, err := x509.SystemCertPool()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Error loading system CA certificates: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if caCertPool == nil {
|
||||||
|
caCertPool = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
|
||||||
|
// the > condition below is a weird way to say "if a client certificate is required"
|
||||||
|
// see https://pkg.go.dev/crypto/tls#ClientAuthType
|
||||||
|
if clientAuthType > tls.RequestClientCert && t.CACertPath != "" {
|
||||||
|
log.Infof("(tls) Client Auth Type set to %s", clientAuthType.String())
|
||||||
|
|
||||||
|
caCert, err := os.ReadFile(t.CACertPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("while opening cert file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
caCertPool.AppendCertsFromPEM(caCert)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tls.Config{
|
||||||
|
ServerName: t.ServerName, //should it be removed ?
|
||||||
|
ClientAuth: clientAuthType,
|
||||||
|
ClientCAs: caCertPool,
|
||||||
|
MinVersion: tls.VersionTLS12, // TLS versions below 1.2 are considered insecure - see https://www.rfc-editor.org/rfc/rfc7525.txt for details
|
||||||
|
}, nil
|
||||||
|
}
|
51
test/bats/01_crowdsec_lapi.bats
Normal file
51
test/bats/01_crowdsec_lapi.bats
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
#!/usr/bin/env bats
|
||||||
|
# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si:
|
||||||
|
|
||||||
|
set -u
|
||||||
|
|
||||||
|
setup_file() {
|
||||||
|
load "../lib/setup_file.sh"
|
||||||
|
}
|
||||||
|
|
||||||
|
teardown_file() {
|
||||||
|
load "../lib/teardown_file.sh"
|
||||||
|
}
|
||||||
|
|
||||||
|
setup() {
|
||||||
|
load "../lib/setup.sh"
|
||||||
|
load "../lib/bats-file/load.bash"
|
||||||
|
./instance-data load
|
||||||
|
}
|
||||||
|
|
||||||
|
teardown() {
|
||||||
|
./instance-crowdsec stop
|
||||||
|
}
|
||||||
|
|
||||||
|
#----------
|
||||||
|
|
||||||
|
# Tests for LAPI configuration and startup
|
||||||
|
|
||||||
|
@test "lapi (.api.server.enable=false)" {
|
||||||
|
rune -0 config_set '.api.server.enable=false'
|
||||||
|
rune -1 "${CROWDSEC}" -no-cs
|
||||||
|
assert_stderr --partial "You must run at least the API Server or crowdsec"
|
||||||
|
}
|
||||||
|
|
||||||
|
@test "lapi (no .api.server.listen_uri)" {
|
||||||
|
rune -0 config_set 'del(.api.server.listen_uri)'
|
||||||
|
rune -1 "${CROWDSEC}" -no-cs
|
||||||
|
assert_stderr --partial "no listen_uri specified"
|
||||||
|
}
|
||||||
|
|
||||||
|
@test "lapi (bad .api.server.listen_uri)" {
|
||||||
|
rune -0 config_set '.api.server.listen_uri="127.0.0.1:-80"'
|
||||||
|
rune -1 "${CROWDSEC}" -no-cs
|
||||||
|
assert_stderr --partial "while starting API server: listening on 127.0.0.1:-80: listen tcp: address -80: invalid port"
|
||||||
|
}
|
||||||
|
|
||||||
|
@test "lapi (listen on random port)" {
|
||||||
|
config_set '.common.log_media="stdout"'
|
||||||
|
rune -0 config_set '.api.server.listen_uri="127.0.0.1:0"'
|
||||||
|
rune -0 wait-for --err "CrowdSec Local API listening on 127.0.0.1:" "${CROWDSEC}" -no-cs
|
||||||
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ setup() {
|
||||||
load "../lib/setup.sh"
|
load "../lib/setup.sh"
|
||||||
load "../lib/bats-file/load.bash"
|
load "../lib/bats-file/load.bash"
|
||||||
./instance-data load
|
./instance-data load
|
||||||
./instance-crowdsec start
|
# don't run crowdsec here, not all tests require a running instance
|
||||||
}
|
}
|
||||||
|
|
||||||
teardown() {
|
teardown() {
|
||||||
|
@ -204,6 +204,7 @@ teardown() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@test "cscli lapi status" {
|
@test "cscli lapi status" {
|
||||||
|
rune -0 ./instance-crowdsec start
|
||||||
rune -0 cscli lapi status
|
rune -0 cscli lapi status
|
||||||
|
|
||||||
assert_stderr --partial "Loaded credentials from"
|
assert_stderr --partial "Loaded credentials from"
|
||||||
|
@ -260,6 +261,7 @@ teardown() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@test "cscli - bad LAPI password" {
|
@test "cscli - bad LAPI password" {
|
||||||
|
rune -0 ./instance-crowdsec start
|
||||||
LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path')
|
LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path')
|
||||||
config_set "${LOCAL_API_CREDENTIALS}" '.password="meh"'
|
config_set "${LOCAL_API_CREDENTIALS}" '.password="meh"'
|
||||||
|
|
||||||
|
@ -269,6 +271,7 @@ teardown() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@test "cscli metrics" {
|
@test "cscli metrics" {
|
||||||
|
rune -0 ./instance-crowdsec start
|
||||||
rune -0 cscli lapi status
|
rune -0 cscli lapi status
|
||||||
rune -0 cscli metrics
|
rune -0 cscli metrics
|
||||||
assert_output --partial "Route"
|
assert_output --partial "Route"
|
||||||
|
@ -297,6 +300,7 @@ teardown() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@test "cscli explain" {
|
@test "cscli explain" {
|
||||||
|
rune -0 ./instance-crowdsec start
|
||||||
line="Sep 19 18:33:22 scw-d95986 sshd[24347]: pam_unix(sshd:auth): authentication failure; logname= uid=0 euid=0 tty=ssh ruser= rhost=1.2.3.4"
|
line="Sep 19 18:33:22 scw-d95986 sshd[24347]: pam_unix(sshd:auth): authentication failure; logname= uid=0 euid=0 tty=ssh ruser= rhost=1.2.3.4"
|
||||||
|
|
||||||
rune -0 cscli parsers install crowdsecurity/syslog-logs
|
rune -0 cscli parsers install crowdsecurity/syslog-logs
|
||||||
|
|
Loading…
Reference in a new issue