Browse Source

manage force_pull message for one blocklist (#2615)

* manage force_pull message for one blocklist

* fix info message on force pull blocklist
Cristian Nitescu 1 year ago
parent
commit
7c5cbef51a
3 changed files with 98 additions and 10 deletions
  1. 22 7
      pkg/apiserver/apic.go
  2. 31 0
      pkg/apiserver/apic_test.go
  3. 45 3
      pkg/apiserver/papi_cmd.go

+ 22 - 7
pkg/apiserver/apic.go

@@ -618,12 +618,23 @@ func (a *apic) PullTop(forcePull bool) error {
 	}
 	}
 
 
 	// update blocklists
 	// update blocklists
-	if err := a.UpdateBlocklists(data.Links, add_counters); err != nil {
+	if err := a.UpdateBlocklists(data.Links, add_counters, forcePull); err != nil {
 		return fmt.Errorf("while updating blocklists: %w", err)
 		return fmt.Errorf("while updating blocklists: %w", err)
 	}
 	}
 	return nil
 	return nil
 }
 }
 
 
+// we receive a link to a blocklist, we pull the content of the blocklist and we create one alert
+func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error {
+	add_counters, _ := makeAddAndDeleteCounters()
+	if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{
+		Blocklists: []*modelscapi.BlocklistLink{blocklist},
+	}, add_counters, forcePull); err != nil {
+		return fmt.Errorf("while pulling blocklist: %w", err)
+	}
+	return nil
+}
+
 // if decisions is whitelisted: return representation of the whitelist ip or cidr
 // if decisions is whitelisted: return representation of the whitelist ip or cidr
 // if not whitelisted: empty string
 // if not whitelisted: empty string
 func (a *apic) whitelistedBy(decision *models.Decision) string {
 func (a *apic) whitelistedBy(decision *models.Decision) string {
@@ -710,7 +721,7 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo
 	return false, nil
 	return false, nil
 }
 }
 
 
-func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, add_counters map[string]map[string]int) error {
+func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, add_counters map[string]map[string]int, forcePull bool) error {
 	if blocklist.Scope == nil {
 	if blocklist.Scope == nil {
 		log.Warningf("blocklist has no scope")
 		log.Warningf("blocklist has no scope")
 		return nil
 		return nil
@@ -719,12 +730,16 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
 		log.Warningf("blocklist has no duration")
 		log.Warningf("blocklist has no duration")
 		return nil
 		return nil
 	}
 	}
-	forcePull, err := a.ShouldForcePullBlocklist(blocklist)
-	if err != nil {
-		return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err)
+	if !forcePull {
+		_forcePull, err := a.ShouldForcePullBlocklist(blocklist)
+		if err != nil {
+			return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err)
+		}
+		forcePull = _forcePull
 	}
 	}
 	blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
 	blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
 	var lastPullTimestamp *string
 	var lastPullTimestamp *string
+	var err error
 	if !forcePull {
 	if !forcePull {
 		lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
 		lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
 		if err != nil {
 		if err != nil {
@@ -764,7 +779,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
 	return nil
 	return nil
 }
 }
 
 
-func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, add_counters map[string]map[string]int) error {
+func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, add_counters map[string]map[string]int, forcePull bool) error {
 	if links == nil {
 	if links == nil {
 		return nil
 		return nil
 	}
 	}
@@ -778,7 +793,7 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink
 		return fmt.Errorf("while creating default client: %w", err)
 		return fmt.Errorf("while creating default client: %w", err)
 	}
 	}
 	for _, blocklist := range links.Blocklists {
 	for _, blocklist := range links.Blocklists {
-		if err := a.updateBlocklist(defaultClient, blocklist, add_counters); err != nil {
+		if err := a.updateBlocklist(defaultClient, blocklist, add_counters, forcePull); err != nil {
 			return err
 			return err
 		}
 		}
 	}
 	}

+ 31 - 0
pkg/apiserver/apic_test.go

@@ -973,6 +973,37 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) {
 	require.NoError(t, err)
 	require.NoError(t, err)
 }
 }
 
 
+func TestAPICPullBlocklistCall(t *testing.T) {
+	api := getAPIC(t)
+	httpmock.Activate()
+	defer httpmock.DeactivateAndReset()
+
+	httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
+		assert.Equal(t, "", req.Header.Get("If-Modified-Since"))
+		return httpmock.NewStringResponse(200, "1.2.3.4"), nil
+	})
+	url, err := url.ParseRequestURI("http://api.crowdsec.net/")
+	require.NoError(t, err)
+
+	apic, err := apiclient.NewDefaultClient(
+		url,
+		"/api",
+		fmt.Sprintf("crowdsec/%s", version.String()),
+		nil,
+	)
+	require.NoError(t, err)
+
+	api.apiClient = apic
+	err = api.PullBlocklist(&modelscapi.BlocklistLink{
+		URL:         ptr.Of("http://api.crowdsec.net/blocklist1"),
+		Name:        ptr.Of("blocklist1"),
+		Scope:       ptr.Of("Ip"),
+		Remediation: ptr.Of("ban"),
+		Duration:    ptr.Of("24h"),
+	}, true)
+	require.NoError(t, err)
+}
+
 func TestAPICPush(t *testing.T) {
 func TestAPICPush(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
 		name          string
 		name          string

+ 45 - 3
pkg/apiserver/papi_cmd.go

@@ -11,6 +11,7 @@ import (
 
 
 	"github.com/crowdsecurity/crowdsec/pkg/apiclient"
 	"github.com/crowdsecurity/crowdsec/pkg/apiclient"
 	"github.com/crowdsecurity/crowdsec/pkg/models"
 	"github.com/crowdsecurity/crowdsec/pkg/models"
+	"github.com/crowdsecurity/crowdsec/pkg/modelscapi"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 )
 )
 
 
@@ -19,6 +20,23 @@ type deleteDecisions struct {
 	Decisions []string `json:"decisions"`
 	Decisions []string `json:"decisions"`
 }
 }
 
 
+type blocklistLink struct {
+	// blocklist name
+	Name string `json:"name"`
+	// blocklist url
+	Url string `json:"url"`
+	// blocklist remediation
+	Remediation string `json:"remediation"`
+	// blocklist scope
+	Scope string `json:"scope,omitempty"`
+	// blocklist duration
+	Duration string `json:"duration,omitempty"`
+}
+
+type forcePull struct {
+	Blocklist *blocklistLink `json:"blocklist,omitempty"`
+}
+
 func DecisionCmd(message *Message, p *Papi, sync bool) error {
 func DecisionCmd(message *Message, p *Papi, sync bool) error {
 	switch message.Header.OperationCmd {
 	switch message.Header.OperationCmd {
 	case "delete":
 	case "delete":
@@ -144,11 +162,35 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
 		log.Infof("Received reauth command from PAPI, resetting token")
 		log.Infof("Received reauth command from PAPI, resetting token")
 		p.apiClient.GetClient().Transport.(*apiclient.JWTTransport).ResetToken()
 		p.apiClient.GetClient().Transport.(*apiclient.JWTTransport).ResetToken()
 	case "force_pull":
 	case "force_pull":
-		log.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists")
-		err := p.apic.PullTop(true)
+		data, err := json.Marshal(message.Data)
 		if err != nil {
 		if err != nil {
-			return fmt.Errorf("failed to force pull operation: %s", err)
+			return err
+		}
+		forcePullMsg := forcePull{}
+		if err := json.Unmarshal(data, &forcePullMsg); err != nil {
+			return fmt.Errorf("message for '%s' contains bad data format: %s", message.Header.OperationType, err)
 		}
 		}
+
+		if forcePullMsg.Blocklist == nil {
+			log.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists")
+			err = p.apic.PullTop(true)
+			if err != nil {
+				return fmt.Errorf("failed to force pull operation: %s", err)
+			}
+		} else {
+			log.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name)
+			err = p.apic.PullBlocklist(&modelscapi.BlocklistLink{
+				Name:        &forcePullMsg.Blocklist.Name,
+				URL:         &forcePullMsg.Blocklist.Url,
+				Remediation: &forcePullMsg.Blocklist.Remediation,
+				Scope:       &forcePullMsg.Blocklist.Scope,
+				Duration:    &forcePullMsg.Blocklist.Duration,
+			}, true)
+			if err != nil {
+				return fmt.Errorf("failed to force pull operation: %s", err)
+			}
+		}
+
 	default:
 	default:
 		return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType)
 		return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType)
 	}
 	}