fix tests

This commit is contained in:
marco 2024-02-07 12:23:23 +01:00
parent 084c01ddc7
commit 75ff4e1e31
2 changed files with 28 additions and 26 deletions

View file

@ -109,26 +109,28 @@ func CrowdsecCTI(params ...any) (any, error) {
ctiLogger.Infof("cti call for %s", ip)
before := time.Now()
ctx := context.Background()
ctx := context.Background() // XXX: timeout?
ctiResp, err := ctiClient.GetSmokeIpWithResponse(ctx, ip)
ctiLogger.Debugf("request for %s took %v", ip, time.Since(before))
// fmt.Printf("response code: %d", ctiResp.HTTPResponse.StatusCode)
// litter.Dump(string(ctiResp.Body))
if err != nil {
switch {
case ctiResp.HTTPResponse != nil && ctiResp.HTTPResponse.StatusCode == 403:
CTIApiEnabled = false
ctiLogger.Errorf("Invalid API key provided, disabling CTI API")
return &cti.CTIObject{}, cti.ErrUnauthorized
case ctiResp.HTTPResponse != nil && ctiResp.HTTPResponse.StatusCode == 429:
CTIBackOffUntil = time.Now().Add(CTIBackOffDuration)
ctiLogger.Errorf("CTI API is throttled, will try again in %s", CTIBackOffDuration)
return &cti.CTIObject{}, cti.ErrLimit
default:
ctiLogger.Warnf("CTI API error : %s", err)
return &cti.CTIObject{}, fmt.Errorf("unexpected error : %v", err)
}
ctiLogger.Warnf("CTI API error: %s", err)
return &cti.CTIObject{}, fmt.Errorf("unexpected error: %w", err)
}
switch {
case ctiResp.HTTPResponse != nil && ctiResp.HTTPResponse.StatusCode == 403:
fmt.Printf("403 error, disabling CTI API\n")
CTIApiEnabled = false
ctiLogger.Errorf("Invalid API key provided, disabling CTI API")
return &cti.CTIObject{}, cti.ErrUnauthorized
case ctiResp.HTTPResponse != nil && ctiResp.HTTPResponse.StatusCode == 429:
CTIBackOffUntil = time.Now().Add(CTIBackOffDuration)
ctiLogger.Errorf("CTI API is throttled, will try again in %s", CTIBackOffDuration)
return &cti.CTIObject{}, cti.ErrLimit
case ctiResp.HTTPResponse != nil && ctiResp.HTTPResponse.StatusCode != 200:
ctiLogger.Warnf("CTI API error: %s", ctiResp.HTTPResponse.Status)
return &cti.CTIObject{}, fmt.Errorf("unexpected error: %s", ctiResp.HTTPResponse.Status)
}
if err := CTICache.SetWithExpire(ip, ctiResp, CacheExpiration); err != nil {
@ -136,7 +138,7 @@ func CrowdsecCTI(params ...any) (any, error) {
return &cti.CTIObject{}, cti.ErrUnknown
}
ctiLogger.Tracef("CTI response: %v", *ctiResp)
ctiLogger.Tracef("CTI response: %s", ctiResp.Body)
var ctiObject cti.CTIObject
if err := json.Unmarshal(ctiResp.Body, &ctiObject); err != nil {

View file

@ -124,24 +124,24 @@ func TestNillClient(t *testing.T) {
func TestInvalidAuth(t *testing.T) {
defer ShutdownCrowdsecCTI()
if err := InitCrowdsecCTI(ptr.Of("asdasd"), nil, nil, nil); err != nil {
badKey := "asdasd"
if err := InitCrowdsecCTI(&badKey, nil, nil, nil); err != nil {
t.Fatalf("failed to init CTI : %s", err)
}
var err error
//Replace the client created by InitCrowdsecCTI with one that uses a custom transport
ctiClient, err = cti.NewClientWithResponses(CTIUrl+"/v2/", cti.WithRequestEditorFn(cti.APIKeyInserter(validApiKey)), cti.WithHTTPClient(&http.Client{
ctiClient, err = cti.NewClientWithResponses(CTIUrl+"/v2/", cti.WithRequestEditorFn(cti.APIKeyInserter(badKey)), cti.WithHTTPClient(&http.Client{
Transport: RoundTripFunc(smokeHandler),
}))
require.NoError(t, err)
assert.True(t, CTIApiEnabled)
item, err := CrowdsecCTI("1.2.3.4")
// require.False(t, CTIApiEnabled)
// require.ErrorIs(t, err, cti.ErrUnauthorized)
require.Equal(t, &cti.CTIObject{Ip: "1.2.3.4"}, item)
// require.Equal(t, &cti.CTIObject{}, item)
require.False(t, CTIApiEnabled)
require.ErrorIs(t, err, cti.ErrUnauthorized)
require.Equal(t, &cti.CTIObject{}, item)
//CTI is now disabled, all requests should return empty
ctiClient, err = cti.NewClientWithResponses(CTIUrl+"/v2/", cti.WithRequestEditorFn(cti.APIKeyInserter(validApiKey)), cti.WithHTTPClient(&http.Client{
@ -150,9 +150,9 @@ func TestInvalidAuth(t *testing.T) {
require.NoError(t, err)
item, err = CrowdsecCTI("1.2.3.4")
// assert.Equal(t, item, &cti.CTIObject{})
// assert.False(t, CTIApiEnabled)
// assert.Equal(t, err, cti.ErrDisabled)
assert.Equal(t, item, &cti.CTIObject{})
assert.False(t, CTIApiEnabled)
assert.Equal(t, err, cti.ErrDisabled)
}
func TestNoKey(t *testing.T) {