Browse Source

New PAPI commands: reauth + force_pull (#2129)

blotus 2 years ago
parent
commit
91eb39cff6

+ 1 - 1
cmd/crowdsec-cli/capi.go

@@ -20,7 +20,7 @@ import (
 	"gopkg.in/yaml.v2"
 )
 
-const CAPIBaseURL string = "https://api.dev.crowdsec.net/"
+const CAPIBaseURL string = "https://api.crowdsec.net/"
 const CAPIURLPrefix = "v3"
 
 func NewCapiCmd() *cobra.Command {

+ 1 - 1
cmd/crowdsec-cli/papi.go

@@ -116,7 +116,7 @@ func NewPapiSyncCmd() *cobra.Command {
 			}
 			t.Go(papi.SyncDecisions)
 
-			err = papi.PullOnce(time.Time{})
+			err = papi.PullOnce(time.Time{}, true)
 
 			if err != nil {
 				log.Fatalf("unable to sync decisions: %s", err)

+ 1 - 2
cmd/crowdsec/serve.go

@@ -18,7 +18,7 @@ import (
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 )
 
-//nolint: deadcode,unused // debugHandler is kept as a dev convenience: it shuts down and serialize internal state
+//nolint:deadcode,unused // debugHandler is kept as a dev convenience: it shuts down and serialize internal state
 func debugHandler(sig os.Signal, cConfig *csconfig.Config) error {
 	var (
 		tmpFile string
@@ -356,7 +356,6 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e
 		if !sent || err != nil {
 			log.Errorf("Failed to notify(sent: %v): %v", sent, err)
 		}
-
 		// wait for signals
 		return HandleSignals(cConfig)
 	}

+ 22 - 7
pkg/apiserver/apic.go

@@ -61,6 +61,7 @@ type apic struct {
 	credentials   *csconfig.ApiCredentialsCfg
 	scenarioList  []string
 	consoleConfig *csconfig.ConsoleConfig
+	isPulling     chan bool
 	whitelists    *csconfig.CapiWhitelist
 }
 
@@ -171,6 +172,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con
 		pushIntervalFirst:    randomDuration(pushIntervalDefault, pushIntervalDelta),
 		metricsInterval:      metricsIntervalDefault,
 		metricsIntervalFirst: randomDuration(metricsIntervalDefault, metricsIntervalDelta),
+		isPulling:            make(chan bool, 1),
 		whitelists:           apicWhitelist,
 	}
 
@@ -537,13 +539,26 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio
 // we receive a list of decisions and links for blocklist and we need to create a list of alerts :
 // one alert for "community blocklist"
 // one alert per list we're subscribed to
-func (a *apic) PullTop() error {
+func (a *apic) PullTop(forcePull bool) error {
 	var err error
 
-	if lastPullIsOld, err := a.CAPIPullIsOld(); err != nil {
-		return err
-	} else if !lastPullIsOld {
-		return nil
+	//A mutex with TryLock would be a bit simpler
+	//But go does not guarantee that TryLock will be able to acquire the lock even if it is available
+	select {
+	case a.isPulling <- true:
+		defer func() {
+			<-a.isPulling
+		}()
+	default:
+		return errors.New("pull already in progress")
+	}
+
+	if !forcePull {
+		if lastPullIsOld, err := a.CAPIPullIsOld(); err != nil {
+			return err
+		} else if !lastPullIsOld {
+			return nil
+		}
 	}
 
 	log.Infof("Starting community-blocklist update")
@@ -780,7 +795,7 @@ func (a *apic) Pull() error {
 		}
 		time.Sleep(1 * time.Second)
 	}
-	if err := a.PullTop(); err != nil {
+	if err := a.PullTop(false); err != nil {
 		log.Errorf("capi pull top: %s", err)
 	}
 
@@ -791,7 +806,7 @@ func (a *apic) Pull() error {
 		select {
 		case <-ticker.C:
 			ticker.Reset(a.pullInterval)
-			if err := a.PullTop(); err != nil {
+			if err := a.PullTop(false); err != nil {
 				log.Errorf("capi pull top: %s", err)
 				continue
 			}

+ 6 - 5
pkg/apiserver/apic_test.go

@@ -64,6 +64,7 @@ func getAPIC(t *testing.T) *apic {
 			ShareCustomScenarios:  types.BoolPtr(false),
 			ShareContext:          types.BoolPtr(false),
 		},
+		isPulling: make(chan bool, 1),
 	}
 }
 
@@ -666,7 +667,7 @@ func TestAPICWhitelists(t *testing.T) {
 	require.NoError(t, err)
 
 	api.apiClient = apic
-	err = api.PullTop()
+	err = api.PullTop(false)
 	require.NoError(t, err)
 
 	assertTotalDecisionCount(t, api.dbClient, 5) //2 from FIRE + 2 from bl + 1 existing
@@ -797,7 +798,7 @@ func TestAPICPullTop(t *testing.T) {
 	require.NoError(t, err)
 
 	api.apiClient = apic
-	err = api.PullTop()
+	err = api.PullTop(false)
 	require.NoError(t, err)
 
 	assertTotalDecisionCount(t, api.dbClient, 5)
@@ -879,7 +880,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
 	require.NoError(t, err)
 
 	api.apiClient = apic
-	err = api.PullTop()
+	err = api.PullTop(false)
 	require.NoError(t, err)
 
 	blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *types.StrPtr("blocklist1"))
@@ -892,7 +893,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
 		assert.NotEqual(t, "", req.Header.Get("If-Modified-Since"))
 		return httpmock.NewStringResponse(304, ""), nil
 	})
-	err = api.PullTop()
+	err = api.PullTop(false)
 	require.NoError(t, err)
 	secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName)
 	require.NoError(t, err)
@@ -966,7 +967,7 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) {
 	require.NoError(t, err)
 
 	api.apiClient = apic
-	err = api.PullTop()
+	err = api.PullTop(false)
 	require.NoError(t, err)
 }
 

+ 3 - 0
pkg/apiserver/apiserver.go

@@ -420,6 +420,9 @@ func (s *APIServer) Close() {
 	if s.apic != nil {
 		s.apic.Shutdown() // stop apic first since it use dbClient
 	}
+	if s.papi != nil {
+		s.papi.Shutdown() // papi also uses the dbClient
+	}
 	s.dbClient.Ent.Close()
 	if s.flushScheduler != nil {
 		s.flushScheduler.Stop()

+ 12 - 10
pkg/apiserver/papi.go

@@ -15,7 +15,6 @@ import (
 	"github.com/crowdsecurity/crowdsec/pkg/models"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	"github.com/pkg/errors"
-	"github.com/sirupsen/logrus"
 	log "github.com/sirupsen/logrus"
 	"gopkg.in/tomb.v2"
 )
@@ -29,9 +28,10 @@ const (
 )
 
 var (
-	operationMap = map[string]func(*Message, *Papi) error{
-		"decision": DecisionCmd,
-		"alert":    AlertCmd,
+	operationMap = map[string]func(*Message, *Papi, bool) error{
+		"decision":   DecisionCmd,
+		"alert":      AlertCmd,
+		"management": ManagementCmd,
 	}
 )
 
@@ -71,6 +71,7 @@ type Papi struct {
 	SyncInterval  time.Duration
 	consoleConfig *csconfig.ConsoleConfig
 	Logger        *log.Entry
+	apic          *apic
 }
 
 type PapiPermCheckError struct {
@@ -85,7 +86,7 @@ type PapiPermCheckSuccess struct {
 
 func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, logLevel log.Level) (*Papi, error) {
 
-	logger := logrus.New()
+	logger := log.New()
 	if err := types.ConfigureLogger(logger); err != nil {
 		return &Papi{}, fmt.Errorf("creating papi logger: %s", err)
 	}
@@ -118,6 +119,7 @@ func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.Cons
 		pullTomb:      tomb.Tomb{},
 		syncTomb:      tomb.Tomb{},
 		apiClient:     apic.apiClient,
+		apic:          apic,
 		consoleConfig: consoleConfig,
 		Logger:        logger.WithFields(log.Fields{"interval": SyncInterval.Seconds(), "source": "papi"}),
 	}
@@ -125,7 +127,7 @@ func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.Cons
 	return papi, nil
 }
 
-func (p *Papi) handleEvent(event longpollclient.Event) error {
+func (p *Papi) handleEvent(event longpollclient.Event, sync bool) error {
 	logger := p.Logger.WithField("request-id", event.RequestId)
 	logger.Debugf("message received: %+v", event.Data)
 	message := &Message{}
@@ -141,7 +143,7 @@ func (p *Papi) handleEvent(event longpollclient.Event) error {
 
 	if operationFunc, ok := operationMap[message.Header.OperationType]; ok {
 		logger.Debugf("Calling operation '%s'", message.Header.OperationType)
-		err := operationFunc(message, p)
+		err := operationFunc(message, p, sync)
 		if err != nil {
 			return fmt.Errorf("'%s %s failed: %s", message.Header.OperationType, message.Header.OperationCmd, err)
 		}
@@ -192,7 +194,7 @@ func reverse(s []longpollclient.Event) []longpollclient.Event {
 	return a
 }
 
-func (p *Papi) PullOnce(since time.Time) error {
+func (p *Papi) PullOnce(since time.Time, sync bool) error {
 	events, err := p.Client.PullOnce(since)
 	if err != nil {
 		return err
@@ -202,7 +204,7 @@ func (p *Papi) PullOnce(since time.Time) error {
 	eventsCount := len(events)
 	p.Logger.Infof("received %d events", eventsCount)
 	for i, event := range reversedEvents {
-		if err := p.handleEvent(event); err != nil {
+		if err := p.handleEvent(event, sync); err != nil {
 			p.Logger.WithField("request-id", event.RequestId).Errorf("failed to handle event: %s", err)
 		}
 		p.Logger.Debugf("handled event %d/%d", i, eventsCount)
@@ -251,7 +253,7 @@ func (p *Papi) Pull() error {
 			return errors.Wrap(err, "failed to marshal last timestamp")
 		}
 
-		err = p.handleEvent(event)
+		err = p.handleEvent(event, false)
 		if err != nil {
 			logger.Errorf("failed to handle event: %s", err)
 			continue

+ 24 - 2
pkg/apiserver/papi_cmd.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"time"
 
+	"github.com/crowdsecurity/crowdsec/pkg/apiclient"
 	"github.com/crowdsecurity/crowdsec/pkg/models"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	"github.com/pkg/errors"
@@ -16,7 +17,7 @@ type deleteDecisions struct {
 	Decisions []string `json:"decisions"`
 }
 
-func DecisionCmd(message *Message, p *Papi) error {
+func DecisionCmd(message *Message, p *Papi, sync bool) error {
 	switch message.Header.OperationCmd {
 	case "delete":
 
@@ -64,7 +65,7 @@ func DecisionCmd(message *Message, p *Papi) error {
 	return nil
 }
 
-func AlertCmd(message *Message, p *Papi) error {
+func AlertCmd(message *Message, p *Papi, sync bool) error {
 	switch message.Header.OperationCmd {
 	case "add":
 		data, err := json.Marshal(message.Data)
@@ -130,3 +131,24 @@ func AlertCmd(message *Message, p *Papi) error {
 
 	return nil
 }
+
+func ManagementCmd(message *Message, p *Papi, sync bool) error {
+	if sync {
+		log.Infof("Ignoring management command from PAPI in sync mode")
+		return nil
+	}
+	switch message.Header.OperationCmd {
+	case "reauth":
+		log.Infof("Received reauth command from PAPI, resetting token")
+		p.apiClient.GetClient().Transport.(*apiclient.JWTTransport).ResetToken()
+	case "force_pull":
+		log.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists")
+		err := p.apic.PullTop(true)
+		if err != nil {
+			return fmt.Errorf("failed to force pull operation: %s", err)
+		}
+	default:
+		return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType)
+	}
+	return nil
+}