Browse Source

fix null deref in cti calls if key is empty (#2540)

* fix null deref in cti calls if key is empty

* avoid hardcoded error check
Thibault "bui" Koechlin 1 year ago
parent
commit
a4dc5053d2
2 changed files with 21 additions and 18 deletions
  1. 10 17
      pkg/exprhelpers/crowdsec_cti.go
  2. 11 1
      pkg/exprhelpers/crowdsec_cti_test.go

+ 10 - 17
pkg/exprhelpers/crowdsec_cti.go

@@ -16,7 +16,7 @@ var CTIUrlSuffix = "/v2/smoke/"
 var CTIApiKey = ""
 
 // this is set for non-recoverable errors, such as 403 when querying API or empty API key
-var CTIApiEnabled = true
+var CTIApiEnabled = false
 
 // when hitting quotas or auth errors, we temporarily disable the API
 var CTIBackOffUntil time.Time
@@ -25,9 +25,9 @@ var CTIBackOffDuration time.Duration = 5 * time.Minute
 var ctiClient *cticlient.CrowdsecCTIClient
 
 func InitCrowdsecCTI(Key *string, TTL *time.Duration, Size *int, LogLevel *log.Level) error {
-	if Key == nil {
-		CTIApiEnabled = false
-		return fmt.Errorf("CTI API key not set, CTI will not be available")
+	if Key == nil || *Key == "" {
+		log.Warningf("CTI API key not set or empty, CTI will not be available")
+		return cticlient.ErrDisabled
 	}
 	CTIApiKey = *Key
 	if Size == nil {
@@ -38,7 +38,6 @@ func InitCrowdsecCTI(Key *string, TTL *time.Duration, Size *int, LogLevel *log.L
 		TTL = new(time.Duration)
 		*TTL = 5 * time.Minute
 	}
-	//dedicated logger
 	clog := log.New()
 	if err := types.ConfigureLogger(clog); err != nil {
 		return errors.Wrap(err, "while configuring datasource logger")
@@ -52,6 +51,7 @@ func InitCrowdsecCTI(Key *string, TTL *time.Duration, Size *int, LogLevel *log.L
 	subLogger := clog.WithFields(customLog)
 	CrowdsecCTIInitCache(*Size, *TTL)
 	ctiClient = cticlient.NewCrowdsecCTIClient(cticlient.WithAPIKey(CTIApiKey), cticlient.WithLogger(subLogger))
+	CTIApiEnabled = true
 	return nil
 }
 
@@ -60,7 +60,7 @@ func ShutdownCrowdsecCTI() {
 		CTICache.Purge()
 	}
 	CTIApiKey = ""
-	CTIApiEnabled = true
+	CTIApiEnabled = false
 }
 
 // Cache for responses
@@ -74,20 +74,13 @@ func CrowdsecCTIInitCache(size int, ttl time.Duration) {
 
 // func CrowdsecCTI(ip string) (*cticlient.SmokeItem, error) {
 func CrowdsecCTI(params ...any) (any, error) {
-	ip := params[0].(string)
+	var ip string
 	if !CTIApiEnabled {
-		ctiClient.Logger.Warningf("Crowdsec CTI API is disabled, please check your configuration")
 		return &cticlient.SmokeItem{}, cticlient.ErrDisabled
 	}
-
-	if CTIApiKey == "" {
-		ctiClient.Logger.Warningf("CrowdsecCTI : no key provided, skipping")
-		return &cticlient.SmokeItem{}, cticlient.ErrDisabled
-	}
-
-	if ctiClient == nil {
-		ctiClient.Logger.Warningf("CrowdsecCTI: no client, skipping")
-		return &cticlient.SmokeItem{}, cticlient.ErrDisabled
+	var ok bool
+	if ip, ok = params[0].(string); !ok {
+		return &cticlient.SmokeItem{}, fmt.Errorf("invalid type for ip : %T", params[0])
 	}
 
 	if val, err := CTICache.Get(ip); err == nil && val != nil {

+ 11 - 1
pkg/exprhelpers/crowdsec_cti_test.go

@@ -106,6 +106,16 @@ func smokeHandler(req *http.Request) *http.Response {
 	}
 }
 
+func TestNillClient(t *testing.T) {
+	defer ShutdownCrowdsecCTI()
+	if err := InitCrowdsecCTI(ptr.Of(""), nil, nil, nil); err != cticlient.ErrDisabled {
+		t.Fatalf("failed to init CTI : %s", err)
+	}
+	item, err := CrowdsecCTI("1.2.3.4")
+	assert.Equal(t, err, cticlient.ErrDisabled)
+	assert.Equal(t, item, &cticlient.SmokeItem{})
+}
+
 func TestInvalidAuth(t *testing.T) {
 	defer ShutdownCrowdsecCTI()
 	if err := InitCrowdsecCTI(ptr.Of("asdasd"), nil, nil, nil); err != nil {
@@ -135,7 +145,7 @@ func TestInvalidAuth(t *testing.T) {
 func TestNoKey(t *testing.T) {
 	defer ShutdownCrowdsecCTI()
 	err := InitCrowdsecCTI(nil, nil, nil, nil)
-	assert.ErrorContains(t, err, "CTI API key not set")
+	assert.ErrorIs(t, err, cticlient.ErrDisabled)
 	//Replace the client created by InitCrowdsecCTI with one that uses a custom transport
 	ctiClient = cticlient.NewCrowdsecCTIClient(cticlient.WithAPIKey("asdasd"), cticlient.WithHTTPClient(&http.Client{
 		Transport: RoundTripFunc(smokeHandler),