S3 acquisition datasource (#2130)
This commit is contained in:
parent
a74e424d53
commit
dc38e5ac00
5 changed files with 1081 additions and 0 deletions
1
go.mod
1
go.mod
|
@ -98,6 +98,7 @@ require (
|
|||
github.com/ahmetalpbalkan/dlog v0.0.0-20170105205344-4fb5f8204f26 // indirect
|
||||
github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect
|
||||
github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef // indirect
|
||||
github.com/aws/aws-lambda-go v1.38.0 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/containerd/containerd v1.6.18 // indirect
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
|
||||
|
|
2
go.sum
2
go.sum
|
@ -107,6 +107,8 @@ github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496/go.mod h1:o
|
|||
github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg=
|
||||
github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef h1:46PFijGLmAjMPwCCCo7Jf0W6f9slllCkkv7vyc1yOSg=
|
||||
github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
|
||||
github.com/aws/aws-lambda-go v1.38.0 h1:4CUdxGzvuQp0o8Zh7KtupB9XvCiiY8yKqJtzco+gsDw=
|
||||
github.com/aws/aws-lambda-go v1.38.0/go.mod h1:jwFe2KmMsHmffA1X2R09hH6lFzJQxzI8qK17ewzbQMM=
|
||||
github.com/aws/aws-sdk-go v1.34.28/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48=
|
||||
github.com/aws/aws-sdk-go v1.42.25 h1:BbdvHAi+t9LRiaYUyd53noq9jcaAcfzOhSVbKfr6Avs=
|
||||
github.com/aws/aws-sdk-go v1.42.25/go.mod h1:gyRszuZ/icHmHAVE4gc/r+cfCmhA1AD+vqfWbgI+eHs=
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
kafkaacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kafka"
|
||||
kinesisacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kinesis"
|
||||
k8sauditacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kubernetesaudit"
|
||||
s3acquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/s3"
|
||||
syslogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog"
|
||||
wineventlogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/wineventlog"
|
||||
|
||||
|
@ -52,6 +53,7 @@ var AcquisitionSources = map[string]func() DataSource{
|
|||
"wineventlog": func() DataSource { return &wineventlogacquisition.WinEventLogSource{} },
|
||||
"kafka": func() DataSource { return &kafkaacquisition.KafkaSource{} },
|
||||
"k8s_audit": func() DataSource { return &k8sauditacquisition.KubernetesAuditSource{} },
|
||||
"s3": func() DataSource { return &s3acquisition.S3Source{} },
|
||||
}
|
||||
|
||||
func GetDataSourceIface(dataSourceType string) DataSource {
|
||||
|
|
647
pkg/acquisition/modules/s3/s3.go
Normal file
647
pkg/acquisition/modules/s3/s3.go
Normal file
|
@ -0,0 +1,647 @@
|
|||
package s3acquisition
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-lambda-go/events"
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3iface"
|
||||
"github.com/aws/aws-sdk-go/service/sqs"
|
||||
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
|
||||
"github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration"
|
||||
"github.com/crowdsecurity/crowdsec/pkg/types"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gopkg.in/tomb.v2"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
type S3Configuration struct {
|
||||
configuration.DataSourceCommonCfg `yaml:",inline"`
|
||||
AwsProfile *string `yaml:"aws_profile"`
|
||||
AwsRegion string `yaml:"aws_region"`
|
||||
AwsEndpoint string `yaml:"aws_endpoint"`
|
||||
BucketName string `yaml:"bucket_name"`
|
||||
Prefix string `yaml:"prefix"`
|
||||
Key string `yaml:"-"` //Only for DSN acquisition
|
||||
PollingMethod string `yaml:"polling_method"`
|
||||
PollingInterval int `yaml:"polling_interval"`
|
||||
SQSName string `yaml:"sqs_name"`
|
||||
SQSFormat string `yaml:"sqs_format"`
|
||||
}
|
||||
|
||||
type S3Source struct {
|
||||
Config S3Configuration
|
||||
logger *log.Entry
|
||||
s3Client s3iface.S3API
|
||||
sqsClient sqsiface.SQSAPI
|
||||
readerChan chan S3Object
|
||||
t *tomb.Tomb
|
||||
out chan types.Event
|
||||
ctx aws.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type S3Object struct {
|
||||
Key string
|
||||
Bucket string
|
||||
}
|
||||
|
||||
// For some reason, the aws sdk doesn't have a struct for this
|
||||
// The one aws-lamdbda-go/events is only intended when using S3 Notification without event bridge
|
||||
type S3Event struct {
|
||||
Version string `json:"version"`
|
||||
Id string `json:"id"`
|
||||
DetailType string `json:"detail-type"`
|
||||
Source string `json:"source"`
|
||||
Account string `json:"account"`
|
||||
Time string `json:"time"`
|
||||
Region string `json:"region"`
|
||||
Resources []string `json:"resources"`
|
||||
Detail struct {
|
||||
Version string `json:"version"`
|
||||
RequestId string `json:"request-id"`
|
||||
Requester string `json:"requester"`
|
||||
Reason string `json:"reason"`
|
||||
SourceIpAddress string `json:"source-ip-address"`
|
||||
Bucket struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"bucket"`
|
||||
Object struct {
|
||||
Key string `json:"key"`
|
||||
Size int `json:"size"`
|
||||
Etag string `json:"etag"`
|
||||
Sequencer string `json:"sequencer"`
|
||||
} `json:"object"`
|
||||
} `json:"detail"`
|
||||
}
|
||||
|
||||
const PollMethodList = "list"
|
||||
const PollMethodSQS = "sqs"
|
||||
const SQSFormatEventBridge = "eventbridge"
|
||||
const SQSFormatS3Notification = "s3notification"
|
||||
|
||||
var linesRead = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "cs_s3_hits_total",
|
||||
Help: "Number of events read per bucket.",
|
||||
},
|
||||
[]string{"bucket"},
|
||||
)
|
||||
|
||||
var objectsRead = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "cs_s3_objects_total",
|
||||
Help: "Number of objects read per bucket.",
|
||||
},
|
||||
[]string{"bucket"},
|
||||
)
|
||||
|
||||
var sqsMessagesReceived = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "cs_s3_sqs_messages_total",
|
||||
Help: "Number of SQS messages received per queue.",
|
||||
},
|
||||
[]string{"queue"},
|
||||
)
|
||||
|
||||
func (s *S3Source) newS3Client() error {
|
||||
options := session.Options{
|
||||
SharedConfigState: session.SharedConfigEnable,
|
||||
}
|
||||
if s.Config.AwsProfile != nil {
|
||||
options.Profile = *s.Config.AwsProfile
|
||||
}
|
||||
|
||||
sess, err := session.NewSessionWithOptions(options)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create aws session: %w", err)
|
||||
}
|
||||
|
||||
config := aws.NewConfig()
|
||||
if s.Config.AwsRegion != "" {
|
||||
config = config.WithRegion(s.Config.AwsRegion)
|
||||
}
|
||||
if s.Config.AwsEndpoint != "" {
|
||||
config = config.WithEndpoint(s.Config.AwsEndpoint)
|
||||
}
|
||||
s.s3Client = s3.New(sess, config)
|
||||
if s.s3Client == nil {
|
||||
return fmt.Errorf("failed to create S3 client")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Source) newSQSClient() error {
|
||||
var sess *session.Session
|
||||
|
||||
if s.Config.AwsProfile != nil {
|
||||
sess = session.Must(session.NewSessionWithOptions(session.Options{
|
||||
SharedConfigState: session.SharedConfigEnable,
|
||||
Profile: *s.Config.AwsProfile,
|
||||
}))
|
||||
} else {
|
||||
sess = session.Must(session.NewSessionWithOptions(session.Options{
|
||||
SharedConfigState: session.SharedConfigEnable,
|
||||
}))
|
||||
}
|
||||
|
||||
if sess == nil {
|
||||
return fmt.Errorf("failed to create aws session")
|
||||
}
|
||||
config := aws.NewConfig()
|
||||
if s.Config.AwsRegion != "" {
|
||||
config = config.WithRegion(s.Config.AwsRegion)
|
||||
}
|
||||
if s.Config.AwsEndpoint != "" {
|
||||
config = config.WithEndpoint(s.Config.AwsEndpoint)
|
||||
}
|
||||
s.sqsClient = sqs.New(sess, config)
|
||||
if s.sqsClient == nil {
|
||||
return fmt.Errorf("failed to create SQS client")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Source) readManager() {
|
||||
logger := s.logger.WithField("method", "readManager")
|
||||
for {
|
||||
select {
|
||||
case <-s.t.Dying():
|
||||
logger.Infof("Shutting down S3 read manager")
|
||||
s.cancel()
|
||||
return
|
||||
case s3Object := <-s.readerChan:
|
||||
logger.Debugf("Reading file %s/%s", s3Object.Bucket, s3Object.Key)
|
||||
err := s.readFile(s3Object.Bucket, s3Object.Key)
|
||||
if err != nil {
|
||||
logger.Errorf("Error while reading file: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3Source) getBucketContent() ([]*s3.Object, error) {
|
||||
logger := s.logger.WithField("method", "getBucketContent")
|
||||
logger.Debugf("Getting bucket content for %s", s.Config.BucketName)
|
||||
bucketObjects := make([]*s3.Object, 0)
|
||||
var continuationToken *string = nil
|
||||
for {
|
||||
out, err := s.s3Client.ListObjectsV2WithContext(s.ctx, &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.Config.BucketName),
|
||||
Prefix: aws.String(s.Config.Prefix),
|
||||
ContinuationToken: continuationToken,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Errorf("Error while listing bucket content: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
bucketObjects = append(bucketObjects, out.Contents...)
|
||||
if out.NextContinuationToken == nil {
|
||||
break
|
||||
}
|
||||
continuationToken = out.NextContinuationToken
|
||||
}
|
||||
sort.Slice(bucketObjects, func(i, j int) bool {
|
||||
return bucketObjects[i].LastModified.Before(*bucketObjects[j].LastModified)
|
||||
})
|
||||
return bucketObjects, nil
|
||||
}
|
||||
|
||||
func (s *S3Source) listPoll() error {
|
||||
logger := s.logger.WithField("method", "listPoll")
|
||||
ticker := time.NewTicker(time.Duration(s.Config.PollingInterval) * time.Second)
|
||||
lastObjectDate := time.Now()
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.t.Dying():
|
||||
logger.Infof("Shutting down list poller")
|
||||
s.cancel()
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
newObject := false
|
||||
bucketObjects, err := s.getBucketContent()
|
||||
if err != nil {
|
||||
logger.Errorf("Error while getting bucket content: %s", err)
|
||||
continue
|
||||
}
|
||||
if bucketObjects == nil {
|
||||
continue
|
||||
}
|
||||
for i := len(bucketObjects) - 1; i >= 0; i-- {
|
||||
if bucketObjects[i].LastModified.After(lastObjectDate) {
|
||||
newObject = true
|
||||
logger.Debugf("Found new object %s", *bucketObjects[i].Key)
|
||||
s.readerChan <- S3Object{
|
||||
Bucket: s.Config.BucketName,
|
||||
Key: *bucketObjects[i].Key,
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
if newObject {
|
||||
lastObjectDate = *bucketObjects[len(bucketObjects)-1].LastModified
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func extractBucketAndPrefixFromEventBridge(message *string) (string, string, error) {
|
||||
eventBody := S3Event{}
|
||||
err := json.Unmarshal([]byte(*message), &eventBody)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
if eventBody.Detail.Bucket.Name != "" {
|
||||
return eventBody.Detail.Bucket.Name, eventBody.Detail.Object.Key, nil
|
||||
}
|
||||
return "", "", fmt.Errorf("invalid event body for event bridge format")
|
||||
}
|
||||
|
||||
func extractBucketAndPrefixFromS3Notif(message *string) (string, string, error) {
|
||||
s3notifBody := events.S3Event{}
|
||||
err := json.Unmarshal([]byte(*message), &s3notifBody)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
if len(s3notifBody.Records) == 0 {
|
||||
return "", "", fmt.Errorf("no records found in S3 notification")
|
||||
}
|
||||
if !strings.HasPrefix(s3notifBody.Records[0].EventName, "ObjectCreated:") {
|
||||
return "", "", fmt.Errorf("event %s is not supported", s3notifBody.Records[0].EventName)
|
||||
}
|
||||
return s3notifBody.Records[0].S3.Bucket.Name, s3notifBody.Records[0].S3.Object.Key, nil
|
||||
}
|
||||
|
||||
func (s *S3Source) extractBucketAndPrefix(message *string) (string, string, error) {
|
||||
if s.Config.SQSFormat == SQSFormatEventBridge {
|
||||
bucket, key, err := extractBucketAndPrefixFromEventBridge(message)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return bucket, key, nil
|
||||
} else if s.Config.SQSFormat == SQSFormatS3Notification {
|
||||
bucket, key, err := extractBucketAndPrefixFromS3Notif(message)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return bucket, key, nil
|
||||
} else {
|
||||
bucket, key, err := extractBucketAndPrefixFromEventBridge(message)
|
||||
if err == nil {
|
||||
s.Config.SQSFormat = SQSFormatEventBridge
|
||||
return bucket, key, nil
|
||||
}
|
||||
bucket, key, err = extractBucketAndPrefixFromS3Notif(message)
|
||||
if err == nil {
|
||||
s.Config.SQSFormat = SQSFormatS3Notification
|
||||
return bucket, key, nil
|
||||
}
|
||||
return "", "", fmt.Errorf("SQS message format not supported")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3Source) sqsPoll() error {
|
||||
logger := s.logger.WithField("method", "sqsPoll")
|
||||
for {
|
||||
select {
|
||||
case <-s.t.Dying():
|
||||
logger.Infof("Shutting down SQS poller")
|
||||
s.cancel()
|
||||
return nil
|
||||
default:
|
||||
logger.Trace("Polling SQS queue")
|
||||
out, err := s.sqsClient.ReceiveMessageWithContext(s.ctx, &sqs.ReceiveMessageInput{
|
||||
QueueUrl: aws.String(s.Config.SQSName),
|
||||
MaxNumberOfMessages: aws.Int64(10),
|
||||
WaitTimeSeconds: aws.Int64(20), //Probably no need to make it configurable ?
|
||||
})
|
||||
if err != nil {
|
||||
logger.Errorf("Error while polling SQS: %s", err)
|
||||
continue
|
||||
}
|
||||
logger.Tracef("SQS output: %v", out)
|
||||
logger.Debugf("Received %d messages from SQS", len(out.Messages))
|
||||
for _, message := range out.Messages {
|
||||
sqsMessagesReceived.WithLabelValues(s.Config.SQSName).Inc()
|
||||
bucket, key, err := s.extractBucketAndPrefix(message.Body)
|
||||
if err != nil {
|
||||
logger.Errorf("Error while parsing SQS message: %s", err)
|
||||
//Always delete the message to avoid infinite loop
|
||||
_, err = s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{
|
||||
QueueUrl: aws.String(s.Config.SQSName),
|
||||
ReceiptHandle: message.ReceiptHandle,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Errorf("Error while deleting SQS message: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
logger.Debugf("Received SQS message for object %s/%s", bucket, key)
|
||||
s.readerChan <- S3Object{Key: key, Bucket: bucket}
|
||||
_, err = s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{
|
||||
QueueUrl: aws.String(s.Config.SQSName),
|
||||
ReceiptHandle: message.ReceiptHandle,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Errorf("Error while deleting SQS message: %s", err)
|
||||
}
|
||||
logger.Debugf("Deleted SQS message for object %s/%s", bucket, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3Source) readFile(bucket string, key string) error {
|
||||
//TODO: Handle SSE-C
|
||||
var scanner *bufio.Scanner
|
||||
|
||||
logger := s.logger.WithFields(logrus.Fields{
|
||||
"method": "readFile",
|
||||
"bucket": bucket,
|
||||
"key": key,
|
||||
})
|
||||
|
||||
output, err := s.s3Client.GetObjectWithContext(s.ctx, &s3.GetObjectInput{
|
||||
Bucket: aws.String(bucket),
|
||||
Key: aws.String(key),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get object %s/%s: %w", bucket, key, err)
|
||||
}
|
||||
defer output.Body.Close()
|
||||
if strings.HasSuffix(key, ".gz") {
|
||||
gzReader, err := gzip.NewReader(output.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read gzip object %s/%s: %w", bucket, key, err)
|
||||
}
|
||||
defer gzReader.Close()
|
||||
scanner = bufio.NewScanner(gzReader)
|
||||
} else {
|
||||
scanner = bufio.NewScanner(output.Body)
|
||||
}
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
logger.Tracef("Read line %s", text)
|
||||
linesRead.WithLabelValues(bucket).Inc()
|
||||
l := types.Line{}
|
||||
l.Raw = text
|
||||
l.Labels = s.Config.Labels
|
||||
l.Time = time.Now().UTC()
|
||||
l.Process = true
|
||||
l.Module = s.GetName()
|
||||
l.Src = bucket
|
||||
var evt types.Event
|
||||
if !s.Config.UseTimeMachine {
|
||||
evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE}
|
||||
} else {
|
||||
evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE}
|
||||
}
|
||||
s.out <- evt
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("failed to read object %s/%s: %s", bucket, key, err)
|
||||
}
|
||||
objectsRead.WithLabelValues(bucket).Inc()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Source) GetMetrics() []prometheus.Collector {
|
||||
return []prometheus.Collector{linesRead, objectsRead, sqsMessagesReceived}
|
||||
}
|
||||
func (s *S3Source) GetAggregMetrics() []prometheus.Collector {
|
||||
return []prometheus.Collector{linesRead, objectsRead, sqsMessagesReceived}
|
||||
}
|
||||
|
||||
func (s *S3Source) UnmarshalConfig(yamlConfig []byte) error {
|
||||
s.Config = S3Configuration{}
|
||||
err := yaml.UnmarshalStrict(yamlConfig, &s.Config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot parse S3Acquisition configuration: %w", err)
|
||||
}
|
||||
if s.Config.Mode == "" {
|
||||
s.Config.Mode = configuration.TAIL_MODE
|
||||
}
|
||||
if s.Config.PollingMethod == "" {
|
||||
s.Config.PollingMethod = PollMethodList
|
||||
}
|
||||
|
||||
if s.Config.PollingInterval == 0 {
|
||||
s.Config.PollingInterval = 60
|
||||
}
|
||||
|
||||
if s.Config.PollingMethod != PollMethodList && s.Config.PollingMethod != PollMethodSQS {
|
||||
return fmt.Errorf("invalid polling method %s", s.Config.PollingMethod)
|
||||
}
|
||||
|
||||
if s.Config.BucketName != "" && s.Config.SQSName != "" {
|
||||
return fmt.Errorf("bucket_name and sqs_name are mutually exclusive")
|
||||
}
|
||||
|
||||
if s.Config.PollingMethod == PollMethodSQS && s.Config.SQSName == "" {
|
||||
return fmt.Errorf("sqs_name is required when using sqs polling method")
|
||||
}
|
||||
|
||||
if s.Config.BucketName == "" && s.Config.PollingMethod == PollMethodList {
|
||||
return fmt.Errorf("bucket_name is required")
|
||||
}
|
||||
|
||||
if s.Config.SQSFormat != "" && s.Config.SQSFormat != SQSFormatEventBridge && s.Config.SQSFormat != SQSFormatS3Notification {
|
||||
return fmt.Errorf("invalid sqs_format %s, must be empty, %s or %s", s.Config.SQSFormat, SQSFormatEventBridge, SQSFormatS3Notification)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Source) Configure(yamlConfig []byte, logger *log.Entry) error {
|
||||
err := s.UnmarshalConfig(yamlConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.Config.SQSName != "" {
|
||||
s.logger = logger.WithFields(log.Fields{
|
||||
"queue": s.Config.SQSName,
|
||||
})
|
||||
} else {
|
||||
s.logger = logger.WithFields(log.Fields{
|
||||
"bucket": s.Config.BucketName,
|
||||
"prefix": s.Config.Prefix,
|
||||
})
|
||||
}
|
||||
|
||||
if !s.Config.UseTimeMachine {
|
||||
s.logger.Warning("use_time_machine is not set to true in the datasource configuration. This will likely lead to false positives as S3 logs are not processed in real time.")
|
||||
}
|
||||
|
||||
if s.Config.PollingMethod == PollMethodList {
|
||||
s.logger.Warning("Polling method is set to list. This is not recommended as it will not scale well. Consider using SQS instead.")
|
||||
}
|
||||
|
||||
err = s.newS3Client()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.Config.PollingMethod == PollMethodSQS {
|
||||
err = s.newSQSClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Source) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry) error {
|
||||
if !strings.HasPrefix(dsn, "s3://") {
|
||||
return fmt.Errorf("invalid DSN %s for S3 source, must start with s3://", dsn)
|
||||
}
|
||||
|
||||
s.logger = logger.WithFields(log.Fields{
|
||||
"bucket": s.Config.BucketName,
|
||||
"prefix": s.Config.Prefix,
|
||||
})
|
||||
|
||||
dsn = strings.TrimPrefix(dsn, "s3://")
|
||||
args := strings.Split(dsn, "?")
|
||||
if len(args[0]) == 0 {
|
||||
return fmt.Errorf("empty s3:// DSN")
|
||||
}
|
||||
|
||||
if len(args) == 2 && len(args[1]) != 0 {
|
||||
params, err := url.ParseQuery(args[1])
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not parse s3 args")
|
||||
}
|
||||
for key, value := range params {
|
||||
if key != "log_level" {
|
||||
return fmt.Errorf("unsupported key %s in s3 DSN", key)
|
||||
}
|
||||
if len(value) != 1 {
|
||||
return errors.New("expected zero or one value for 'log_level'")
|
||||
}
|
||||
lvl, err := log.ParseLevel(value[0])
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "unknown level %s", value[0])
|
||||
}
|
||||
s.logger.Logger.SetLevel(lvl)
|
||||
}
|
||||
}
|
||||
|
||||
s.Config = S3Configuration{}
|
||||
s.Config.Labels = labels
|
||||
s.Config.Mode = configuration.CAT_MODE
|
||||
|
||||
pathParts := strings.Split(args[0], "/")
|
||||
s.logger.Debugf("pathParts: %v", pathParts)
|
||||
|
||||
//FIXME: handle s3://bucket/
|
||||
if len(pathParts) == 1 {
|
||||
s.Config.BucketName = pathParts[0]
|
||||
s.Config.Prefix = ""
|
||||
} else if len(pathParts) > 1 {
|
||||
s.Config.BucketName = pathParts[0]
|
||||
if args[0][len(args[0])-1] == '/' {
|
||||
s.Config.Prefix = strings.Join(pathParts[1:], "/")
|
||||
} else {
|
||||
s.Config.Key = strings.Join(pathParts[1:], "/")
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("invalid DSN %s for S3 source", dsn)
|
||||
}
|
||||
|
||||
err := s.newS3Client()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Source) GetMode() string {
|
||||
return s.Config.Mode
|
||||
}
|
||||
|
||||
func (s *S3Source) GetName() string {
|
||||
return "s3"
|
||||
}
|
||||
|
||||
func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error {
|
||||
s.logger.Infof("starting acquisition of %s/%s/%s", s.Config.BucketName, s.Config.Prefix, s.Config.Key)
|
||||
s.out = out
|
||||
s.ctx, s.cancel = context.WithCancel(context.Background())
|
||||
if s.Config.Key != "" {
|
||||
err := s.readFile(s.Config.BucketName, s.Config.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
//No key, get everything in the bucket based on the prefix
|
||||
objects, err := s.getBucketContent()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, object := range objects {
|
||||
err := s.readFile(s.Config.BucketName, *object.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Source) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error {
|
||||
s.t = t
|
||||
s.out = out
|
||||
s.readerChan = make(chan S3Object, 100) //FIXME: does this needs to be buffered?
|
||||
s.ctx, s.cancel = context.WithCancel(context.Background())
|
||||
s.logger.Infof("starting acquisition of %s/%s", s.Config.BucketName, s.Config.Prefix)
|
||||
t.Go(func() error {
|
||||
s.readManager()
|
||||
return nil
|
||||
})
|
||||
if s.Config.PollingMethod == PollMethodSQS {
|
||||
t.Go(func() error {
|
||||
err := s.sqsPoll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
t.Go(func() error {
|
||||
err := s.listPoll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Source) CanRun() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Source) Dump() interface{} {
|
||||
return s
|
||||
}
|
429
pkg/acquisition/modules/s3/s3_test.go
Normal file
429
pkg/acquisition/modules/s3/s3_test.go
Normal file
|
@ -0,0 +1,429 @@
|
|||
package s3acquisition
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3iface"
|
||||
"github.com/aws/aws-sdk-go/service/sqs"
|
||||
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
|
||||
"github.com/crowdsecurity/crowdsec/pkg/types"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/tomb.v2"
|
||||
)
|
||||
|
||||
func TestBadConfiguration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "no bucket",
|
||||
config: `
|
||||
source: s3
|
||||
`,
|
||||
expectedErr: "bucket_name is required",
|
||||
},
|
||||
{
|
||||
name: "invalid polling method",
|
||||
config: `
|
||||
source: s3
|
||||
bucket_name: foobar
|
||||
polling_method: foobar
|
||||
`,
|
||||
expectedErr: "invalid polling method foobar",
|
||||
},
|
||||
{
|
||||
name: "no sqs name",
|
||||
config: `
|
||||
source: s3
|
||||
bucket_name: foobar
|
||||
polling_method: sqs
|
||||
`,
|
||||
expectedErr: "sqs_name is required when using sqs polling method",
|
||||
},
|
||||
{
|
||||
name: "both bucket and sqs",
|
||||
config: `
|
||||
source: s3
|
||||
bucket_name: foobar
|
||||
polling_method: sqs
|
||||
sqs_name: foobar
|
||||
`,
|
||||
expectedErr: "bucket_name and sqs_name are mutually exclusive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
f := S3Source{}
|
||||
err := f.Configure([]byte(test.config), nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got none")
|
||||
}
|
||||
if err.Error() != test.expectedErr {
|
||||
t.Fatalf("expected error %s, got %s", test.expectedErr, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoodConfiguration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
config: `
|
||||
source: s3
|
||||
bucket_name: foobar
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "polling method",
|
||||
config: `
|
||||
source: s3
|
||||
polling_method: sqs
|
||||
sqs_name: foobar
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "list method",
|
||||
config: `
|
||||
source: s3
|
||||
bucket_name: foobar
|
||||
polling_method: list
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
f := S3Source{}
|
||||
logger := log.NewEntry(log.New())
|
||||
err := f.Configure([]byte(test.config), logger)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockS3Client struct {
|
||||
s3iface.S3API
|
||||
}
|
||||
|
||||
// We add one hour to trick the listing goroutine into thinking the files are new
|
||||
var mockListOutput map[string][]*s3.Object = map[string][]*s3.Object{
|
||||
"bucket_no_prefix": {
|
||||
{
|
||||
Key: aws.String("foo.log"),
|
||||
LastModified: aws.Time(time.Now().Add(time.Hour)),
|
||||
},
|
||||
},
|
||||
"bucket_with_prefix": {
|
||||
{
|
||||
Key: aws.String("prefix/foo.log"),
|
||||
LastModified: aws.Time(time.Now().Add(time.Hour)),
|
||||
},
|
||||
{
|
||||
Key: aws.String("prefix/bar.log"),
|
||||
LastModified: aws.Time(time.Now().Add(time.Hour)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func (m mockS3Client) ListObjectsV2WithContext(ctx context.Context, input *s3.ListObjectsV2Input, options ...request.Option) (*s3.ListObjectsV2Output, error) {
|
||||
log.Infof("returning mock list output for %s, %v", *input.Bucket, mockListOutput[*input.Bucket])
|
||||
return &s3.ListObjectsV2Output{
|
||||
Contents: mockListOutput[*input.Bucket],
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m mockS3Client) GetObjectWithContext(ctx context.Context, input *s3.GetObjectInput, options ...request.Option) (*s3.GetObjectOutput, error) {
|
||||
r := strings.NewReader("foo\nbar")
|
||||
return &s3.GetObjectOutput{
|
||||
Body: aws.ReadSeekCloser(r),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type mockSQSClient struct {
|
||||
sqsiface.SQSAPI
|
||||
counter *int32
|
||||
}
|
||||
|
||||
func (msqs mockSQSClient) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, options ...request.Option) (*sqs.ReceiveMessageOutput, error) {
|
||||
if atomic.LoadInt32(msqs.counter) == 1 {
|
||||
return &sqs.ReceiveMessageOutput{}, nil
|
||||
}
|
||||
atomic.AddInt32(msqs.counter, 1)
|
||||
return &sqs.ReceiveMessageOutput{
|
||||
Messages: []*sqs.Message{
|
||||
{
|
||||
Body: aws.String(`
|
||||
{"version":"0","id":"af1ce7ea-bdb4-5bb7-3af2-c6cb32f9aac9","detail-type":"Object Created","source":"aws.s3","account":"1234","time":"2023-03-17T07:45:04Z","region":"eu-west-1","resources":["arn:aws:s3:::my_bucket"],"detail":{"version":"0","bucket":{"name":"my_bucket"},"object":{"key":"foo.log","size":663,"etag":"f2d5268a0776d6cdd6e14fcfba96d1cd","sequencer":"0064141A8022966874"},"request-id":"MBWX2P6FWA3S1YH5","requester":"156460612806","source-ip-address":"42.42.42.42","reason":"PutObject"}}`),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (msqs mockSQSClient) DeleteMessage(input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) {
|
||||
return &sqs.DeleteMessageOutput{}, nil
|
||||
}
|
||||
|
||||
type mockSQSClientNotif struct {
|
||||
sqsiface.SQSAPI
|
||||
counter *int32
|
||||
}
|
||||
|
||||
func (msqs mockSQSClientNotif) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, options ...request.Option) (*sqs.ReceiveMessageOutput, error) {
|
||||
if atomic.LoadInt32(msqs.counter) == 1 {
|
||||
return &sqs.ReceiveMessageOutput{}, nil
|
||||
}
|
||||
atomic.AddInt32(msqs.counter, 1)
|
||||
return &sqs.ReceiveMessageOutput{
|
||||
Messages: []*sqs.Message{
|
||||
{
|
||||
Body: aws.String(`
|
||||
{"Records":[{"eventVersion":"2.1","eventSource":"aws:s3","awsRegion":"eu-west-1","eventTime":"2023-03-20T19:30:02.536Z","eventName":"ObjectCreated:Put","userIdentity":{"principalId":"AWS:XXXXX"},"requestParameters":{"sourceIPAddress":"42.42.42.42"},"responseElements":{"x-amz-request-id":"FM0TAV2WE5AXXW42","x-amz-id-2":"LCfQt1aSBtD1G5wdXjB5ANdPxLEXJxA89Ev+/rRAsCGFNJGI/1+HMlKI59S92lqvzfViWh7B74leGKWB8/nNbsbKbK7WXKz2"},"s3":{"s3SchemaVersion":"1.0","configurationId":"test-acquis","bucket":{"name":"my_bucket","ownerIdentity":{"principalId":"A1F2PSER1FB8MY"},"arn":"arn:aws:s3:::my_bucket"},"object":{"key":"foo.log","size":3097,"eTag":"ab6889744611c77991cbc6ca12d1ddc7","sequencer":"006418B43A76BC0257"}}}]}`),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (msqs mockSQSClientNotif) DeleteMessage(input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) {
|
||||
return &sqs.DeleteMessageOutput{}, nil
|
||||
}
|
||||
|
||||
func TestDSNAcquis(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dsn string
|
||||
expectedBucketName string
|
||||
expectedPrefix string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
dsn: "s3://bucket_no_prefix/foo.log",
|
||||
expectedBucketName: "bucket_no_prefix",
|
||||
expectedPrefix: "",
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "with prefix",
|
||||
dsn: "s3://bucket_with_prefix/prefix/",
|
||||
expectedBucketName: "bucket_with_prefix",
|
||||
expectedPrefix: "prefix/",
|
||||
expectedCount: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
linesRead := 0
|
||||
f := S3Source{}
|
||||
logger := log.NewEntry(log.New())
|
||||
err := f.ConfigureByDSN(test.dsn, map[string]string{"foo": "bar"}, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err.Error())
|
||||
}
|
||||
assert.Equal(t, test.expectedBucketName, f.Config.BucketName)
|
||||
assert.Equal(t, test.expectedPrefix, f.Config.Prefix)
|
||||
out := make(chan types.Event)
|
||||
|
||||
done := make(chan bool)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case s := <-out:
|
||||
fmt.Printf("got line %s\n", s.Line.Raw)
|
||||
linesRead++
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
f.s3Client = mockS3Client{}
|
||||
err = f.OneShotAcquisition(out, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err.Error())
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
done <- true
|
||||
assert.Equal(t, test.expectedCount, linesRead)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestListPolling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
config: `
|
||||
source: s3
|
||||
bucket_name: bucket_no_prefix
|
||||
polling_method: list
|
||||
polling_interval: 1
|
||||
`,
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "with prefix",
|
||||
config: `
|
||||
source: s3
|
||||
bucket_name: bucket_with_prefix
|
||||
polling_method: list
|
||||
polling_interval: 1
|
||||
prefix: foo/
|
||||
`,
|
||||
expectedCount: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
linesRead := 0
|
||||
f := S3Source{}
|
||||
logger := log.NewEntry(log.New())
|
||||
logger.Logger.SetLevel(log.TraceLevel)
|
||||
err := f.Configure([]byte(test.config), logger)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err.Error())
|
||||
}
|
||||
if f.Config.PollingMethod != PollMethodList {
|
||||
t.Fatalf("expected list polling, got %s", f.Config.PollingMethod)
|
||||
}
|
||||
|
||||
f.s3Client = mockS3Client{}
|
||||
|
||||
out := make(chan types.Event)
|
||||
tb := tomb.Tomb{}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case s := <-out:
|
||||
fmt.Printf("got line %s\n", s.Line.Raw)
|
||||
linesRead++
|
||||
case <-tb.Dying():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = f.StreamingAcquisition(out, &tb)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err.Error())
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
tb.Kill(nil)
|
||||
err = tb.Wait()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err.Error())
|
||||
}
|
||||
assert.Equal(t, test.expectedCount, linesRead)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQSPoll(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config string
|
||||
notifType string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "eventbridge",
|
||||
config: `
|
||||
source: s3
|
||||
polling_method: sqs
|
||||
sqs_name: test
|
||||
`,
|
||||
expectedCount: 2,
|
||||
notifType: "eventbridge",
|
||||
},
|
||||
{
|
||||
name: "notification",
|
||||
config: `
|
||||
source: s3
|
||||
polling_method: sqs
|
||||
sqs_name: test
|
||||
`,
|
||||
expectedCount: 2,
|
||||
notifType: "notification",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
linesRead := 0
|
||||
f := S3Source{}
|
||||
logger := log.NewEntry(log.New())
|
||||
err := f.Configure([]byte(test.config), logger)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err.Error())
|
||||
}
|
||||
if f.Config.PollingMethod != PollMethodSQS {
|
||||
t.Fatalf("expected sqs polling, got %s", f.Config.PollingMethod)
|
||||
}
|
||||
|
||||
counter := int32(0)
|
||||
f.s3Client = mockS3Client{}
|
||||
if test.notifType == "eventbridge" {
|
||||
f.sqsClient = mockSQSClient{counter: &counter}
|
||||
} else {
|
||||
f.sqsClient = mockSQSClientNotif{counter: &counter}
|
||||
}
|
||||
|
||||
out := make(chan types.Event)
|
||||
tb := tomb.Tomb{}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case s := <-out:
|
||||
fmt.Printf("got line %s\n", s.Line.Raw)
|
||||
linesRead++
|
||||
case <-tb.Dying():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = f.StreamingAcquisition(out, &tb)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err.Error())
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
tb.Kill(nil)
|
||||
err = tb.Wait()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err.Error())
|
||||
}
|
||||
assert.Equal(t, test.expectedCount, linesRead)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue