|
@@ -8,13 +8,13 @@ import (
|
|
"net/url"
|
|
"net/url"
|
|
"os"
|
|
"os"
|
|
"reflect"
|
|
"reflect"
|
|
- "sort"
|
|
|
|
"sync"
|
|
"sync"
|
|
"testing"
|
|
"testing"
|
|
"time"
|
|
"time"
|
|
|
|
|
|
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
|
|
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
|
|
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
|
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
|
|
|
+ "github.com/crowdsecurity/crowdsec/pkg/cstest"
|
|
"github.com/crowdsecurity/crowdsec/pkg/cwversion"
|
|
"github.com/crowdsecurity/crowdsec/pkg/cwversion"
|
|
"github.com/crowdsecurity/crowdsec/pkg/database"
|
|
"github.com/crowdsecurity/crowdsec/pkg/database"
|
|
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
|
|
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
|
|
@@ -24,23 +24,20 @@ import (
|
|
"github.com/jarcoal/httpmock"
|
|
"github.com/jarcoal/httpmock"
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
+ "github.com/stretchr/testify/require"
|
|
"gopkg.in/tomb.v2"
|
|
"gopkg.in/tomb.v2"
|
|
)
|
|
)
|
|
|
|
|
|
func getDBClient(t *testing.T) *database.Client {
|
|
func getDBClient(t *testing.T) *database.Client {
|
|
t.Helper()
|
|
t.Helper()
|
|
dbPath, err := os.CreateTemp("", "*sqlite")
|
|
dbPath, err := os.CreateTemp("", "*sqlite")
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
|
|
+ require.NoError(t, err)
|
|
dbClient, err := database.NewClient(&csconfig.DatabaseCfg{
|
|
dbClient, err := database.NewClient(&csconfig.DatabaseCfg{
|
|
Type: "sqlite",
|
|
Type: "sqlite",
|
|
DbName: "crowdsec",
|
|
DbName: "crowdsec",
|
|
DbPath: dbPath.Name(),
|
|
DbPath: dbPath.Name(),
|
|
})
|
|
})
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
|
|
+ require.NoError(t, err)
|
|
return dbClient
|
|
return dbClient
|
|
}
|
|
}
|
|
|
|
|
|
@@ -98,11 +95,9 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) {
|
|
|
|
|
|
func TestAPICCAPIPullIsOld(t *testing.T) {
|
|
func TestAPICCAPIPullIsOld(t *testing.T) {
|
|
api := getAPIC(t)
|
|
api := getAPIC(t)
|
|
- isOld, err := api.CAPIPullIsOld()
|
|
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
|
|
|
|
|
|
+ isOld, err := api.CAPIPullIsOld()
|
|
|
|
+ require.NoError(t, err)
|
|
assert.True(t, isOld)
|
|
assert.True(t, isOld)
|
|
|
|
|
|
decision := api.dbClient.Ent.Decision.Create().
|
|
decision := api.dbClient.Ent.Decision.Create().
|
|
@@ -123,16 +118,13 @@ func TestAPICCAPIPullIsOld(t *testing.T) {
|
|
SaveX(context.Background())
|
|
SaveX(context.Background())
|
|
|
|
|
|
isOld, err = api.CAPIPullIsOld()
|
|
isOld, err = api.CAPIPullIsOld()
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
|
|
assert.False(t, isOld)
|
|
assert.False(t, isOld)
|
|
}
|
|
}
|
|
|
|
|
|
func TestAPICFetchScenariosListFromDB(t *testing.T) {
|
|
func TestAPICFetchScenariosListFromDB(t *testing.T) {
|
|
- api := getAPIC(t)
|
|
|
|
- testCases := []struct {
|
|
|
|
|
|
+ tests := []struct {
|
|
name string
|
|
name string
|
|
machineIDsWithScenarios map[string]string
|
|
machineIDsWithScenarios map[string]string
|
|
expectedScenarios []string
|
|
expectedScenarios []string
|
|
@@ -154,8 +146,10 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
|
|
- for _, tc := range testCases {
|
|
|
|
|
|
+ for _, tc := range tests {
|
|
|
|
+ tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
|
|
+ api := getAPIC(t)
|
|
for machineID, scenarios := range tc.machineIDsWithScenarios {
|
|
for machineID, scenarios := range tc.machineIDsWithScenarios {
|
|
api.dbClient.Ent.Machine.Create().
|
|
api.dbClient.Ent.Machine.Create().
|
|
SetMachineId(machineID).
|
|
SetMachineId(machineID).
|
|
@@ -164,17 +158,14 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
|
|
SetScenarios(scenarios).
|
|
SetScenarios(scenarios).
|
|
ExecX(context.Background())
|
|
ExecX(context.Background())
|
|
}
|
|
}
|
|
|
|
+
|
|
scenarios, err := api.FetchScenariosListFromDB()
|
|
scenarios, err := api.FetchScenariosListFromDB()
|
|
for machineID := range tc.machineIDsWithScenarios {
|
|
for machineID := range tc.machineIDsWithScenarios {
|
|
api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background())
|
|
api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background())
|
|
}
|
|
}
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- } else {
|
|
|
|
- sort.Strings(scenarios)
|
|
|
|
- sort.Strings(tc.expectedScenarios)
|
|
|
|
- assert.Equal(t, scenarios, tc.expectedScenarios)
|
|
|
|
- }
|
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
+
|
|
|
|
+ assert.ElementsMatch(t, tc.expectedScenarios, scenarios)
|
|
})
|
|
})
|
|
|
|
|
|
}
|
|
}
|
|
@@ -196,11 +187,10 @@ func TestNewAPIC(t *testing.T) {
|
|
consoleConfig *csconfig.ConsoleConfig
|
|
consoleConfig *csconfig.ConsoleConfig
|
|
}
|
|
}
|
|
tests := []struct {
|
|
tests := []struct {
|
|
- name string
|
|
|
|
- args args
|
|
|
|
- wantErr bool
|
|
|
|
- errorContains string
|
|
|
|
- action func()
|
|
|
|
|
|
+ name string
|
|
|
|
+ args args
|
|
|
|
+ expectedErr string
|
|
|
|
+ action func()
|
|
}{
|
|
}{
|
|
{
|
|
{
|
|
name: "simple",
|
|
name: "simple",
|
|
@@ -217,20 +207,16 @@ func TestNewAPIC(t *testing.T) {
|
|
dbClient: getDBClient(t),
|
|
dbClient: getDBClient(t),
|
|
consoleConfig: LoadTestConfig().API.Server.ConsoleConfig,
|
|
consoleConfig: LoadTestConfig().API.Server.ConsoleConfig,
|
|
},
|
|
},
|
|
- wantErr: true,
|
|
|
|
- errorContains: "first path segment in URL cannot contain colon",
|
|
|
|
|
|
+ expectedErr: "first path segment in URL cannot contain colon",
|
|
},
|
|
},
|
|
}
|
|
}
|
|
- for _, tt := range tests {
|
|
|
|
- t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
|
+ for _, tc := range tests {
|
|
|
|
+ tc := tc
|
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
setConfig()
|
|
setConfig()
|
|
- tt.action()
|
|
|
|
- _, err := NewAPIC(testConfig, tt.args.dbClient, tt.args.consoleConfig)
|
|
|
|
- if tt.wantErr {
|
|
|
|
- assert.ErrorContains(t, err, tt.errorContains)
|
|
|
|
- } else {
|
|
|
|
- assert.NoError(t, err)
|
|
|
|
- }
|
|
|
|
|
|
+ tc.action()
|
|
|
|
+ _, err := NewAPIC(testConfig, tc.args.dbClient, tc.args.consoleConfig)
|
|
|
|
+ cstest.RequireErrorContains(t, err, tc.expectedErr)
|
|
})
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -268,17 +254,16 @@ func TestAPICHandleDeletedDecisions(t *testing.T) {
|
|
}}, deleteCounters)
|
|
}}, deleteCounters)
|
|
|
|
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, err)
|
|
- assert.Equal(t, nbDeleted, 2)
|
|
|
|
- assert.Equal(t, deleteCounters[SCOPE_CAPI]["all"], 2)
|
|
|
|
|
|
+ assert.Equal(t, 2, nbDeleted)
|
|
|
|
+ assert.Equal(t, 2, deleteCounters[SCOPE_CAPI]["all"])
|
|
}
|
|
}
|
|
|
|
|
|
func TestAPICGetMetrics(t *testing.T) {
|
|
func TestAPICGetMetrics(t *testing.T) {
|
|
- api := getAPIC(t)
|
|
|
|
- cleanUp := func() {
|
|
|
|
|
|
+ cleanUp := func(api *apic) {
|
|
api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background())
|
|
api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background())
|
|
api.dbClient.Ent.Machine.Delete().ExecX(context.Background())
|
|
api.dbClient.Ent.Machine.Delete().ExecX(context.Background())
|
|
}
|
|
}
|
|
- testCases := []struct {
|
|
|
|
|
|
+ tests := []struct {
|
|
name string
|
|
name string
|
|
machineIDs []string
|
|
machineIDs []string
|
|
bouncers []string
|
|
bouncers []string
|
|
@@ -322,11 +307,13 @@ func TestAPICGetMetrics(t *testing.T) {
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
- for _, testCase := range testCases {
|
|
|
|
- t.Run(testCase.name, func(t *testing.T) {
|
|
|
|
- cleanUp()
|
|
|
|
- for i, machineID := range testCase.machineIDs {
|
|
|
|
- api.dbClient.Ent.Machine.Create().
|
|
|
|
|
|
+ for _, tc := range tests {
|
|
|
|
+ tc := tc
|
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
|
|
+ apiClient := getAPIC(t)
|
|
|
|
+ cleanUp(apiClient)
|
|
|
|
+ for i, machineID := range tc.machineIDs {
|
|
|
|
+ apiClient.dbClient.Ent.Machine.Create().
|
|
SetMachineId(machineID).
|
|
SetMachineId(machineID).
|
|
SetPassword(testPassword.String()).
|
|
SetPassword(testPassword.String()).
|
|
SetIpAddress(fmt.Sprintf("1.2.3.%d", i)).
|
|
SetIpAddress(fmt.Sprintf("1.2.3.%d", i)).
|
|
@@ -336,8 +323,8 @@ func TestAPICGetMetrics(t *testing.T) {
|
|
ExecX(context.Background())
|
|
ExecX(context.Background())
|
|
}
|
|
}
|
|
|
|
|
|
- for i, bouncerName := range testCase.bouncers {
|
|
|
|
- api.dbClient.Ent.Bouncer.Create().
|
|
|
|
|
|
+ for i, bouncerName := range tc.bouncers {
|
|
|
|
+ apiClient.dbClient.Ent.Bouncer.Create().
|
|
SetIPAddress(fmt.Sprintf("1.2.3.%d", i)).
|
|
SetIPAddress(fmt.Sprintf("1.2.3.%d", i)).
|
|
SetName(bouncerName).
|
|
SetName(bouncerName).
|
|
SetAPIKey("foobar").
|
|
SetAPIKey("foobar").
|
|
@@ -346,19 +333,17 @@ func TestAPICGetMetrics(t *testing.T) {
|
|
ExecX(context.Background())
|
|
ExecX(context.Background())
|
|
}
|
|
}
|
|
|
|
|
|
- if foundMetrics, err := api.GetMetrics(); err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- } else {
|
|
|
|
- assert.Equal(t, foundMetrics.Bouncers, testCase.expectedMetric.Bouncers)
|
|
|
|
- assert.Equal(t, foundMetrics.Machines, testCase.expectedMetric.Machines)
|
|
|
|
|
|
+ foundMetrics, err := apiClient.GetMetrics()
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
+
|
|
|
|
+ assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers)
|
|
|
|
+ assert.Equal(t, tc.expectedMetric.Machines, foundMetrics.Machines)
|
|
|
|
|
|
- }
|
|
|
|
})
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
func TestCreateAlertsForDecision(t *testing.T) {
|
|
func TestCreateAlertsForDecision(t *testing.T) {
|
|
-
|
|
|
|
httpBfDecisionList := &models.Decision{
|
|
httpBfDecisionList := &models.Decision{
|
|
Origin: &SCOPE_LISTS,
|
|
Origin: &SCOPE_LISTS,
|
|
Scenario: types.StrPtr("crowdsecurity/http-bf"),
|
|
Scenario: types.StrPtr("crowdsecurity/http-bf"),
|
|
@@ -427,10 +412,11 @@ func TestCreateAlertsForDecision(t *testing.T) {
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
- for _, tt := range tests {
|
|
|
|
- t.Run(tt.name, func(t *testing.T) {
|
|
|
|
- if got := createAlertsForDecisions(tt.args.decisions); !reflect.DeepEqual(got, tt.want) {
|
|
|
|
- t.Errorf("createAlertsForDecisions() = %v, want %v", got, tt.want)
|
|
|
|
|
|
+ for _, tc := range tests {
|
|
|
|
+ tc := tc
|
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
|
|
+ if got := createAlertsForDecisions(tc.args.decisions); !reflect.DeepEqual(got, tc.want) {
|
|
|
|
+ t.Errorf("createAlertsForDecisions() = %v, want %v", got, tc.want)
|
|
}
|
|
}
|
|
})
|
|
})
|
|
}
|
|
}
|
|
@@ -503,11 +489,12 @@ func TestFillAlertsWithDecisions(t *testing.T) {
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
- for _, tt := range tests {
|
|
|
|
- t.Run(tt.name, func(t *testing.T) {
|
|
|
|
- add_counters, _ := makeAddAndDeleteCounters()
|
|
|
|
- if got := fillAlertsWithDecisions(tt.args.alerts, tt.args.decisions, add_counters); !reflect.DeepEqual(got, tt.want) {
|
|
|
|
- t.Errorf("fillAlertsWithDecisions() = %v, want %v", got, tt.want)
|
|
|
|
|
|
+ for _, tc := range tests {
|
|
|
|
+ tc := tc
|
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
|
|
+ addCounters, _ := makeAddAndDeleteCounters()
|
|
|
|
+ if got := fillAlertsWithDecisions(tc.args.alerts, tc.args.decisions, addCounters); !reflect.DeepEqual(got, tc.want) {
|
|
|
|
+ t.Errorf("fillAlertsWithDecisions() = %v, want %v", got, tc.want)
|
|
}
|
|
}
|
|
})
|
|
})
|
|
}
|
|
}
|
|
@@ -586,24 +573,19 @@ func TestAPICPullTop(t *testing.T) {
|
|
),
|
|
),
|
|
))
|
|
))
|
|
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
|
|
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
+
|
|
apic, err := apiclient.NewDefaultClient(
|
|
apic, err := apiclient.NewDefaultClient(
|
|
url,
|
|
url,
|
|
"/api",
|
|
"/api",
|
|
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
|
|
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
|
|
nil,
|
|
nil,
|
|
)
|
|
)
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
|
|
api.apiClient = apic
|
|
api.apiClient = apic
|
|
err = api.PullTop()
|
|
err = api.PullTop()
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
|
|
assertTotalDecisionCount(t, api.dbClient, 5)
|
|
assertTotalDecisionCount(t, api.dbClient, 5)
|
|
assertTotalValidDecisionCount(t, api.dbClient, 4)
|
|
assertTotalValidDecisionCount(t, api.dbClient, 4)
|
|
@@ -619,24 +601,23 @@ func TestAPICPullTop(t *testing.T) {
|
|
for _, alert := range alerts {
|
|
for _, alert := range alerts {
|
|
alertScenario[alert.SourceScope]++
|
|
alertScenario[alert.SourceScope]++
|
|
}
|
|
}
|
|
- assert.Equal(t, len(alertScenario), 3)
|
|
|
|
- assert.Equal(t, alertScenario[SCOPE_CAPI_ALIAS], 1)
|
|
|
|
- assert.Equal(t, alertScenario["lists:crowdsecurity/ssh-bf"], 1)
|
|
|
|
- assert.Equal(t, alertScenario["lists:crowdsecurity/http-bf"], 1)
|
|
|
|
|
|
+ assert.Equal(t, 3, len(alertScenario))
|
|
|
|
+ assert.Equal(t, 1, alertScenario[SCOPE_CAPI_ALIAS])
|
|
|
|
+ assert.Equal(t, 1, alertScenario["lists:crowdsecurity/ssh-bf"])
|
|
|
|
+ assert.Equal(t, 1, alertScenario["lists:crowdsecurity/http-bf"])
|
|
|
|
|
|
for _, decisions := range validDecisions {
|
|
for _, decisions := range validDecisions {
|
|
decisionScenarioFreq[decisions.Scenario]++
|
|
decisionScenarioFreq[decisions.Scenario]++
|
|
}
|
|
}
|
|
|
|
|
|
- assert.Equal(t, decisionScenarioFreq["crowdsecurity/http-bf"], 1)
|
|
|
|
- assert.Equal(t, decisionScenarioFreq["crowdsecurity/ssh-bf"], 1)
|
|
|
|
- assert.Equal(t, decisionScenarioFreq["crowdsecurity/test1"], 1)
|
|
|
|
- assert.Equal(t, decisionScenarioFreq["crowdsecurity/test2"], 1)
|
|
|
|
|
|
+ assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/http-bf"], 1)
|
|
|
|
+ assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/ssh-bf"], 1)
|
|
|
|
+ assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test1"], 1)
|
|
|
|
+ assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test2"], 1)
|
|
}
|
|
}
|
|
|
|
|
|
func TestAPICPush(t *testing.T) {
|
|
func TestAPICPush(t *testing.T) {
|
|
-
|
|
|
|
- testCases := []struct {
|
|
|
|
|
|
+ tests := []struct {
|
|
name string
|
|
name string
|
|
alerts []*models.Alert
|
|
alerts []*models.Alert
|
|
expectedCalls int
|
|
expectedCalls int
|
|
@@ -683,14 +664,14 @@ func TestAPICPush(t *testing.T) {
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
|
|
- for _, testCase := range testCases {
|
|
|
|
- t.Run(testCase.name, func(t *testing.T) {
|
|
|
|
|
|
+ for _, tc := range tests {
|
|
|
|
+ tc := tc
|
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
api := getAPIC(t)
|
|
api := getAPIC(t)
|
|
api.pushInterval = time.Millisecond
|
|
api.pushInterval = time.Millisecond
|
|
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
|
|
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
+
|
|
httpmock.Activate()
|
|
httpmock.Activate()
|
|
defer httpmock.DeactivateAndReset()
|
|
defer httpmock.DeactivateAndReset()
|
|
apic, err := apiclient.NewDefaultClient(
|
|
apic, err := apiclient.NewDefaultClient(
|
|
@@ -699,31 +680,28 @@ func TestAPICPush(t *testing.T) {
|
|
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
|
|
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
|
|
nil,
|
|
nil,
|
|
)
|
|
)
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
+
|
|
api.apiClient = apic
|
|
api.apiClient = apic
|
|
httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/signals", httpmock.NewBytesResponder(200, []byte{}))
|
|
httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/signals", httpmock.NewBytesResponder(200, []byte{}))
|
|
go func() {
|
|
go func() {
|
|
- api.alertToPush <- testCase.alerts
|
|
|
|
|
|
+ api.alertToPush <- tc.alerts
|
|
time.Sleep(time.Second)
|
|
time.Sleep(time.Second)
|
|
api.Shutdown()
|
|
api.Shutdown()
|
|
}()
|
|
}()
|
|
- if err := api.Push(); err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
- assert.Equal(t, httpmock.GetTotalCallCount(), testCase.expectedCalls)
|
|
|
|
|
|
+ err = api.Push()
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
+ assert.Equal(t, tc.expectedCalls, httpmock.GetTotalCallCount())
|
|
})
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
func TestAPICSendMetrics(t *testing.T) {
|
|
func TestAPICSendMetrics(t *testing.T) {
|
|
- api := getAPIC(t)
|
|
|
|
- testCases := []struct {
|
|
|
|
|
|
+ tests := []struct {
|
|
name string
|
|
name string
|
|
duration time.Duration
|
|
duration time.Duration
|
|
expectedCalls int
|
|
expectedCalls int
|
|
- setUp func()
|
|
|
|
|
|
+ setUp func(*apic)
|
|
metricsInterval time.Duration
|
|
metricsInterval time.Duration
|
|
}{
|
|
}{
|
|
{
|
|
{
|
|
@@ -731,14 +709,15 @@ func TestAPICSendMetrics(t *testing.T) {
|
|
duration: time.Millisecond * 30,
|
|
duration: time.Millisecond * 30,
|
|
metricsInterval: time.Millisecond * 5,
|
|
metricsInterval: time.Millisecond * 5,
|
|
expectedCalls: 5,
|
|
expectedCalls: 5,
|
|
- setUp: func() {},
|
|
|
|
|
|
+ setUp: func(api *apic) {},
|
|
},
|
|
},
|
|
{
|
|
{
|
|
name: "with some metrics",
|
|
name: "with some metrics",
|
|
duration: time.Millisecond * 30,
|
|
duration: time.Millisecond * 30,
|
|
metricsInterval: time.Millisecond * 5,
|
|
metricsInterval: time.Millisecond * 5,
|
|
expectedCalls: 5,
|
|
expectedCalls: 5,
|
|
- setUp: func() {
|
|
|
|
|
|
+ setUp: func(api *apic) {
|
|
|
|
+ api.dbClient.Ent.Machine.Delete().ExecX(context.Background())
|
|
api.dbClient.Ent.Machine.Create().
|
|
api.dbClient.Ent.Machine.Create().
|
|
SetMachineId("1234").
|
|
SetMachineId("1234").
|
|
SetPassword(testPassword.String()).
|
|
SetPassword(testPassword.String()).
|
|
@@ -748,6 +727,7 @@ func TestAPICSendMetrics(t *testing.T) {
|
|
SetUpdatedAt(time.Time{}).
|
|
SetUpdatedAt(time.Time{}).
|
|
ExecX(context.Background())
|
|
ExecX(context.Background())
|
|
|
|
|
|
|
|
+ api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background())
|
|
api.dbClient.Ent.Bouncer.Create().
|
|
api.dbClient.Ent.Bouncer.Create().
|
|
SetIPAddress("1.2.3.6").
|
|
SetIPAddress("1.2.3.6").
|
|
SetName("someBouncer").
|
|
SetName("someBouncer").
|
|
@@ -758,44 +738,49 @@ func TestAPICSendMetrics(t *testing.T) {
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
- for _, testCase := range testCases {
|
|
|
|
- t.Run(testCase.name, func(t *testing.T) {
|
|
|
|
- api = getAPIC(t)
|
|
|
|
- api.pushInterval = time.Millisecond
|
|
|
|
|
|
+
|
|
|
|
+ httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/metrics/", httpmock.NewBytesResponder(200, []byte{}))
|
|
|
|
+ httpmock.Activate()
|
|
|
|
+ defer httpmock.Deactivate()
|
|
|
|
+
|
|
|
|
+ for _, tc := range tests {
|
|
|
|
+ tc := tc
|
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
|
|
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
- httpmock.Activate()
|
|
|
|
- defer httpmock.DeactivateAndReset()
|
|
|
|
- apic, err := apiclient.NewDefaultClient(
|
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
+
|
|
|
|
+ apiClient, err := apiclient.NewDefaultClient(
|
|
url,
|
|
url,
|
|
"/api",
|
|
"/api",
|
|
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
|
|
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
|
|
nil,
|
|
nil,
|
|
)
|
|
)
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
- api.apiClient = apic
|
|
|
|
- api.metricsInterval = testCase.metricsInterval
|
|
|
|
- httpmock.RegisterNoResponder(httpmock.NewBytesResponder(200, []byte{}))
|
|
|
|
- testCase.setUp()
|
|
|
|
|
|
+ require.NoError(t, err)
|
|
|
|
|
|
- go func() {
|
|
|
|
- if err := api.SendMetrics(); err != nil {
|
|
|
|
- panic(err)
|
|
|
|
- }
|
|
|
|
- }()
|
|
|
|
- time.Sleep(testCase.duration)
|
|
|
|
- assert.LessOrEqual(t, absDiff(testCase.expectedCalls, httpmock.GetTotalCallCount()), 2)
|
|
|
|
|
|
+ api := getAPIC(t)
|
|
|
|
+ api.pushInterval = time.Millisecond
|
|
|
|
+ api.apiClient = apiClient
|
|
|
|
+ api.metricsInterval = tc.metricsInterval
|
|
|
|
+ tc.setUp(api)
|
|
|
|
+
|
|
|
|
+ stop := make(chan bool)
|
|
|
|
+ httpmock.ZeroCallCounters()
|
|
|
|
+ go api.SendMetrics(stop)
|
|
|
|
+ time.Sleep(tc.duration)
|
|
|
|
+ stop <- true
|
|
|
|
+
|
|
|
|
+ info := httpmock.GetCallCountInfo()
|
|
|
|
+ noResponderCalls := info["NO_RESPONDER"]
|
|
|
|
+ responderCalls := info["POST http://api.crowdsec.net/api/metrics/"]
|
|
|
|
+ assert.LessOrEqual(t, absDiff(tc.expectedCalls, responderCalls), 2)
|
|
|
|
+ assert.Zero(t, noResponderCalls)
|
|
})
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
func TestAPICPull(t *testing.T) {
|
|
func TestAPICPull(t *testing.T) {
|
|
api := getAPIC(t)
|
|
api := getAPIC(t)
|
|
- testCases := []struct {
|
|
|
|
|
|
+ tests := []struct {
|
|
name string
|
|
name string
|
|
setUp func()
|
|
setUp func()
|
|
expectedDecisionCount int
|
|
expectedDecisionCount int
|
|
@@ -820,14 +805,13 @@ func TestAPICPull(t *testing.T) {
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
|
|
- for _, testCase := range testCases {
|
|
|
|
- t.Run(testCase.name, func(t *testing.T) {
|
|
|
|
|
|
+ for _, tc := range tests {
|
|
|
|
+ tc := tc
|
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
api = getAPIC(t)
|
|
api = getAPIC(t)
|
|
api.pullInterval = time.Millisecond
|
|
api.pullInterval = time.Millisecond
|
|
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
|
|
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
|
|
+ require.NoError(t, err)
|
|
httpmock.Activate()
|
|
httpmock.Activate()
|
|
defer httpmock.DeactivateAndReset()
|
|
defer httpmock.DeactivateAndReset()
|
|
apic, err := apiclient.NewDefaultClient(
|
|
apic, err := apiclient.NewDefaultClient(
|
|
@@ -836,9 +820,7 @@ func TestAPICPull(t *testing.T) {
|
|
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
|
|
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
|
|
nil,
|
|
nil,
|
|
)
|
|
)
|
|
- if err != nil {
|
|
|
|
- t.Fatal(err)
|
|
|
|
- }
|
|
|
|
|
|
+ require.NoError(t, err)
|
|
api.apiClient = apic
|
|
api.apiClient = apic
|
|
httpmock.RegisterNoResponder(httpmock.NewBytesResponder(200, jsonMarshalX(
|
|
httpmock.RegisterNoResponder(httpmock.NewBytesResponder(200, jsonMarshalX(
|
|
models.DecisionsStreamResponse{
|
|
models.DecisionsStreamResponse{
|
|
@@ -854,7 +836,7 @@ func TestAPICPull(t *testing.T) {
|
|
},
|
|
},
|
|
},
|
|
},
|
|
)))
|
|
)))
|
|
- testCase.setUp()
|
|
|
|
|
|
+ tc.setUp()
|
|
var buf bytes.Buffer
|
|
var buf bytes.Buffer
|
|
go func() {
|
|
go func() {
|
|
logrus.SetOutput(&buf)
|
|
logrus.SetOutput(&buf)
|
|
@@ -865,15 +847,14 @@ func TestAPICPull(t *testing.T) {
|
|
//Slightly long because the CI runner for windows are slow, and this can lead to random failure
|
|
//Slightly long because the CI runner for windows are slow, and this can lead to random failure
|
|
time.Sleep(time.Millisecond * 500)
|
|
time.Sleep(time.Millisecond * 500)
|
|
logrus.SetOutput(os.Stderr)
|
|
logrus.SetOutput(os.Stderr)
|
|
- assert.Contains(t, buf.String(), testCase.logContains)
|
|
|
|
- assertTotalDecisionCount(t, api.dbClient, testCase.expectedDecisionCount)
|
|
|
|
|
|
+ assert.Contains(t, buf.String(), tc.logContains)
|
|
|
|
+ assertTotalDecisionCount(t, api.dbClient, tc.expectedDecisionCount)
|
|
})
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
func TestShouldShareAlert(t *testing.T) {
|
|
func TestShouldShareAlert(t *testing.T) {
|
|
-
|
|
|
|
- testCases := []struct {
|
|
|
|
|
|
+ tests := []struct {
|
|
name string
|
|
name string
|
|
consoleConfig *csconfig.ConsoleConfig
|
|
consoleConfig *csconfig.ConsoleConfig
|
|
alert *models.Alert
|
|
alert *models.Alert
|
|
@@ -948,10 +929,11 @@ func TestShouldShareAlert(t *testing.T) {
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
|
|
- for _, testCase := range testCases {
|
|
|
|
- t.Run(testCase.name, func(t *testing.T) {
|
|
|
|
- ret := shouldShareAlert(testCase.alert, testCase.consoleConfig)
|
|
|
|
- assert.Equal(t, ret, testCase.expectedRet)
|
|
|
|
|
|
+ for _, tc := range tests {
|
|
|
|
+ tc := tc
|
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
|
|
+ ret := shouldShareAlert(tc.alert, tc.consoleConfig)
|
|
|
|
+ assert.Equal(t, tc.expectedRet, ret)
|
|
})
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|