Ver código fonte

optimize blocklist fetch (#2039)

Cristian Nitescu 2 anos atrás
pai
commit
ecb32d74c6

+ 1 - 1
pkg/apiclient/client.go

@@ -158,7 +158,7 @@ func newResponse(r *http.Response) *Response {
 }
 
 func CheckResponse(r *http.Response) error {
-	if c := r.StatusCode; 200 <= c && c <= 299 {
+	if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 {
 		return nil
 	}
 	errorResponse := &ErrorResponse{}

+ 45 - 15
pkg/apiclient/decisions_service.go

@@ -4,7 +4,6 @@ import (
 	"bufio"
 	"context"
 	"fmt"
-	"io"
 	"net/http"
 
 	"github.com/crowdsecurity/crowdsec/pkg/models"
@@ -150,29 +149,58 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m
 	return &v2Decisions, resp, nil
 }
 
-func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink) ([]*models.Decision, error) {
+func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, lastPullTimestamp *string) ([]*models.Decision, bool, error) {
 	if blocklist.URL == nil {
-		return nil, errors.New("blocklist URL is nil")
+		return nil, false, errors.New("blocklist URL is nil")
 	}
 
 	log.Debugf("Fetching blocklist %s", *blocklist.URL)
 
-	req, err := s.client.NewRequest(http.MethodGet, *blocklist.URL, nil)
+	client := http.Client{}
+	req, err := http.NewRequest(http.MethodGet, *blocklist.URL, nil)
 	if err != nil {
-		return nil, err
+		return nil, false, err
+	}
+
+	if lastPullTimestamp != nil {
+		req.Header.Set("If-Modified-Since", *lastPullTimestamp)
+	}
+	req = req.WithContext(ctx)
+	log.Debugf("[URL] %s %s", req.Method, req.URL)
+	// we dont use client_http Do method because we need the reader and is not provided. We would be forced to use Pipe and goroutine, etc
+	resp, err := client.Do(req)
+	if resp != nil && resp.Body != nil {
+		defer resp.Body.Close()
 	}
 
-	pr, pw := io.Pipe()
-	defer pr.Close()
-	go func() {
-		defer pw.Close()
-		_, err = s.client.Do(ctx, req, pw)
-		if err != nil {
-			log.Errorf("Error fetching blocklist %s: %s", *blocklist.URL, err)
+	if err != nil {
+		// If we got an error, and the context has been canceled,
+		// the context's error is probably more useful.
+		select {
+		case <-ctx.Done():
+			return nil, false, ctx.Err()
+		default:
 		}
-	}()
+
+		// If the error type is *url.Error, sanitize its URL before returning.
+		log.Errorf("Error fetching blocklist %s: %s", *blocklist.URL, err)
+		return nil, false, err
+	}
+
+	if resp.StatusCode == http.StatusNotModified {
+		if lastPullTimestamp != nil {
+			log.Debugf("Blocklist %s has not been modified since %s", *blocklist.URL, *lastPullTimestamp)
+		} else {
+			log.Debugf("Blocklist %s has not been modified (decisions about to expire)", *blocklist.URL)
+		}
+		return nil, false, nil
+	}
+	if resp.StatusCode != http.StatusOK {
+		log.Debugf("Received nok status code %d for blocklist %s", resp.StatusCode, *blocklist.URL)
+		return nil, false, nil
+	}
 	decisions := make([]*models.Decision, 0)
-	scanner := bufio.NewScanner(pr)
+	scanner := bufio.NewScanner(resp.Body)
 	for scanner.Scan() {
 		decision := scanner.Text()
 		decisions = append(decisions, &models.Decision{
@@ -185,7 +213,9 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
 		})
 	}
 
-	return decisions, nil
+	// here the upper go routine is finished because scanner.Scan() is blocking until pw.Close() is called
+	// so it's safe to use the isModified variable here
+	return decisions, true, nil
 }
 
 func (s *DecisionsService) GetStream(ctx context.Context, opts DecisionsStreamOpts) (*models.DecisionsStreamResponse, *Response, error) {

+ 29 - 5
pkg/apiclient/decisions_service_test.go

@@ -11,6 +11,7 @@ import (
 	"github.com/crowdsecurity/crowdsec/pkg/cwversion"
 	"github.com/crowdsecurity/crowdsec/pkg/models"
 	"github.com/crowdsecurity/crowdsec/pkg/modelscapi"
+	"github.com/crowdsecurity/crowdsec/pkg/types"
 	log "github.com/sirupsen/logrus"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
@@ -377,9 +378,11 @@ func TestDecisionsFromBlocklist(t *testing.T) {
 	defer teardown()
 
 	mux.HandleFunc("/blocklist", func(w http.ResponseWriter, r *http.Request) {
-
-		assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu")
 		testMethod(t, r, http.MethodGet)
+		if r.Header.Get("If-Modified-Since") == "Sun, 01 Jan 2023 01:01:01 GMT" {
+			w.WriteHeader(http.StatusNotModified)
+			return
+		}
 		if r.Method == http.MethodGet {
 			w.WriteHeader(http.StatusOK)
 			w.Write([]byte("1.2.3.4\r\n1.2.3.5"))
@@ -407,7 +410,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
 	tnameBlocklist := "blocklist1"
 	tremediationBlocklist := "ban"
 	tscopeBlocklist := "ip"
-	turlBlocklist := "/v3/blocklist"
+	turlBlocklist := urlx + "/v3/blocklist"
 	torigin := "lists"
 	expected := []*models.Decision{
 		{
@@ -427,14 +430,15 @@ func TestDecisionsFromBlocklist(t *testing.T) {
 			Origin:   &torigin,
 		},
 	}
-	decisions, err := newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{
+	decisions, isModified, err := newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{
 		URL:         &turlBlocklist,
 		Scope:       &tscopeBlocklist,
 		Remediation: &tremediationBlocklist,
 		Name:        &tnameBlocklist,
 		Duration:    &tdurationBlocklist,
-	})
+	}, nil)
 	require.NoError(t, err)
+	assert.True(t, isModified)
 
 	log.Infof("decision1: %+v", decisions[0])
 	log.Infof("expected1: %+v", expected[0])
@@ -448,6 +452,26 @@ func TestDecisionsFromBlocklist(t *testing.T) {
 	if !reflect.DeepEqual(decisions, expected) {
 		t.Fatalf("returned %+v, want %+v", decisions, expected)
 	}
+
+	// test cache control
+	_, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{
+		URL:         &turlBlocklist,
+		Scope:       &tscopeBlocklist,
+		Remediation: &tremediationBlocklist,
+		Name:        &tnameBlocklist,
+		Duration:    &tdurationBlocklist,
+	}, types.StrPtr("Sun, 01 Jan 2023 01:01:01 GMT"))
+	require.NoError(t, err)
+	assert.False(t, isModified)
+	_, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{
+		URL:         &turlBlocklist,
+		Scope:       &tscopeBlocklist,
+		Remediation: &tremediationBlocklist,
+		Name:        &tnameBlocklist,
+		Duration:    &tdurationBlocklist,
+	}, types.StrPtr("Mon, 02 Jan 2023 01:01:01 GMT"))
+	require.NoError(t, err)
+	assert.True(t, isModified)
 }
 
 func TestDeleteDecisions(t *testing.T) {

+ 57 - 2
pkg/apiserver/apic.go

@@ -20,6 +20,7 @@ import (
 	"github.com/crowdsecurity/crowdsec/pkg/csconfig"
 	"github.com/crowdsecurity/crowdsec/pkg/cwversion"
 	"github.com/crowdsecurity/crowdsec/pkg/database"
+	"github.com/crowdsecurity/crowdsec/pkg/database/ent"
 	"github.com/crowdsecurity/crowdsec/pkg/database/ent/alert"
 	"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
 	"github.com/crowdsecurity/crowdsec/pkg/models"
@@ -598,6 +599,36 @@ func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, add_counters map[strin
 	return nil
 }
 
+func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bool, error) {
+	// we should force pull if the blocklist decisions are about to expire or there's no decision in the db
+	alertQuery := a.dbClient.Ent.Alert.Query()
+	alertQuery.Where(alert.SourceScopeEQ(fmt.Sprintf("%s:%s", types.ListOrigin, *blocklist.Name)))
+	alertQuery.Order(ent.Desc(alert.FieldCreatedAt))
+	alertInstance, err := alertQuery.First(context.Background())
+	if err != nil {
+		if ent.IsNotFound(err) {
+			log.Debugf("no alert found for %s, force refresh", *blocklist.Name)
+			return true, nil
+		}
+		return false, errors.Wrap(err, "while getting alert")
+	}
+	decisionQuery := a.dbClient.Ent.Decision.Query()
+	decisionQuery.Where(decision.HasOwnerWith(alert.IDEQ(alertInstance.ID)))
+	firstDecision, err := decisionQuery.First(context.Background())
+	if err != nil {
+		if ent.IsNotFound(err) {
+			log.Debugf("no decision found for %s, force refresh", *blocklist.Name)
+			return true, nil
+		}
+		return false, errors.Wrap(err, "while getting decision")
+	}
+	if firstDecision == nil || firstDecision.Until == nil || firstDecision.Until.Sub(time.Now().UTC()) < (a.pullInterval+15*time.Minute) {
+		log.Debugf("at least one decision found for %s, expire soon, force refresh", *blocklist.Name)
+		return true, nil
+	}
+	return false, nil
+}
+
 func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, add_counters map[string]map[string]int) error {
 	if links == nil {
 		return nil
@@ -607,7 +638,7 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink
 	}
 	// we must use a different http client than apiClient's because the transport of apiClient is jwtTransport or here we have signed apis that are incompatibles
 	// we can use the same baseUrl as the urls are absolute and the parse will take care of it
-	defaultClient, err := apiclient.NewDefaultClient(a.apiClient.BaseURL, "", "", &http.Client{})
+	defaultClient, err := apiclient.NewDefaultClient(a.apiClient.BaseURL, "", "", nil)
 	if err != nil {
 		return errors.Wrap(err, "while creating default client")
 	}
@@ -620,10 +651,34 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink
 			log.Warningf("blocklist has no duration")
 			continue
 		}
-		decisions, err := defaultClient.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist)
+		forcePull, err := a.ShouldForcePullBlocklist(blocklist)
+		if err != nil {
+			return errors.Wrapf(err, "while checking if we should force pull blocklist %s", *blocklist.Name)
+		}
+		blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
+		var lastPullTimestamp *string
+		if !forcePull {
+			lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
+			if err != nil {
+				return errors.Wrapf(err, "while getting last pull timestamp for blocklist %s", *blocklist.Name)
+			}
+		}
+		decisions, has_changed, err := defaultClient.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp)
 		if err != nil {
 			return errors.Wrapf(err, "while getting decisions from blocklist %s", *blocklist.Name)
 		}
+		if !has_changed {
+			if lastPullTimestamp == nil {
+				log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name)
+			} else {
+				log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp)
+			}
+			continue
+		}
+		err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
+		if err != nil {
+			return errors.Wrapf(err, "while setting last pull timestamp for blocklist %s", *blocklist.Name)
+		}
 		if len(decisions) == 0 {
 			log.Infof("blocklist %s has no decisions", *blocklist.Name)
 			continue

+ 148 - 6
pkg/apiserver/apic_test.go

@@ -5,6 +5,7 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"net/http"
 	"net/url"
 	"os"
 	"reflect"
@@ -568,14 +569,14 @@ func TestAPICPullTop(t *testing.T) {
 					Blocklists: []*modelscapi.BlocklistLink{
 						{
 							URL:         types.StrPtr("http://api.crowdsec.net/blocklist1"),
-							Name:        types.StrPtr("crowdsecurity/http-bf"),
+							Name:        types.StrPtr("blocklist1"),
 							Scope:       types.StrPtr("Ip"),
 							Remediation: types.StrPtr("ban"),
 							Duration:    types.StrPtr("24h"),
 						},
 						{
 							URL:         types.StrPtr("http://api.crowdsec.net/blocklist2"),
-							Name:        types.StrPtr("crowdsecurity/ssh-bf"),
+							Name:        types.StrPtr("blocklist2"),
 							Scope:       types.StrPtr("Ip"),
 							Remediation: types.StrPtr("ban"),
 							Duration:    types.StrPtr("24h"),
@@ -622,19 +623,160 @@ func TestAPICPullTop(t *testing.T) {
 	}
 	assert.Equal(t, 3, len(alertScenario))
 	assert.Equal(t, 1, alertScenario[SCOPE_CAPI_ALIAS_ALIAS])
-	assert.Equal(t, 1, alertScenario["lists:crowdsecurity/ssh-bf"])
-	assert.Equal(t, 1, alertScenario["lists:crowdsecurity/http-bf"])
+	assert.Equal(t, 1, alertScenario["lists:blocklist1"])
+	assert.Equal(t, 1, alertScenario["lists:blocklist2"])
 
 	for _, decisions := range validDecisions {
 		decisionScenarioFreq[decisions.Scenario]++
 	}
 
-	assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/http-bf"], 1)
-	assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/ssh-bf"], 1)
+	assert.Equal(t, 1, decisionScenarioFreq["blocklist1"], 1)
+	assert.Equal(t, 1, decisionScenarioFreq["blocklist2"], 1)
 	assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test1"], 1)
 	assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test2"], 1)
 }
 
+func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
+	// no decision in db, no last modified parameter.
+	api := getAPIC(t)
+	httpmock.Activate()
+	defer httpmock.DeactivateAndReset()
+	httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
+		200, jsonMarshalX(
+			modelscapi.GetDecisionsStreamResponse{
+				New: modelscapi.GetDecisionsStreamResponseNew{
+					&modelscapi.GetDecisionsStreamResponseNewItem{
+						Scenario: types.StrPtr("crowdsecurity/test1"),
+						Scope:    types.StrPtr("Ip"),
+						Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
+							{
+								Value:    types.StrPtr("1.2.3.4"),
+								Duration: types.StrPtr("24h"),
+							},
+						},
+					},
+				},
+				Links: &modelscapi.GetDecisionsStreamResponseLinks{
+					Blocklists: []*modelscapi.BlocklistLink{
+						{
+							URL:         types.StrPtr("http://api.crowdsec.net/blocklist1"),
+							Name:        types.StrPtr("blocklist1"),
+							Scope:       types.StrPtr("Ip"),
+							Remediation: types.StrPtr("ban"),
+							Duration:    types.StrPtr("24h"),
+						},
+					},
+				},
+			},
+		),
+	))
+	httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
+		assert.Equal(t, "", req.Header.Get("If-Modified-Since"))
+		return httpmock.NewStringResponse(200, "1.2.3.4"), nil
+	})
+	url, err := url.ParseRequestURI("http://api.crowdsec.net/")
+	require.NoError(t, err)
+
+	apic, err := apiclient.NewDefaultClient(
+		url,
+		"/api",
+		fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
+		nil,
+	)
+	require.NoError(t, err)
+
+	api.apiClient = apic
+	err = api.PullTop()
+	require.NoError(t, err)
+
+	blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *types.StrPtr("blocklist1"))
+	lastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName)
+	require.NoError(t, err)
+	assert.NotEqual(t, "", *lastPullTimestamp)
+
+	// new call should return 304 and should not change lastPullTimestamp
+	httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
+		assert.NotEqual(t, "", req.Header.Get("If-Modified-Since"))
+		return httpmock.NewStringResponse(304, ""), nil
+	})
+	err = api.PullTop()
+	require.NoError(t, err)
+	secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName)
+	require.NoError(t, err)
+	assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp)
+}
+
+func TestAPICPullTopBLCacheForceCall(t *testing.T) {
+	api := getAPIC(t)
+	httpmock.Activate()
+	defer httpmock.DeactivateAndReset()
+	// create a decision about to expire. It should force fetch
+	alertInstance := api.dbClient.Ent.Alert.
+		Create().
+		SetScenario("update list").
+		SetSourceScope("list:blocklist1").
+		SetSourceValue("list:blocklist1").
+		SaveX(context.Background())
+
+	api.dbClient.Ent.Decision.Create().
+		SetOrigin(types.ListOrigin).
+		SetType("ban").
+		SetValue("9.9.9.9").
+		SetScope("Ip").
+		SetScenario("blocklist1").
+		SetUntil(time.Now().Add(time.Hour)).
+		SetOwnerID(alertInstance.ID).
+		ExecX(context.Background())
+
+	httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
+		200, jsonMarshalX(
+			modelscapi.GetDecisionsStreamResponse{
+				New: modelscapi.GetDecisionsStreamResponseNew{
+					&modelscapi.GetDecisionsStreamResponseNewItem{
+						Scenario: types.StrPtr("crowdsecurity/test1"),
+						Scope:    types.StrPtr("Ip"),
+						Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
+							{
+								Value:    types.StrPtr("1.2.3.4"),
+								Duration: types.StrPtr("24h"),
+							},
+						},
+					},
+				},
+				Links: &modelscapi.GetDecisionsStreamResponseLinks{
+					Blocklists: []*modelscapi.BlocklistLink{
+						{
+							URL:         types.StrPtr("http://api.crowdsec.net/blocklist1"),
+							Name:        types.StrPtr("blocklist1"),
+							Scope:       types.StrPtr("Ip"),
+							Remediation: types.StrPtr("ban"),
+							Duration:    types.StrPtr("24h"),
+						},
+					},
+				},
+			},
+		),
+	))
+	httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
+		assert.Equal(t, "", req.Header.Get("If-Modified-Since"))
+		return httpmock.NewStringResponse(304, ""), nil
+	})
+	url, err := url.ParseRequestURI("http://api.crowdsec.net/")
+	require.NoError(t, err)
+
+	apic, err := apiclient.NewDefaultClient(
+		url,
+		"/api",
+		fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
+		nil,
+	)
+	require.NoError(t, err)
+
+	api.apiClient = apic
+	err = api.PullTop()
+	require.NoError(t, err)
+}
+
 func TestAPICPush(t *testing.T) {
 	tests := []struct {
 		name          string