S3 acquisition datasource (#2130)

This commit is contained in:
blotus 2023-03-21 13:54:52 +01:00 committed by GitHub
parent a74e424d53
commit dc38e5ac00
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 1081 additions and 0 deletions

1
go.mod
View file

@ -98,6 +98,7 @@ require (
github.com/ahmetalpbalkan/dlog v0.0.0-20170105205344-4fb5f8204f26 // indirect
github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect
github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef // indirect
github.com/aws/aws-lambda-go v1.38.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/containerd/containerd v1.6.18 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect

2
go.sum
View file

@ -107,6 +107,8 @@ github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496/go.mod h1:o
github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg=
github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef h1:46PFijGLmAjMPwCCCo7Jf0W6f9slllCkkv7vyc1yOSg=
github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
github.com/aws/aws-lambda-go v1.38.0 h1:4CUdxGzvuQp0o8Zh7KtupB9XvCiiY8yKqJtzco+gsDw=
github.com/aws/aws-lambda-go v1.38.0/go.mod h1:jwFe2KmMsHmffA1X2R09hH6lFzJQxzI8qK17ewzbQMM=
github.com/aws/aws-sdk-go v1.34.28/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48=
github.com/aws/aws-sdk-go v1.42.25 h1:BbdvHAi+t9LRiaYUyd53noq9jcaAcfzOhSVbKfr6Avs=
github.com/aws/aws-sdk-go v1.42.25/go.mod h1:gyRszuZ/icHmHAVE4gc/r+cfCmhA1AD+vqfWbgI+eHs=

View file

@ -20,6 +20,7 @@ import (
kafkaacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kafka"
kinesisacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kinesis"
k8sauditacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kubernetesaudit"
s3acquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/s3"
syslogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog"
wineventlogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/wineventlog"
@ -52,6 +53,7 @@ var AcquisitionSources = map[string]func() DataSource{
"wineventlog": func() DataSource { return &wineventlogacquisition.WinEventLogSource{} },
"kafka": func() DataSource { return &kafkaacquisition.KafkaSource{} },
"k8s_audit": func() DataSource { return &k8sauditacquisition.KubernetesAuditSource{} },
"s3": func() DataSource { return &s3acquisition.S3Source{} },
}
func GetDataSourceIface(dataSourceType string) DataSource {

View file

@ -0,0 +1,647 @@
package s3acquisition
import (
"bufio"
"compress/gzip"
"context"
"encoding/json"
"fmt"
"net/url"
"sort"
"strings"
"time"
"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
"github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration"
"github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"gopkg.in/tomb.v2"
"gopkg.in/yaml.v2"
)
type S3Configuration struct {
configuration.DataSourceCommonCfg `yaml:",inline"`
AwsProfile *string `yaml:"aws_profile"`
AwsRegion string `yaml:"aws_region"`
AwsEndpoint string `yaml:"aws_endpoint"`
BucketName string `yaml:"bucket_name"`
Prefix string `yaml:"prefix"`
Key string `yaml:"-"` //Only for DSN acquisition
PollingMethod string `yaml:"polling_method"`
PollingInterval int `yaml:"polling_interval"`
SQSName string `yaml:"sqs_name"`
SQSFormat string `yaml:"sqs_format"`
}
type S3Source struct {
Config S3Configuration
logger *log.Entry
s3Client s3iface.S3API
sqsClient sqsiface.SQSAPI
readerChan chan S3Object
t *tomb.Tomb
out chan types.Event
ctx aws.Context
cancel context.CancelFunc
}
type S3Object struct {
Key string
Bucket string
}
// For some reason, the aws sdk doesn't have a struct for this
// The one aws-lamdbda-go/events is only intended when using S3 Notification without event bridge
type S3Event struct {
Version string `json:"version"`
Id string `json:"id"`
DetailType string `json:"detail-type"`
Source string `json:"source"`
Account string `json:"account"`
Time string `json:"time"`
Region string `json:"region"`
Resources []string `json:"resources"`
Detail struct {
Version string `json:"version"`
RequestId string `json:"request-id"`
Requester string `json:"requester"`
Reason string `json:"reason"`
SourceIpAddress string `json:"source-ip-address"`
Bucket struct {
Name string `json:"name"`
} `json:"bucket"`
Object struct {
Key string `json:"key"`
Size int `json:"size"`
Etag string `json:"etag"`
Sequencer string `json:"sequencer"`
} `json:"object"`
} `json:"detail"`
}
const PollMethodList = "list"
const PollMethodSQS = "sqs"
const SQSFormatEventBridge = "eventbridge"
const SQSFormatS3Notification = "s3notification"
var linesRead = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "cs_s3_hits_total",
Help: "Number of events read per bucket.",
},
[]string{"bucket"},
)
var objectsRead = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "cs_s3_objects_total",
Help: "Number of objects read per bucket.",
},
[]string{"bucket"},
)
var sqsMessagesReceived = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "cs_s3_sqs_messages_total",
Help: "Number of SQS messages received per queue.",
},
[]string{"queue"},
)
func (s *S3Source) newS3Client() error {
options := session.Options{
SharedConfigState: session.SharedConfigEnable,
}
if s.Config.AwsProfile != nil {
options.Profile = *s.Config.AwsProfile
}
sess, err := session.NewSessionWithOptions(options)
if err != nil {
return fmt.Errorf("failed to create aws session: %w", err)
}
config := aws.NewConfig()
if s.Config.AwsRegion != "" {
config = config.WithRegion(s.Config.AwsRegion)
}
if s.Config.AwsEndpoint != "" {
config = config.WithEndpoint(s.Config.AwsEndpoint)
}
s.s3Client = s3.New(sess, config)
if s.s3Client == nil {
return fmt.Errorf("failed to create S3 client")
}
return nil
}
func (s *S3Source) newSQSClient() error {
var sess *session.Session
if s.Config.AwsProfile != nil {
sess = session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
Profile: *s.Config.AwsProfile,
}))
} else {
sess = session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))
}
if sess == nil {
return fmt.Errorf("failed to create aws session")
}
config := aws.NewConfig()
if s.Config.AwsRegion != "" {
config = config.WithRegion(s.Config.AwsRegion)
}
if s.Config.AwsEndpoint != "" {
config = config.WithEndpoint(s.Config.AwsEndpoint)
}
s.sqsClient = sqs.New(sess, config)
if s.sqsClient == nil {
return fmt.Errorf("failed to create SQS client")
}
return nil
}
func (s *S3Source) readManager() {
logger := s.logger.WithField("method", "readManager")
for {
select {
case <-s.t.Dying():
logger.Infof("Shutting down S3 read manager")
s.cancel()
return
case s3Object := <-s.readerChan:
logger.Debugf("Reading file %s/%s", s3Object.Bucket, s3Object.Key)
err := s.readFile(s3Object.Bucket, s3Object.Key)
if err != nil {
logger.Errorf("Error while reading file: %s", err)
}
}
}
}
func (s *S3Source) getBucketContent() ([]*s3.Object, error) {
logger := s.logger.WithField("method", "getBucketContent")
logger.Debugf("Getting bucket content for %s", s.Config.BucketName)
bucketObjects := make([]*s3.Object, 0)
var continuationToken *string = nil
for {
out, err := s.s3Client.ListObjectsV2WithContext(s.ctx, &s3.ListObjectsV2Input{
Bucket: aws.String(s.Config.BucketName),
Prefix: aws.String(s.Config.Prefix),
ContinuationToken: continuationToken,
})
if err != nil {
logger.Errorf("Error while listing bucket content: %s", err)
return nil, err
}
bucketObjects = append(bucketObjects, out.Contents...)
if out.NextContinuationToken == nil {
break
}
continuationToken = out.NextContinuationToken
}
sort.Slice(bucketObjects, func(i, j int) bool {
return bucketObjects[i].LastModified.Before(*bucketObjects[j].LastModified)
})
return bucketObjects, nil
}
func (s *S3Source) listPoll() error {
logger := s.logger.WithField("method", "listPoll")
ticker := time.NewTicker(time.Duration(s.Config.PollingInterval) * time.Second)
lastObjectDate := time.Now()
defer ticker.Stop()
for {
select {
case <-s.t.Dying():
logger.Infof("Shutting down list poller")
s.cancel()
return nil
case <-ticker.C:
newObject := false
bucketObjects, err := s.getBucketContent()
if err != nil {
logger.Errorf("Error while getting bucket content: %s", err)
continue
}
if bucketObjects == nil {
continue
}
for i := len(bucketObjects) - 1; i >= 0; i-- {
if bucketObjects[i].LastModified.After(lastObjectDate) {
newObject = true
logger.Debugf("Found new object %s", *bucketObjects[i].Key)
s.readerChan <- S3Object{
Bucket: s.Config.BucketName,
Key: *bucketObjects[i].Key,
}
} else {
break
}
}
if newObject {
lastObjectDate = *bucketObjects[len(bucketObjects)-1].LastModified
}
}
}
}
func extractBucketAndPrefixFromEventBridge(message *string) (string, string, error) {
eventBody := S3Event{}
err := json.Unmarshal([]byte(*message), &eventBody)
if err != nil {
return "", "", err
}
if eventBody.Detail.Bucket.Name != "" {
return eventBody.Detail.Bucket.Name, eventBody.Detail.Object.Key, nil
}
return "", "", fmt.Errorf("invalid event body for event bridge format")
}
func extractBucketAndPrefixFromS3Notif(message *string) (string, string, error) {
s3notifBody := events.S3Event{}
err := json.Unmarshal([]byte(*message), &s3notifBody)
if err != nil {
return "", "", err
}
if len(s3notifBody.Records) == 0 {
return "", "", fmt.Errorf("no records found in S3 notification")
}
if !strings.HasPrefix(s3notifBody.Records[0].EventName, "ObjectCreated:") {
return "", "", fmt.Errorf("event %s is not supported", s3notifBody.Records[0].EventName)
}
return s3notifBody.Records[0].S3.Bucket.Name, s3notifBody.Records[0].S3.Object.Key, nil
}
func (s *S3Source) extractBucketAndPrefix(message *string) (string, string, error) {
if s.Config.SQSFormat == SQSFormatEventBridge {
bucket, key, err := extractBucketAndPrefixFromEventBridge(message)
if err != nil {
return "", "", err
}
return bucket, key, nil
} else if s.Config.SQSFormat == SQSFormatS3Notification {
bucket, key, err := extractBucketAndPrefixFromS3Notif(message)
if err != nil {
return "", "", err
}
return bucket, key, nil
} else {
bucket, key, err := extractBucketAndPrefixFromEventBridge(message)
if err == nil {
s.Config.SQSFormat = SQSFormatEventBridge
return bucket, key, nil
}
bucket, key, err = extractBucketAndPrefixFromS3Notif(message)
if err == nil {
s.Config.SQSFormat = SQSFormatS3Notification
return bucket, key, nil
}
return "", "", fmt.Errorf("SQS message format not supported")
}
}
func (s *S3Source) sqsPoll() error {
logger := s.logger.WithField("method", "sqsPoll")
for {
select {
case <-s.t.Dying():
logger.Infof("Shutting down SQS poller")
s.cancel()
return nil
default:
logger.Trace("Polling SQS queue")
out, err := s.sqsClient.ReceiveMessageWithContext(s.ctx, &sqs.ReceiveMessageInput{
QueueUrl: aws.String(s.Config.SQSName),
MaxNumberOfMessages: aws.Int64(10),
WaitTimeSeconds: aws.Int64(20), //Probably no need to make it configurable ?
})
if err != nil {
logger.Errorf("Error while polling SQS: %s", err)
continue
}
logger.Tracef("SQS output: %v", out)
logger.Debugf("Received %d messages from SQS", len(out.Messages))
for _, message := range out.Messages {
sqsMessagesReceived.WithLabelValues(s.Config.SQSName).Inc()
bucket, key, err := s.extractBucketAndPrefix(message.Body)
if err != nil {
logger.Errorf("Error while parsing SQS message: %s", err)
//Always delete the message to avoid infinite loop
_, err = s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{
QueueUrl: aws.String(s.Config.SQSName),
ReceiptHandle: message.ReceiptHandle,
})
if err != nil {
logger.Errorf("Error while deleting SQS message: %s", err)
}
continue
}
logger.Debugf("Received SQS message for object %s/%s", bucket, key)
s.readerChan <- S3Object{Key: key, Bucket: bucket}
_, err = s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{
QueueUrl: aws.String(s.Config.SQSName),
ReceiptHandle: message.ReceiptHandle,
})
if err != nil {
logger.Errorf("Error while deleting SQS message: %s", err)
}
logger.Debugf("Deleted SQS message for object %s/%s", bucket, key)
}
}
}
}
func (s *S3Source) readFile(bucket string, key string) error {
//TODO: Handle SSE-C
var scanner *bufio.Scanner
logger := s.logger.WithFields(logrus.Fields{
"method": "readFile",
"bucket": bucket,
"key": key,
})
output, err := s.s3Client.GetObjectWithContext(s.ctx, &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})
if err != nil {
return fmt.Errorf("failed to get object %s/%s: %w", bucket, key, err)
}
defer output.Body.Close()
if strings.HasSuffix(key, ".gz") {
gzReader, err := gzip.NewReader(output.Body)
if err != nil {
return fmt.Errorf("failed to read gzip object %s/%s: %w", bucket, key, err)
}
defer gzReader.Close()
scanner = bufio.NewScanner(gzReader)
} else {
scanner = bufio.NewScanner(output.Body)
}
for scanner.Scan() {
text := scanner.Text()
logger.Tracef("Read line %s", text)
linesRead.WithLabelValues(bucket).Inc()
l := types.Line{}
l.Raw = text
l.Labels = s.Config.Labels
l.Time = time.Now().UTC()
l.Process = true
l.Module = s.GetName()
l.Src = bucket
var evt types.Event
if !s.Config.UseTimeMachine {
evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE}
} else {
evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE}
}
s.out <- evt
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("failed to read object %s/%s: %s", bucket, key, err)
}
objectsRead.WithLabelValues(bucket).Inc()
return nil
}
func (s *S3Source) GetMetrics() []prometheus.Collector {
return []prometheus.Collector{linesRead, objectsRead, sqsMessagesReceived}
}
func (s *S3Source) GetAggregMetrics() []prometheus.Collector {
return []prometheus.Collector{linesRead, objectsRead, sqsMessagesReceived}
}
func (s *S3Source) UnmarshalConfig(yamlConfig []byte) error {
s.Config = S3Configuration{}
err := yaml.UnmarshalStrict(yamlConfig, &s.Config)
if err != nil {
return fmt.Errorf("cannot parse S3Acquisition configuration: %w", err)
}
if s.Config.Mode == "" {
s.Config.Mode = configuration.TAIL_MODE
}
if s.Config.PollingMethod == "" {
s.Config.PollingMethod = PollMethodList
}
if s.Config.PollingInterval == 0 {
s.Config.PollingInterval = 60
}
if s.Config.PollingMethod != PollMethodList && s.Config.PollingMethod != PollMethodSQS {
return fmt.Errorf("invalid polling method %s", s.Config.PollingMethod)
}
if s.Config.BucketName != "" && s.Config.SQSName != "" {
return fmt.Errorf("bucket_name and sqs_name are mutually exclusive")
}
if s.Config.PollingMethod == PollMethodSQS && s.Config.SQSName == "" {
return fmt.Errorf("sqs_name is required when using sqs polling method")
}
if s.Config.BucketName == "" && s.Config.PollingMethod == PollMethodList {
return fmt.Errorf("bucket_name is required")
}
if s.Config.SQSFormat != "" && s.Config.SQSFormat != SQSFormatEventBridge && s.Config.SQSFormat != SQSFormatS3Notification {
return fmt.Errorf("invalid sqs_format %s, must be empty, %s or %s", s.Config.SQSFormat, SQSFormatEventBridge, SQSFormatS3Notification)
}
return nil
}
func (s *S3Source) Configure(yamlConfig []byte, logger *log.Entry) error {
err := s.UnmarshalConfig(yamlConfig)
if err != nil {
return err
}
if s.Config.SQSName != "" {
s.logger = logger.WithFields(log.Fields{
"queue": s.Config.SQSName,
})
} else {
s.logger = logger.WithFields(log.Fields{
"bucket": s.Config.BucketName,
"prefix": s.Config.Prefix,
})
}
if !s.Config.UseTimeMachine {
s.logger.Warning("use_time_machine is not set to true in the datasource configuration. This will likely lead to false positives as S3 logs are not processed in real time.")
}
if s.Config.PollingMethod == PollMethodList {
s.logger.Warning("Polling method is set to list. This is not recommended as it will not scale well. Consider using SQS instead.")
}
err = s.newS3Client()
if err != nil {
return err
}
if s.Config.PollingMethod == PollMethodSQS {
err = s.newSQSClient()
if err != nil {
return err
}
}
return nil
}
func (s *S3Source) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry) error {
if !strings.HasPrefix(dsn, "s3://") {
return fmt.Errorf("invalid DSN %s for S3 source, must start with s3://", dsn)
}
s.logger = logger.WithFields(log.Fields{
"bucket": s.Config.BucketName,
"prefix": s.Config.Prefix,
})
dsn = strings.TrimPrefix(dsn, "s3://")
args := strings.Split(dsn, "?")
if len(args[0]) == 0 {
return fmt.Errorf("empty s3:// DSN")
}
if len(args) == 2 && len(args[1]) != 0 {
params, err := url.ParseQuery(args[1])
if err != nil {
return errors.Wrap(err, "could not parse s3 args")
}
for key, value := range params {
if key != "log_level" {
return fmt.Errorf("unsupported key %s in s3 DSN", key)
}
if len(value) != 1 {
return errors.New("expected zero or one value for 'log_level'")
}
lvl, err := log.ParseLevel(value[0])
if err != nil {
return errors.Wrapf(err, "unknown level %s", value[0])
}
s.logger.Logger.SetLevel(lvl)
}
}
s.Config = S3Configuration{}
s.Config.Labels = labels
s.Config.Mode = configuration.CAT_MODE
pathParts := strings.Split(args[0], "/")
s.logger.Debugf("pathParts: %v", pathParts)
//FIXME: handle s3://bucket/
if len(pathParts) == 1 {
s.Config.BucketName = pathParts[0]
s.Config.Prefix = ""
} else if len(pathParts) > 1 {
s.Config.BucketName = pathParts[0]
if args[0][len(args[0])-1] == '/' {
s.Config.Prefix = strings.Join(pathParts[1:], "/")
} else {
s.Config.Key = strings.Join(pathParts[1:], "/")
}
} else {
return fmt.Errorf("invalid DSN %s for S3 source", dsn)
}
err := s.newS3Client()
if err != nil {
return err
}
return nil
}
func (s *S3Source) GetMode() string {
return s.Config.Mode
}
func (s *S3Source) GetName() string {
return "s3"
}
func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error {
s.logger.Infof("starting acquisition of %s/%s/%s", s.Config.BucketName, s.Config.Prefix, s.Config.Key)
s.out = out
s.ctx, s.cancel = context.WithCancel(context.Background())
if s.Config.Key != "" {
err := s.readFile(s.Config.BucketName, s.Config.Key)
if err != nil {
return err
}
} else {
//No key, get everything in the bucket based on the prefix
objects, err := s.getBucketContent()
if err != nil {
return err
}
for _, object := range objects {
err := s.readFile(s.Config.BucketName, *object.Key)
if err != nil {
return err
}
}
}
return nil
}
func (s *S3Source) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error {
s.t = t
s.out = out
s.readerChan = make(chan S3Object, 100) //FIXME: does this needs to be buffered?
s.ctx, s.cancel = context.WithCancel(context.Background())
s.logger.Infof("starting acquisition of %s/%s", s.Config.BucketName, s.Config.Prefix)
t.Go(func() error {
s.readManager()
return nil
})
if s.Config.PollingMethod == PollMethodSQS {
t.Go(func() error {
err := s.sqsPoll()
if err != nil {
return err
}
return nil
})
} else {
t.Go(func() error {
err := s.listPoll()
if err != nil {
return err
}
return nil
})
}
return nil
}
func (s *S3Source) CanRun() error {
return nil
}
func (s *S3Source) Dump() interface{} {
return s
}

View file

@ -0,0 +1,429 @@
package s3acquisition
import (
"context"
"fmt"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
"github.com/crowdsecurity/crowdsec/pkg/types"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"gopkg.in/tomb.v2"
)
func TestBadConfiguration(t *testing.T) {
tests := []struct {
name string
config string
expectedErr string
}{
{
name: "no bucket",
config: `
source: s3
`,
expectedErr: "bucket_name is required",
},
{
name: "invalid polling method",
config: `
source: s3
bucket_name: foobar
polling_method: foobar
`,
expectedErr: "invalid polling method foobar",
},
{
name: "no sqs name",
config: `
source: s3
bucket_name: foobar
polling_method: sqs
`,
expectedErr: "sqs_name is required when using sqs polling method",
},
{
name: "both bucket and sqs",
config: `
source: s3
bucket_name: foobar
polling_method: sqs
sqs_name: foobar
`,
expectedErr: "bucket_name and sqs_name are mutually exclusive",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
f := S3Source{}
err := f.Configure([]byte(test.config), nil)
if err == nil {
t.Fatalf("expected error, got none")
}
if err.Error() != test.expectedErr {
t.Fatalf("expected error %s, got %s", test.expectedErr, err.Error())
}
})
}
}
func TestGoodConfiguration(t *testing.T) {
tests := []struct {
name string
config string
}{
{
name: "basic",
config: `
source: s3
bucket_name: foobar
`,
},
{
name: "polling method",
config: `
source: s3
polling_method: sqs
sqs_name: foobar
`,
},
{
name: "list method",
config: `
source: s3
bucket_name: foobar
polling_method: list
`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
f := S3Source{}
logger := log.NewEntry(log.New())
err := f.Configure([]byte(test.config), logger)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
})
}
}
type mockS3Client struct {
s3iface.S3API
}
// We add one hour to trick the listing goroutine into thinking the files are new
var mockListOutput map[string][]*s3.Object = map[string][]*s3.Object{
"bucket_no_prefix": {
{
Key: aws.String("foo.log"),
LastModified: aws.Time(time.Now().Add(time.Hour)),
},
},
"bucket_with_prefix": {
{
Key: aws.String("prefix/foo.log"),
LastModified: aws.Time(time.Now().Add(time.Hour)),
},
{
Key: aws.String("prefix/bar.log"),
LastModified: aws.Time(time.Now().Add(time.Hour)),
},
},
}
func (m mockS3Client) ListObjectsV2WithContext(ctx context.Context, input *s3.ListObjectsV2Input, options ...request.Option) (*s3.ListObjectsV2Output, error) {
log.Infof("returning mock list output for %s, %v", *input.Bucket, mockListOutput[*input.Bucket])
return &s3.ListObjectsV2Output{
Contents: mockListOutput[*input.Bucket],
}, nil
}
func (m mockS3Client) GetObjectWithContext(ctx context.Context, input *s3.GetObjectInput, options ...request.Option) (*s3.GetObjectOutput, error) {
r := strings.NewReader("foo\nbar")
return &s3.GetObjectOutput{
Body: aws.ReadSeekCloser(r),
}, nil
}
type mockSQSClient struct {
sqsiface.SQSAPI
counter *int32
}
func (msqs mockSQSClient) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, options ...request.Option) (*sqs.ReceiveMessageOutput, error) {
if atomic.LoadInt32(msqs.counter) == 1 {
return &sqs.ReceiveMessageOutput{}, nil
}
atomic.AddInt32(msqs.counter, 1)
return &sqs.ReceiveMessageOutput{
Messages: []*sqs.Message{
{
Body: aws.String(`
{"version":"0","id":"af1ce7ea-bdb4-5bb7-3af2-c6cb32f9aac9","detail-type":"Object Created","source":"aws.s3","account":"1234","time":"2023-03-17T07:45:04Z","region":"eu-west-1","resources":["arn:aws:s3:::my_bucket"],"detail":{"version":"0","bucket":{"name":"my_bucket"},"object":{"key":"foo.log","size":663,"etag":"f2d5268a0776d6cdd6e14fcfba96d1cd","sequencer":"0064141A8022966874"},"request-id":"MBWX2P6FWA3S1YH5","requester":"156460612806","source-ip-address":"42.42.42.42","reason":"PutObject"}}`),
},
},
}, nil
}
func (msqs mockSQSClient) DeleteMessage(input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) {
return &sqs.DeleteMessageOutput{}, nil
}
type mockSQSClientNotif struct {
sqsiface.SQSAPI
counter *int32
}
func (msqs mockSQSClientNotif) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, options ...request.Option) (*sqs.ReceiveMessageOutput, error) {
if atomic.LoadInt32(msqs.counter) == 1 {
return &sqs.ReceiveMessageOutput{}, nil
}
atomic.AddInt32(msqs.counter, 1)
return &sqs.ReceiveMessageOutput{
Messages: []*sqs.Message{
{
Body: aws.String(`
{"Records":[{"eventVersion":"2.1","eventSource":"aws:s3","awsRegion":"eu-west-1","eventTime":"2023-03-20T19:30:02.536Z","eventName":"ObjectCreated:Put","userIdentity":{"principalId":"AWS:XXXXX"},"requestParameters":{"sourceIPAddress":"42.42.42.42"},"responseElements":{"x-amz-request-id":"FM0TAV2WE5AXXW42","x-amz-id-2":"LCfQt1aSBtD1G5wdXjB5ANdPxLEXJxA89Ev+/rRAsCGFNJGI/1+HMlKI59S92lqvzfViWh7B74leGKWB8/nNbsbKbK7WXKz2"},"s3":{"s3SchemaVersion":"1.0","configurationId":"test-acquis","bucket":{"name":"my_bucket","ownerIdentity":{"principalId":"A1F2PSER1FB8MY"},"arn":"arn:aws:s3:::my_bucket"},"object":{"key":"foo.log","size":3097,"eTag":"ab6889744611c77991cbc6ca12d1ddc7","sequencer":"006418B43A76BC0257"}}}]}`),
},
},
}, nil
}
func (msqs mockSQSClientNotif) DeleteMessage(input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) {
return &sqs.DeleteMessageOutput{}, nil
}
func TestDSNAcquis(t *testing.T) {
tests := []struct {
name string
dsn string
expectedBucketName string
expectedPrefix string
expectedCount int
}{
{
name: "basic",
dsn: "s3://bucket_no_prefix/foo.log",
expectedBucketName: "bucket_no_prefix",
expectedPrefix: "",
expectedCount: 2,
},
{
name: "with prefix",
dsn: "s3://bucket_with_prefix/prefix/",
expectedBucketName: "bucket_with_prefix",
expectedPrefix: "prefix/",
expectedCount: 4,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
linesRead := 0
f := S3Source{}
logger := log.NewEntry(log.New())
err := f.ConfigureByDSN(test.dsn, map[string]string{"foo": "bar"}, logger)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
assert.Equal(t, test.expectedBucketName, f.Config.BucketName)
assert.Equal(t, test.expectedPrefix, f.Config.Prefix)
out := make(chan types.Event)
done := make(chan bool)
go func() {
for {
select {
case s := <-out:
fmt.Printf("got line %s\n", s.Line.Raw)
linesRead++
case <-done:
return
}
}
}()
f.s3Client = mockS3Client{}
err = f.OneShotAcquisition(out, nil)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
time.Sleep(2 * time.Second)
done <- true
assert.Equal(t, test.expectedCount, linesRead)
})
}
}
func TestListPolling(t *testing.T) {
tests := []struct {
name string
config string
expectedCount int
}{
{
name: "basic",
config: `
source: s3
bucket_name: bucket_no_prefix
polling_method: list
polling_interval: 1
`,
expectedCount: 2,
},
{
name: "with prefix",
config: `
source: s3
bucket_name: bucket_with_prefix
polling_method: list
polling_interval: 1
prefix: foo/
`,
expectedCount: 4,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
linesRead := 0
f := S3Source{}
logger := log.NewEntry(log.New())
logger.Logger.SetLevel(log.TraceLevel)
err := f.Configure([]byte(test.config), logger)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
if f.Config.PollingMethod != PollMethodList {
t.Fatalf("expected list polling, got %s", f.Config.PollingMethod)
}
f.s3Client = mockS3Client{}
out := make(chan types.Event)
tb := tomb.Tomb{}
go func() {
for {
select {
case s := <-out:
fmt.Printf("got line %s\n", s.Line.Raw)
linesRead++
case <-tb.Dying():
return
}
}
}()
err = f.StreamingAcquisition(out, &tb)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
time.Sleep(2 * time.Second)
tb.Kill(nil)
err = tb.Wait()
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
assert.Equal(t, test.expectedCount, linesRead)
})
}
}
func TestSQSPoll(t *testing.T) {
tests := []struct {
name string
config string
notifType string
expectedCount int
}{
{
name: "eventbridge",
config: `
source: s3
polling_method: sqs
sqs_name: test
`,
expectedCount: 2,
notifType: "eventbridge",
},
{
name: "notification",
config: `
source: s3
polling_method: sqs
sqs_name: test
`,
expectedCount: 2,
notifType: "notification",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
linesRead := 0
f := S3Source{}
logger := log.NewEntry(log.New())
err := f.Configure([]byte(test.config), logger)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
if f.Config.PollingMethod != PollMethodSQS {
t.Fatalf("expected sqs polling, got %s", f.Config.PollingMethod)
}
counter := int32(0)
f.s3Client = mockS3Client{}
if test.notifType == "eventbridge" {
f.sqsClient = mockSQSClient{counter: &counter}
} else {
f.sqsClient = mockSQSClientNotif{counter: &counter}
}
out := make(chan types.Event)
tb := tomb.Tomb{}
go func() {
for {
select {
case s := <-out:
fmt.Printf("got line %s\n", s.Line.Raw)
linesRead++
case <-tb.Dying():
return
}
}
}()
err = f.StreamingAcquisition(out, &tb)
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
time.Sleep(2 * time.Second)
tb.Kill(nil)
err = tb.Wait()
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
assert.Equal(t, test.expectedCount, linesRead)
})
}
}