From 9d2b5dc07d8acb4cead8bbf399e313a71da53f63 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Mon, 1 Aug 2022 18:48:54 +0200 Subject: [PATCH] refactor: move eventmanager to common package Signed-off-by: Nicola Murino --- go.mod | 6 +- go.sum | 12 +- internal/common/actions.go | 4 +- internal/common/common.go | 165 ++++-- internal/common/common_test.go | 144 ++--- internal/common/eventmanager.go | 776 +++++++++++++++++++++++++ internal/common/eventmanager_test.go | 674 +++++++++++++++++++++ internal/common/eventscheduler.go | 43 ++ internal/common/protocol_test.go | 188 +++++- internal/dataprovider/actions.go | 13 +- internal/dataprovider/dataprovider.go | 92 ++- internal/dataprovider/eventrule.go | 353 +---------- internal/dataprovider/eventruleutil.go | 474 --------------- internal/dataprovider/quota.go | 141 ----- internal/dataprovider/scheduler.go | 9 +- internal/httpd/api_maintenance.go | 5 +- internal/httpd/api_quota.go | 21 +- internal/httpd/httpd_test.go | 24 +- internal/httpd/internal_test.go | 2 +- internal/httpdtest/httpdtest.go | 8 +- internal/sftpd/internal_test.go | 2 +- internal/sftpd/sftpd_test.go | 26 +- internal/util/util.go | 7 + internal/vfs/cryptfs.go | 2 +- 24 files changed, 2030 insertions(+), 1161 deletions(-) create mode 100644 internal/common/eventmanager.go create mode 100644 internal/common/eventmanager_test.go create mode 100644 internal/common/eventscheduler.go delete mode 100644 internal/dataprovider/eventruleutil.go delete mode 100644 internal/dataprovider/quota.go diff --git a/go.mod b/go.mod index c0a4726c..0ab088eb 100644 --- a/go.mod +++ b/go.mod @@ -52,7 +52,7 @@ require ( github.com/rs/xid v1.4.0 github.com/rs/zerolog v1.27.0 github.com/sftpgo/sdk v0.1.2-0.20220727164210-06723ba7ce9a - github.com/shirou/gopsutil/v3 v3.22.6 + github.com/shirou/gopsutil/v3 v3.22.7 github.com/spf13/afero v1.9.2 github.com/spf13/cobra v1.5.0 github.com/spf13/viper v1.12.0 @@ -68,7 +68,7 @@ require ( golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa golang.org/x/net v0.0.0-20220728211354-c7608f3a8462 golang.org/x/oauth2 v0.0.0-20220722155238-128564f6959c - golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 + golang.org/x/sys v0.0.0-20220731174439-a90be440212d golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 google.golang.org/api v0.90.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 @@ -155,7 +155,7 @@ require ( golang.org/x/tools v0.1.12 // indirect golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20220728213248-dd149ef739b9 // indirect + google.golang.org/genproto v0.0.0-20220801145646-83ce21fca29f // indirect google.golang.org/grpc v1.48.0 // indirect google.golang.org/protobuf v1.28.1 // indirect gopkg.in/ini.v1 v1.66.6 // indirect diff --git a/go.sum b/go.sum index f5c7ced6..c6cf433f 100644 --- a/go.sum +++ b/go.sum @@ -712,8 +712,8 @@ github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdh github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4 h1:PT+ElG/UUFMfqy5HrxJxNzj3QBOf7dZwupeVC+mG1Lo= github.com/sftpgo/sdk v0.1.2-0.20220727164210-06723ba7ce9a h1:X9qPZ+GPQ87TnBDNZN6dyX7FkjhwnFh98WgB6Y1T5O8= github.com/sftpgo/sdk v0.1.2-0.20220727164210-06723ba7ce9a/go.mod h1:RL4HeorXC6XgqtkLYnQUSogLdsdMfbsogIvdBVLuy4w= -github.com/shirou/gopsutil/v3 v3.22.6 h1:FnHOFOh+cYAM0C30P+zysPISzlknLC5Z1G4EAElznfQ= -github.com/shirou/gopsutil/v3 v3.22.6/go.mod h1:EdIubSnZhbAvBS1yJ7Xi+AShB/hxwLHOMz4MCYz7yMs= +github.com/shirou/gopsutil/v3 v3.22.7 h1:flKnuCMfUUrO+oAvwAd6GKZgnPzr098VA/UJ14nhJd4= +github.com/shirou/gopsutil/v3 v3.22.7/go.mod h1:s648gW4IywYzUfE/KjXxUsqrqx/T2xO5VqOXxONeRfI= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= @@ -746,7 +746,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= -github.com/stretchr/testify v1.7.5/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/studio-b12/gowebdav v0.0.0-20220128162035-c7b1ff8a5e62 h1:b2nJXyPCa9HY7giGM+kYcnQ71m14JnGdQabMPmyt++8= @@ -973,8 +972,9 @@ golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220624220833-87e55d714810/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -1225,8 +1225,8 @@ google.golang.org/genproto v0.0.0-20220616135557-88e70c0c3a90/go.mod h1:KEWEmljW google.golang.org/genproto v0.0.0-20220617124728-180714bec0ad/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA= google.golang.org/genproto v0.0.0-20220624142145-8cd45d7dbd1f/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA= google.golang.org/genproto v0.0.0-20220628213854-d9e0b6570c03/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA= -google.golang.org/genproto v0.0.0-20220728213248-dd149ef739b9 h1:d3fKQZK+1rWQMg3xLKQbPMirUCo29I/NRdI2WarSzTg= -google.golang.org/genproto v0.0.0-20220728213248-dd149ef739b9/go.mod h1:iHe1svFLAZg9VWz891+QbRMwUv9O/1Ww+/mngYeThbc= +google.golang.org/genproto v0.0.0-20220801145646-83ce21fca29f h1:XVHpVMvPs4MtH3h6cThzKs2snNexcfd35vQx2T3IuIY= +google.golang.org/genproto v0.0.0-20220801145646-83ce21fca29f/go.mod h1:iHe1svFLAZg9VWz891+QbRMwUv9O/1Ww+/mngYeThbc= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= diff --git a/internal/common/actions.go b/internal/common/actions.go index cf6b026b..00f01f0e 100644 --- a/internal/common/actions.go +++ b/internal/common/actions.go @@ -102,7 +102,7 @@ func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtua ) error { hasNotifiersPlugin := plugin.Handler.HasNotifiers() hasHook := util.Contains(Config.Actions.ExecuteOn, operation) - hasRules := dataprovider.EventManager.HasFsRules() + hasRules := eventManager.hasFsRules() if !hasHook && !hasNotifiersPlugin && !hasRules { return nil } @@ -113,7 +113,7 @@ func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtua } var errRes error if hasRules { - errRes = dataprovider.EventManager.HandleFsEvent(dataprovider.EventParams{ + errRes = eventManager.handleFsEvent(EventParams{ Name: notification.Username, Event: notification.Action, Status: notification.Status, diff --git a/internal/common/common.go b/internal/common/common.go index e9d074f3..2c1f641e 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -138,11 +138,11 @@ var ( // Config is the configuration for the supported protocols Config Configuration // Connections is the list of active connections - Connections ActiveConnections - transfersChecker TransfersChecker - periodicTimeoutTicker *time.Ticker - periodicTimeoutTickerDone chan bool - supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV, + Connections ActiveConnections + // QuotaScans is the list of active quota scans + QuotaScans ActiveScans + transfersChecker TransfersChecker + supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP, ProtocolHTTPShare, ProtocolOIDC} disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP} // the map key is the protocol, for each protocol we can have multiple rate limiters @@ -157,7 +157,7 @@ func Initialize(c Configuration, isShared int) error { Config.ProxyAllowed = util.RemoveDuplicates(Config.ProxyAllowed, true) Config.idleLoginTimeout = 2 * time.Minute Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute - startPeriodicTimeoutTicker(periodicTimeoutCheckInterval) + startPeriodicChecks(periodicTimeoutCheckInterval) Config.defender = nil Config.whitelist = nil rateLimiters = make(map[string][]*rateLimiter) @@ -308,35 +308,18 @@ func AddDefenderEvent(ip string, event HostEvent) { Config.defender.AddEvent(ip, event) } -// the ticker cannot be started/stopped from multiple goroutines -func startPeriodicTimeoutTicker(duration time.Duration) { - stopPeriodicTimeoutTicker() - periodicTimeoutTicker = time.NewTicker(duration) - periodicTimeoutTickerDone = make(chan bool) - go func() { - counter := int64(0) +func startPeriodicChecks(duration time.Duration) { + startEventScheduler() + spec := fmt.Sprintf("@every %s", duration) + _, err := eventScheduler.AddFunc(spec, Connections.checkTransfers) + util.PanicOnError(err) + logger.Info(logSender, "", "scheduled overquota transfers check, schedule %q", spec) + if Config.IdleTimeout > 0 { ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval - for { - select { - case <-periodicTimeoutTickerDone: - return - case <-periodicTimeoutTicker.C: - counter++ - if Config.IdleTimeout > 0 && counter >= int64(ratio) { - counter = 0 - Connections.checkIdles() - } - go Connections.checkTransfers() - } - } - }() -} - -func stopPeriodicTimeoutTicker() { - if periodicTimeoutTicker != nil { - periodicTimeoutTicker.Stop() - periodicTimeoutTickerDone <- true - periodicTimeoutTicker = nil + spec = fmt.Sprintf("@every %s", duration*ratio) + _, err = eventScheduler.AddFunc(spec, Connections.checkIdles) + util.PanicOnError(err) + logger.Info(logSender, "", "scheduled idle connections check, schedule %q", spec) } } @@ -1162,3 +1145,117 @@ func (c *ConnectionStatus) GetTransfersAsString() string { } return result } + +// ActiveQuotaScan defines an active quota scan for a user home dir +type ActiveQuotaScan struct { + // Username to which the quota scan refers + Username string `json:"username"` + // quota scan start time as unix timestamp in milliseconds + StartTime int64 `json:"start_time"` +} + +// ActiveVirtualFolderQuotaScan defines an active quota scan for a virtual folder +type ActiveVirtualFolderQuotaScan struct { + // folder name to which the quota scan refers + Name string `json:"name"` + // quota scan start time as unix timestamp in milliseconds + StartTime int64 `json:"start_time"` +} + +// ActiveScans holds the active quota scans +type ActiveScans struct { + sync.RWMutex + UserScans []ActiveQuotaScan + FolderScans []ActiveVirtualFolderQuotaScan +} + +// GetUsersQuotaScans returns the active quota scans for users home directories +func (s *ActiveScans) GetUsersQuotaScans() []ActiveQuotaScan { + s.RLock() + defer s.RUnlock() + + scans := make([]ActiveQuotaScan, len(s.UserScans)) + copy(scans, s.UserScans) + return scans +} + +// AddUserQuotaScan adds a user to the ones with active quota scans. +// Returns false if the user has a quota scan already running +func (s *ActiveScans) AddUserQuotaScan(username string) bool { + s.Lock() + defer s.Unlock() + + for _, scan := range s.UserScans { + if scan.Username == username { + return false + } + } + s.UserScans = append(s.UserScans, ActiveQuotaScan{ + Username: username, + StartTime: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + return true +} + +// RemoveUserQuotaScan removes a user from the ones with active quota scans. +// Returns false if the user has no active quota scans +func (s *ActiveScans) RemoveUserQuotaScan(username string) bool { + s.Lock() + defer s.Unlock() + + for idx, scan := range s.UserScans { + if scan.Username == username { + lastIdx := len(s.UserScans) - 1 + s.UserScans[idx] = s.UserScans[lastIdx] + s.UserScans = s.UserScans[:lastIdx] + return true + } + } + + return false +} + +// GetVFoldersQuotaScans returns the active quota scans for virtual folders +func (s *ActiveScans) GetVFoldersQuotaScans() []ActiveVirtualFolderQuotaScan { + s.RLock() + defer s.RUnlock() + scans := make([]ActiveVirtualFolderQuotaScan, len(s.FolderScans)) + copy(scans, s.FolderScans) + return scans +} + +// AddVFolderQuotaScan adds a virtual folder to the ones with active quota scans. +// Returns false if the folder has a quota scan already running +func (s *ActiveScans) AddVFolderQuotaScan(folderName string) bool { + s.Lock() + defer s.Unlock() + + for _, scan := range s.FolderScans { + if scan.Name == folderName { + return false + } + } + s.FolderScans = append(s.FolderScans, ActiveVirtualFolderQuotaScan{ + Name: folderName, + StartTime: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + return true +} + +// RemoveVFolderQuotaScan removes a folder from the ones with active quota scans. +// Returns false if the folder has no active quota scans +func (s *ActiveScans) RemoveVFolderQuotaScan(folderName string) bool { + s.Lock() + defer s.Unlock() + + for idx, scan := range s.FolderScans { + if scan.Name == folderName { + lastIdx := len(s.FolderScans) - 1 + s.FolderScans[idx] = s.FolderScans[lastIdx] + s.FolderScans = s.FolderScans[:lastIdx] + return true + } + } + + return false +} diff --git a/internal/common/common_test.go b/internal/common/common_test.go index 5fc85f02..37b3f25e 100644 --- a/internal/common/common_test.go +++ b/internal/common/common_test.go @@ -21,7 +21,6 @@ import ( "net" "os" "os/exec" - "path" "path/filepath" "runtime" "strings" @@ -533,19 +532,19 @@ func TestIdleConnections(t *testing.T) { assert.Len(t, Connections.sshConnections, 2) Connections.RUnlock() - startPeriodicTimeoutTicker(100 * time.Millisecond) - assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond) + startPeriodicChecks(100 * time.Millisecond) + assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 2*time.Second, 200*time.Millisecond) assert.Eventually(t, func() bool { Connections.RLock() defer Connections.RUnlock() return len(Connections.sshConnections) == 1 }, 1*time.Second, 200*time.Millisecond) - stopPeriodicTimeoutTicker() + stopEventScheduler() assert.Len(t, Connections.GetStats(), 2) c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano() cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano() sshConn2.lastActivity = c.lastActivity - startPeriodicTimeoutTicker(100 * time.Millisecond) + startPeriodicChecks(100 * time.Millisecond) assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond) assert.Eventually(t, func() bool { Connections.RLock() @@ -553,7 +552,7 @@ func TestIdleConnections(t *testing.T) { return len(Connections.sshConnections) == 0 }, 1*time.Second, 200*time.Millisecond) assert.Equal(t, int32(0), Connections.GetClientConnections()) - stopPeriodicTimeoutTicker() + stopEventScheduler() assert.True(t, customConn1.isClosed) assert.True(t, customConn2.isClosed) @@ -719,6 +718,35 @@ func TestConnectionStatus(t *testing.T) { assert.Len(t, stats, 0) } +func TestQuotaScans(t *testing.T) { + username := "username" + assert.True(t, QuotaScans.AddUserQuotaScan(username)) + assert.False(t, QuotaScans.AddUserQuotaScan(username)) + usersScans := QuotaScans.GetUsersQuotaScans() + if assert.Len(t, usersScans, 1) { + assert.Equal(t, usersScans[0].Username, username) + assert.Equal(t, QuotaScans.UserScans[0].StartTime, usersScans[0].StartTime) + QuotaScans.UserScans[0].StartTime = 0 + assert.NotEqual(t, QuotaScans.UserScans[0].StartTime, usersScans[0].StartTime) + } + + assert.True(t, QuotaScans.RemoveUserQuotaScan(username)) + assert.False(t, QuotaScans.RemoveUserQuotaScan(username)) + assert.Len(t, QuotaScans.GetUsersQuotaScans(), 0) + assert.Len(t, usersScans, 1) + + folderName := "folder" + assert.True(t, QuotaScans.AddVFolderQuotaScan(folderName)) + assert.False(t, QuotaScans.AddVFolderQuotaScan(folderName)) + if assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 1) { + assert.Equal(t, QuotaScans.GetVFoldersQuotaScans()[0].Name, folderName) + } + + assert.True(t, QuotaScans.RemoveVFolderQuotaScan(folderName)) + assert.False(t, QuotaScans.RemoveVFolderQuotaScan(folderName)) + assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 0) +} + func TestProxyProtocolVersion(t *testing.T) { c := Configuration{ ProxyProtocol: 0, @@ -1033,110 +1061,6 @@ func TestUserRecentActivity(t *testing.T) { assert.True(t, res) } -func TestEventRuleMatch(t *testing.T) { - conditions := dataprovider.EventConditions{ - ProviderEvents: []string{"add", "update"}, - Options: dataprovider.ConditionOptions{ - Names: []dataprovider.ConditionPattern{ - { - Pattern: "user1", - InverseMatch: true, - }, - }, - }, - } - res := conditions.ProviderEventMatch(dataprovider.EventParams{ - Name: "user1", - Event: "add", - }) - assert.False(t, res) - res = conditions.ProviderEventMatch(dataprovider.EventParams{ - Name: "user2", - Event: "update", - }) - assert.True(t, res) - res = conditions.ProviderEventMatch(dataprovider.EventParams{ - Name: "user2", - Event: "delete", - }) - assert.False(t, res) - conditions.Options.ProviderObjects = []string{"api_key"} - res = conditions.ProviderEventMatch(dataprovider.EventParams{ - Name: "user2", - Event: "update", - ObjectType: "share", - }) - assert.False(t, res) - res = conditions.ProviderEventMatch(dataprovider.EventParams{ - Name: "user2", - Event: "update", - ObjectType: "api_key", - }) - assert.True(t, res) - // now test fs events - conditions = dataprovider.EventConditions{ - FsEvents: []string{operationUpload, operationDownload}, - Options: dataprovider.ConditionOptions{ - Names: []dataprovider.ConditionPattern{ - { - Pattern: "user*", - }, - { - Pattern: "tester*", - }, - }, - FsPaths: []dataprovider.ConditionPattern{ - { - Pattern: "*.txt", - }, - }, - Protocols: []string{ProtocolSFTP}, - MinFileSize: 10, - MaxFileSize: 30, - }, - } - params := dataprovider.EventParams{ - Name: "tester4", - Event: operationDelete, - VirtualPath: "/path.txt", - Protocol: ProtocolSFTP, - ObjectName: "path.txt", - FileSize: 20, - } - res = conditions.FsEventMatch(params) - assert.False(t, res) - params.Event = operationDownload - res = conditions.FsEventMatch(params) - assert.True(t, res) - params.Name = "name" - res = conditions.FsEventMatch(params) - assert.False(t, res) - params.Name = "user5" - res = conditions.FsEventMatch(params) - assert.True(t, res) - params.VirtualPath = "/sub/f.jpg" - params.ObjectName = path.Base(params.VirtualPath) - res = conditions.FsEventMatch(params) - assert.False(t, res) - params.VirtualPath = "/sub/f.txt" - params.ObjectName = path.Base(params.VirtualPath) - res = conditions.FsEventMatch(params) - assert.True(t, res) - params.Protocol = ProtocolHTTP - res = conditions.FsEventMatch(params) - assert.False(t, res) - params.Protocol = ProtocolSFTP - params.FileSize = 5 - res = conditions.FsEventMatch(params) - assert.False(t, res) - params.FileSize = 50 - res = conditions.FsEventMatch(params) - assert.False(t, res) - params.FileSize = 25 - res = conditions.FsEventMatch(params) - assert.True(t, res) -} - func BenchmarkBcryptHashing(b *testing.B) { bcryptPassword := "bcryptpassword" for i := 0; i < b.N; i++ { diff --git a/internal/common/eventmanager.go b/internal/common/eventmanager.go new file mode 100644 index 00000000..c766f415 --- /dev/null +++ b/internal/common/eventmanager.go @@ -0,0 +1,776 @@ +// Copyright (C) 2019-2022 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/exec" + "path" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/robfig/cron/v3" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +var ( + // eventManager handle the supported event rules actions + eventManager eventRulesContainer +) + +func init() { + eventManager = eventRulesContainer{ + schedulesMapping: make(map[string][]cron.EntryID), + } + dataprovider.SetEventRulesCallbacks(eventManager.loadRules, eventManager.RemoveRule, + func(operation, executor, ip, objectType, objectName string, object plugin.Renderer) { + eventManager.handleProviderEvent(EventParams{ + Name: executor, + ObjectName: objectName, + Event: operation, + Status: 1, + ObjectType: objectType, + IP: ip, + Timestamp: time.Now().UnixNano(), + Object: object, + }) + }) +} + +// eventRulesContainer stores event rules by trigger +type eventRulesContainer struct { + sync.RWMutex + FsEvents []dataprovider.EventRule + ProviderEvents []dataprovider.EventRule + Schedules []dataprovider.EventRule + schedulesMapping map[string][]cron.EntryID + lastLoad int64 +} + +func (r *eventRulesContainer) getLastLoadTime() int64 { + return atomic.LoadInt64(&r.lastLoad) +} + +func (r *eventRulesContainer) setLastLoadTime(modTime int64) { + atomic.StoreInt64(&r.lastLoad, modTime) +} + +// RemoveRule deletes the rule with the specified name +func (r *eventRulesContainer) RemoveRule(name string) { + r.Lock() + defer r.Unlock() + + r.removeRuleInternal(name) + eventManagerLog(logger.LevelDebug, "event rules updated after delete, fs events: %d, provider events: %d, schedules: %d", + len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules)) +} + +func (r *eventRulesContainer) removeRuleInternal(name string) { + for idx := range r.FsEvents { + if r.FsEvents[idx].Name == name { + lastIdx := len(r.FsEvents) - 1 + r.FsEvents[idx] = r.FsEvents[lastIdx] + r.FsEvents = r.FsEvents[:lastIdx] + eventManagerLog(logger.LevelDebug, "removed rule %q from fs events", name) + return + } + } + for idx := range r.ProviderEvents { + if r.ProviderEvents[idx].Name == name { + lastIdx := len(r.ProviderEvents) - 1 + r.ProviderEvents[idx] = r.ProviderEvents[lastIdx] + r.ProviderEvents = r.ProviderEvents[:lastIdx] + eventManagerLog(logger.LevelDebug, "removed rule %q from provider events", name) + return + } + } + for idx := range r.Schedules { + if r.Schedules[idx].Name == name { + if schedules, ok := r.schedulesMapping[name]; ok { + for _, entryID := range schedules { + eventManagerLog(logger.LevelDebug, "removing scheduled entry id %d for rule %q", entryID, name) + eventScheduler.Remove(entryID) + } + delete(r.schedulesMapping, name) + } + + lastIdx := len(r.Schedules) - 1 + r.Schedules[idx] = r.Schedules[lastIdx] + r.Schedules = r.Schedules[:lastIdx] + eventManagerLog(logger.LevelDebug, "removed rule %q from scheduled events", name) + return + } + } +} + +func (r *eventRulesContainer) addUpdateRuleInternal(rule dataprovider.EventRule) { + r.removeRuleInternal(rule.Name) + if rule.DeletedAt > 0 { + deletedAt := util.GetTimeFromMsecSinceEpoch(rule.DeletedAt) + if deletedAt.Add(30 * time.Minute).Before(time.Now()) { + eventManagerLog(logger.LevelDebug, "removing rule %q deleted at %s", rule.Name, deletedAt) + go dataprovider.RemoveEventRule(rule) //nolint:errcheck + } + return + } + switch rule.Trigger { + case dataprovider.EventTriggerFsEvent: + r.FsEvents = append(r.FsEvents, rule) + eventManagerLog(logger.LevelDebug, "added rule %q to fs events", rule.Name) + case dataprovider.EventTriggerProviderEvent: + r.ProviderEvents = append(r.ProviderEvents, rule) + eventManagerLog(logger.LevelDebug, "added rule %q to provider events", rule.Name) + case dataprovider.EventTriggerSchedule: + for _, schedule := range rule.Conditions.Schedules { + cronSpec := schedule.GetCronSpec() + job := &eventCronJob{ + ruleName: dataprovider.ConvertName(rule.Name), + } + entryID, err := eventScheduler.AddJob(cronSpec, job) + if err != nil { + eventManagerLog(logger.LevelError, "unable to add scheduled rule %q, cron string %q: %v", rule.Name, cronSpec, err) + return + } + r.schedulesMapping[rule.Name] = append(r.schedulesMapping[rule.Name], entryID) + eventManagerLog(logger.LevelDebug, "schedule for rule %q added, id: %d, cron string %q, active scheduling rules: %d", + rule.Name, entryID, cronSpec, len(r.schedulesMapping)) + } + r.Schedules = append(r.Schedules, rule) + eventManagerLog(logger.LevelDebug, "added rule %q to scheduled events", rule.Name) + default: + eventManagerLog(logger.LevelError, "unsupported trigger: %d", rule.Trigger) + } +} + +func (r *eventRulesContainer) loadRules() { + eventManagerLog(logger.LevelDebug, "loading updated rules") + modTime := util.GetTimeAsMsSinceEpoch(time.Now()) + rules, err := dataprovider.GetRecentlyUpdatedRules(r.getLastLoadTime()) + if err != nil { + eventManagerLog(logger.LevelError, "unable to load event rules: %v", err) + return + } + eventManagerLog(logger.LevelDebug, "recently updated event rules loaded: %d", len(rules)) + + if len(rules) > 0 { + r.Lock() + defer r.Unlock() + + for _, rule := range rules { + r.addUpdateRuleInternal(rule) + } + } + eventManagerLog(logger.LevelDebug, "event rules updated, fs events: %d, provider events: %d, schedules: %d", + len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules)) + + r.setLastLoadTime(modTime) +} + +func (r *eventRulesContainer) checkProviderEventMatch(conditions dataprovider.EventConditions, params EventParams) bool { + if !util.Contains(conditions.ProviderEvents, params.Event) { + return false + } + if !checkEventConditionPatterns(params.Name, conditions.Options.Names) { + return false + } + if len(conditions.Options.ProviderObjects) > 0 && !util.Contains(conditions.Options.ProviderObjects, params.ObjectType) { + return false + } + return true +} + +func (r *eventRulesContainer) checkFsEventMatch(conditions dataprovider.EventConditions, params EventParams) bool { + if !util.Contains(conditions.FsEvents, params.Event) { + return false + } + if !checkEventConditionPatterns(params.Name, conditions.Options.Names) { + return false + } + if !checkEventConditionPatterns(params.VirtualPath, conditions.Options.FsPaths) { + if !checkEventConditionPatterns(params.ObjectName, conditions.Options.FsPaths) { + return false + } + } + if len(conditions.Options.Protocols) > 0 && !util.Contains(conditions.Options.Protocols, params.Protocol) { + return false + } + if params.Event == operationUpload || params.Event == operationDownload { + if conditions.Options.MinFileSize > 0 { + if params.FileSize < conditions.Options.MinFileSize { + return false + } + } + if conditions.Options.MaxFileSize > 0 { + if params.FileSize > conditions.Options.MaxFileSize { + return false + } + } + } + return true +} + +// hasFsRules returns true if there are any rules for filesystem event triggers +func (r *eventRulesContainer) hasFsRules() bool { + r.RLock() + defer r.RUnlock() + + return len(r.FsEvents) > 0 +} + +// handleFsEvent executes the rules actions defined for the specified event +func (r *eventRulesContainer) handleFsEvent(params EventParams) error { + r.RLock() + + var rulesWithSyncActions, rulesAsync []dataprovider.EventRule + for _, rule := range r.FsEvents { + if r.checkFsEventMatch(rule.Conditions, params) { + hasSyncActions := false + for _, action := range rule.Actions { + if action.Options.ExecuteSync { + hasSyncActions = true + break + } + } + if hasSyncActions { + rulesWithSyncActions = append(rulesWithSyncActions, rule) + } else { + rulesAsync = append(rulesAsync, rule) + } + } + } + + r.RUnlock() + + if len(rulesAsync) > 0 { + go executeAsyncRulesActions(rulesAsync, params) + } + + if len(rulesWithSyncActions) > 0 { + return executeSyncRulesActions(rulesWithSyncActions, params) + } + return nil +} + +func (r *eventRulesContainer) handleProviderEvent(params EventParams) { + r.RLock() + defer r.RUnlock() + + var rules []dataprovider.EventRule + for _, rule := range r.ProviderEvents { + if r.checkProviderEventMatch(rule.Conditions, params) { + rules = append(rules, rule) + } + } + + if len(rules) > 0 { + go executeAsyncRulesActions(rules, params) + } +} + +// EventParams defines the supported event parameters +type EventParams struct { + Name string + Event string + Status int + VirtualPath string + FsPath string + VirtualTargetPath string + FsTargetPath string + ObjectName string + ObjectType string + FileSize int64 + Protocol string + IP string + Timestamp int64 + Object plugin.Renderer +} + +func (p *EventParams) getStringReplacements(addObjectData bool) []string { + replacements := []string{ + "{{Name}}", p.Name, + "{{Event}}", p.Event, + "{{Status}}", fmt.Sprintf("%d", p.Status), + "{{VirtualPath}}", p.VirtualPath, + "{{FsPath}}", p.FsPath, + "{{VirtualTargetPath}}", p.VirtualTargetPath, + "{{FsTargetPath}}", p.FsTargetPath, + "{{ObjectName}}", p.ObjectName, + "{{ObjectType}}", p.ObjectType, + "{{FileSize}}", fmt.Sprintf("%d", p.FileSize), + "{{Protocol}}", p.Protocol, + "{{IP}}", p.IP, + "{{Timestamp}}", fmt.Sprintf("%d", p.Timestamp), + } + if addObjectData { + data, err := p.Object.RenderAsJSON(p.Event != operationDelete) + if err == nil { + replacements = append(replacements, "{{ObjectData}}", string(data)) + } + } + return replacements +} + +func replaceWithReplacer(input string, replacer *strings.Replacer) string { + if !strings.Contains(input, "{{") { + return input + } + return replacer.Replace(input) +} + +func checkEventConditionPattern(p dataprovider.ConditionPattern, name string) bool { + matched, err := path.Match(p.Pattern, name) + if err != nil { + eventManagerLog(logger.LevelError, "pattern matching error %q, err: %v", p.Pattern, err) + return false + } + if p.InverseMatch { + return !matched + } + return matched +} + +// checkConditionPatterns returns false if patterns are defined and no match is found +func checkEventConditionPatterns(name string, patterns []dataprovider.ConditionPattern) bool { + if len(patterns) == 0 { + return true + } + for _, p := range patterns { + if checkEventConditionPattern(p, name) { + return true + } + } + + return false +} + +func getHTTPRuleActionEndpoint(c dataprovider.EventActionHTTPConfig, replacer *strings.Replacer) (string, error) { + if len(c.QueryParameters) > 0 { + u, err := url.Parse(c.Endpoint) + if err != nil { + return "", fmt.Errorf("invalid endpoint: %w", err) + } + q := u.Query() + + for _, keyVal := range c.QueryParameters { + q.Add(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer)) + } + + u.RawQuery = q.Encode() + return u.String(), nil + } + return c.Endpoint, nil +} + +func executeHTTPRuleAction(c dataprovider.EventActionHTTPConfig, params EventParams) error { + if !c.Password.IsEmpty() { + if err := c.Password.TryDecrypt(); err != nil { + return fmt.Errorf("unable to decrypt password: %w", err) + } + } + addObjectData := false + if params.Object != nil { + if !addObjectData { + if strings.Contains(c.Body, "{{ObjectData}}") { + addObjectData = true + } + } + } + + replacements := params.getStringReplacements(addObjectData) + replacer := strings.NewReplacer(replacements...) + endpoint, err := getHTTPRuleActionEndpoint(c, replacer) + if err != nil { + return err + } + + var body io.Reader + if c.Body != "" && c.Method != http.MethodGet { + body = bytes.NewBufferString(replaceWithReplacer(c.Body, replacer)) + } + req, err := http.NewRequest(c.Method, endpoint, body) + if err != nil { + return err + } + if c.Username != "" { + req.SetBasicAuth(replaceWithReplacer(c.Username, replacer), c.Password.GetAdditionalData()) + } + for _, keyVal := range c.Headers { + req.Header.Set(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer)) + } + client := c.GetHTTPClient() + defer client.CloseIdleConnections() + + startTime := time.Now() + resp, err := client.Do(req) + if err != nil { + eventManagerLog(logger.LevelDebug, "unable to send http notification, endpoint: %s, elapsed: %s, err: %v", + endpoint, time.Since(startTime), err) + return err + } + defer resp.Body.Close() + + eventManagerLog(logger.LevelDebug, "http notification sent, endopoint: %s, elapsed: %s, status code: %d", + endpoint, time.Since(startTime), resp.StatusCode) + if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return nil +} + +func executeCommandRuleAction(c dataprovider.EventActionCommandConfig, params EventParams) error { + envVars := make([]string, 0, len(c.EnvVars)) + addObjectData := false + if params.Object != nil { + for _, k := range c.EnvVars { + if strings.Contains(k.Value, "{{ObjectData}}") { + addObjectData = true + break + } + } + } + replacements := params.getStringReplacements(addObjectData) + replacer := strings.NewReplacer(replacements...) + for _, keyVal := range c.EnvVars { + envVars = append(envVars, fmt.Sprintf("%s=%s", keyVal.Key, replaceWithReplacer(keyVal.Value, replacer))) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.Timeout)*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, c.Cmd) + cmd.Env = append(cmd.Env, os.Environ()...) + cmd.Env = append(cmd.Env, envVars...) + + startTime := time.Now() + err := cmd.Run() + + eventManagerLog(logger.LevelDebug, "executed command %q, elapsed: %s, error: %v", + c.Cmd, time.Since(startTime), err) + + return err +} + +func executeEmailRuleAction(c dataprovider.EventActionEmailConfig, params EventParams) error { + addObjectData := false + if params.Object != nil { + if strings.Contains(c.Body, "{{ObjectData}}") { + addObjectData = true + } + } + replacements := params.getStringReplacements(addObjectData) + replacer := strings.NewReplacer(replacements...) + body := replaceWithReplacer(c.Body, replacer) + subject := replaceWithReplacer(c.Subject, replacer) + startTime := time.Now() + err := smtp.SendEmail(c.Recipients, subject, body, smtp.EmailContentTypeTextPlain) + eventManagerLog(logger.LevelDebug, "executed email notification action, elapsed: %s, error: %v", + time.Since(startTime), err) + return err +} + +func executeUsersQuotaResetRuleAction(conditions dataprovider.ConditionOptions) error { + users, err := dataprovider.DumpUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + var failedResets []string + for _, user := range users { + if !checkEventConditionPatterns(user.Username, conditions.Names) { + eventManagerLog(logger.LevelDebug, "skipping scheduled quota reset for user %s, name conditions don't match", + user.Username) + continue + } + if !QuotaScans.AddUserQuotaScan(user.Username) { + eventManagerLog(logger.LevelError, "another quota scan is already in progress for user %s", user.Username) + failedResets = append(failedResets, user.Username) + continue + } + numFiles, size, err := user.ScanQuota() + QuotaScans.RemoveUserQuotaScan(user.Username) + if err != nil { + eventManagerLog(logger.LevelError, "error scanning quota for user %s: %v", user.Username, err) + failedResets = append(failedResets, user.Username) + continue + } + err = dataprovider.UpdateUserQuota(&user, numFiles, size, true) + if err != nil { + eventManagerLog(logger.LevelError, "error updating quota for user %s: %v", user.Username, err) + failedResets = append(failedResets, user.Username) + continue + } + } + if len(failedResets) > 0 { + return fmt.Errorf("quota reset failed for users: %+v", failedResets) + } + return nil +} + +func executeFoldersQuotaResetRuleAction(conditions dataprovider.ConditionOptions) error { + folders, err := dataprovider.DumpFolders() + if err != nil { + return fmt.Errorf("unable to get folders: %w", err) + } + var failedResets []string + for _, folder := range folders { + if !checkEventConditionPatterns(folder.Name, conditions.Names) { + eventManagerLog(logger.LevelDebug, "skipping scheduled quota reset for folder %s, name conditions don't match", + folder.Name) + continue + } + if !QuotaScans.AddVFolderQuotaScan(folder.Name) { + eventManagerLog(logger.LevelError, "another quota scan is already in progress for folder %s", folder.Name) + failedResets = append(failedResets, folder.Name) + continue + } + f := vfs.VirtualFolder{ + BaseVirtualFolder: folder, + VirtualPath: "/", + } + numFiles, size, err := f.ScanQuota() + QuotaScans.RemoveVFolderQuotaScan(folder.Name) + if err != nil { + eventManagerLog(logger.LevelError, "error scanning quota for folder %s: %v", folder.Name, err) + failedResets = append(failedResets, folder.Name) + continue + } + err = dataprovider.UpdateVirtualFolderQuota(&folder, numFiles, size, true) + if err != nil { + eventManagerLog(logger.LevelError, "error updating quota for folder %s: %v", folder.Name, err) + failedResets = append(failedResets, folder.Name) + continue + } + } + if len(failedResets) > 0 { + return fmt.Errorf("quota reset failed for folders: %+v", failedResets) + } + return nil +} + +func executeTransferQuotaResetRuleAction(conditions dataprovider.ConditionOptions) error { + users, err := dataprovider.DumpUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + var failedResets []string + for _, user := range users { + if !checkEventConditionPatterns(user.Username, conditions.Names) { + eventManagerLog(logger.LevelDebug, "skipping scheduled transfer quota reset for user %s, name conditions don't match", + user.Username) + continue + } + err = dataprovider.UpdateUserTransferQuota(&user, 0, 0, true) + if err != nil { + eventManagerLog(logger.LevelError, "error updating transfer quota for user %s: %v", user.Username, err) + failedResets = append(failedResets, user.Username) + continue + } + } + if len(failedResets) > 0 { + return fmt.Errorf("transfer quota reset failed for users: %+v", failedResets) + } + return nil +} + +func executeRuleAction(action dataprovider.BaseEventAction, params EventParams, conditions dataprovider.ConditionOptions) error { + switch action.Type { + case dataprovider.ActionTypeHTTP: + return executeHTTPRuleAction(action.Options.HTTPConfig, params) + case dataprovider.ActionTypeCommand: + return executeCommandRuleAction(action.Options.CmdConfig, params) + case dataprovider.ActionTypeEmail: + return executeEmailRuleAction(action.Options.EmailConfig, params) + case dataprovider.ActionTypeBackup: + return dataprovider.ExecuteBackup() + case dataprovider.ActionTypeUserQuotaReset: + return executeUsersQuotaResetRuleAction(conditions) + case dataprovider.ActionTypeFolderQuotaReset: + return executeFoldersQuotaResetRuleAction(conditions) + case dataprovider.ActionTypeTransferQuotaReset: + return executeTransferQuotaResetRuleAction(conditions) + default: + return fmt.Errorf("unsupported action type: %d", action.Type) + } +} + +func executeSyncRulesActions(rules []dataprovider.EventRule, params EventParams) error { + var errRes error + + for _, rule := range rules { + var failedActions []string + for _, action := range rule.Actions { + if !action.Options.IsFailureAction && action.Options.ExecuteSync { + startTime := time.Now() + if err := executeRuleAction(action.BaseEventAction, params, rule.Conditions.Options); err != nil { + eventManagerLog(logger.LevelError, "unable to execute sync action %q for rule %q, elapsed %s, err: %v", + action.Name, rule.Name, time.Since(startTime), err) + failedActions = append(failedActions, action.Name) + // we return the last error, it is ok for now + errRes = err + if action.Options.StopOnFailure { + break + } + } else { + eventManagerLog(logger.LevelDebug, "executed sync action %q for rule %q, elapsed: %s", + action.Name, rule.Name, time.Since(startTime)) + } + } + } + // execute async actions if any, including failure actions + go executeRuleAsyncActions(rule, params, failedActions) + } + + return errRes +} + +func executeAsyncRulesActions(rules []dataprovider.EventRule, params EventParams) { + for _, rule := range rules { + executeRuleAsyncActions(rule, params, nil) + } +} + +func executeRuleAsyncActions(rule dataprovider.EventRule, params EventParams, failedActions []string) { + for _, action := range rule.Actions { + if !action.Options.IsFailureAction && !action.Options.ExecuteSync { + startTime := time.Now() + if err := executeRuleAction(action.BaseEventAction, params, rule.Conditions.Options); err != nil { + eventManagerLog(logger.LevelError, "unable to execute action %q for rule %q, elapsed %s, err: %v", + action.Name, rule.Name, time.Since(startTime), err) + failedActions = append(failedActions, action.Name) + if action.Options.StopOnFailure { + break + } + } else { + eventManagerLog(logger.LevelDebug, "executed action %q for rule %q, elapsed %s", + action.Name, rule.Name, time.Since(startTime)) + } + } + } + if len(failedActions) > 0 { + // execute failure actions + for _, action := range rule.Actions { + if action.Options.IsFailureAction { + startTime := time.Now() + if err := executeRuleAction(action.BaseEventAction, params, rule.Conditions.Options); err != nil { + eventManagerLog(logger.LevelError, "unable to execute failure action %q for rule %q, elapsed %s, err: %v", + action.Name, rule.Name, time.Since(startTime), err) + if action.Options.StopOnFailure { + break + } + } else { + eventManagerLog(logger.LevelDebug, "executed failure action %q for rule %q, elapsed: %s", + action.Name, rule.Name, time.Since(startTime)) + } + } + } + } +} + +type eventCronJob struct { + ruleName string +} + +func (j *eventCronJob) getTask(rule dataprovider.EventRule) (dataprovider.Task, error) { + if rule.GuardFromConcurrentExecution() { + task, err := dataprovider.GetTaskByName(rule.Name) + if _, ok := err.(*util.RecordNotFoundError); ok { + eventManagerLog(logger.LevelDebug, "adding task for rule %q", rule.Name) + task = dataprovider.Task{ + Name: rule.Name, + UpdateAt: 0, + Version: 0, + } + err = dataprovider.AddTask(rule.Name) + if err != nil { + eventManagerLog(logger.LevelWarn, "unable to add task for rule %q: %v", rule.Name, err) + return task, err + } + } else { + eventManagerLog(logger.LevelWarn, "unable to get task for rule %q: %v", rule.Name, err) + } + return task, err + } + + return dataprovider.Task{}, nil +} + +func (j *eventCronJob) Run() { + eventManagerLog(logger.LevelDebug, "executing scheduled rule %q", j.ruleName) + rule, err := dataprovider.EventRuleExists(j.ruleName) + if err != nil { + eventManagerLog(logger.LevelError, "unable to load rule with name %q", j.ruleName) + return + } + task, err := j.getTask(rule) + if err != nil { + return + } + if task.Name != "" { + updateInterval := 5 * time.Minute + updatedAt := util.GetTimeFromMsecSinceEpoch(task.UpdateAt) + if updatedAt.Add(updateInterval*2 + 1).After(time.Now()) { + eventManagerLog(logger.LevelDebug, "task for rule %q too recent: %s, skip execution", rule.Name, updatedAt) + return + } + err = dataprovider.UpdateTask(rule.Name, task.Version) + if err != nil { + eventManagerLog(logger.LevelInfo, "unable to update task timestamp for rule %q, skip execution, err: %v", + rule.Name, err) + return + } + ticker := time.NewTicker(updateInterval) + done := make(chan bool) + + go func(taskName string) { + eventManagerLog(logger.LevelDebug, "update task %q timestamp worker started", taskName) + for { + select { + case <-done: + eventManagerLog(logger.LevelDebug, "update task %q timestamp worker finished", taskName) + return + case <-ticker.C: + err := dataprovider.UpdateTaskTimestamp(taskName) + eventManagerLog(logger.LevelInfo, "updated timestamp for task %q, err: %v", taskName, err) + } + } + }(task.Name) + + executeRuleAsyncActions(rule, EventParams{}, nil) + + done <- true + ticker.Stop() + } else { + executeRuleAsyncActions(rule, EventParams{}, nil) + } + eventManagerLog(logger.LevelDebug, "execution for scheduled rule %q finished", j.ruleName) +} + +func eventManagerLog(level logger.LogLevel, format string, v ...any) { + logger.Log(level, "eventmanager", "", format, v...) +} diff --git a/internal/common/eventmanager_test.go b/internal/common/eventmanager_test.go new file mode 100644 index 00000000..7cff17d0 --- /dev/null +++ b/internal/common/eventmanager_test.go @@ -0,0 +1,674 @@ +// Copyright (C) 2019-2022 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "fmt" + "net/http" + "os" + "path" + "path/filepath" + "testing" + "time" + + "github.com/sftpgo/sdk" + sdkkms "github.com/sftpgo/sdk/kms" + "github.com/stretchr/testify/assert" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +func TestEventRuleMatch(t *testing.T) { + conditions := dataprovider.EventConditions{ + ProviderEvents: []string{"add", "update"}, + Options: dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "user1", + InverseMatch: true, + }, + }, + }, + } + res := eventManager.checkProviderEventMatch(conditions, EventParams{ + Name: "user1", + Event: "add", + }) + assert.False(t, res) + res = eventManager.checkProviderEventMatch(conditions, EventParams{ + Name: "user2", + Event: "update", + }) + assert.True(t, res) + res = eventManager.checkProviderEventMatch(conditions, EventParams{ + Name: "user2", + Event: "delete", + }) + assert.False(t, res) + conditions.Options.ProviderObjects = []string{"api_key"} + res = eventManager.checkProviderEventMatch(conditions, EventParams{ + Name: "user2", + Event: "update", + ObjectType: "share", + }) + assert.False(t, res) + res = eventManager.checkProviderEventMatch(conditions, EventParams{ + Name: "user2", + Event: "update", + ObjectType: "api_key", + }) + assert.True(t, res) + // now test fs events + conditions = dataprovider.EventConditions{ + FsEvents: []string{operationUpload, operationDownload}, + Options: dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "user*", + }, + { + Pattern: "tester*", + }, + }, + FsPaths: []dataprovider.ConditionPattern{ + { + Pattern: "*.txt", + }, + }, + Protocols: []string{ProtocolSFTP}, + MinFileSize: 10, + MaxFileSize: 30, + }, + } + params := EventParams{ + Name: "tester4", + Event: operationDelete, + VirtualPath: "/path.txt", + Protocol: ProtocolSFTP, + ObjectName: "path.txt", + FileSize: 20, + } + res = eventManager.checkFsEventMatch(conditions, params) + assert.False(t, res) + params.Event = operationDownload + res = eventManager.checkFsEventMatch(conditions, params) + assert.True(t, res) + params.Name = "name" + res = eventManager.checkFsEventMatch(conditions, params) + assert.False(t, res) + params.Name = "user5" + res = eventManager.checkFsEventMatch(conditions, params) + assert.True(t, res) + params.VirtualPath = "/sub/f.jpg" + params.ObjectName = path.Base(params.VirtualPath) + res = eventManager.checkFsEventMatch(conditions, params) + assert.False(t, res) + params.VirtualPath = "/sub/f.txt" + params.ObjectName = path.Base(params.VirtualPath) + res = eventManager.checkFsEventMatch(conditions, params) + assert.True(t, res) + params.Protocol = ProtocolHTTP + res = eventManager.checkFsEventMatch(conditions, params) + assert.False(t, res) + params.Protocol = ProtocolSFTP + params.FileSize = 5 + res = eventManager.checkFsEventMatch(conditions, params) + assert.False(t, res) + params.FileSize = 50 + res = eventManager.checkFsEventMatch(conditions, params) + assert.False(t, res) + params.FileSize = 25 + res = eventManager.checkFsEventMatch(conditions, params) + assert.True(t, res) + // bad pattern + conditions.Options.Names = []dataprovider.ConditionPattern{ + { + Pattern: "[-]", + }, + } + res = eventManager.checkFsEventMatch(conditions, params) + assert.False(t, res) +} + +func TestEventManager(t *testing.T) { + startEventScheduler() + action := &dataprovider.BaseEventAction{ + Name: "test_action", + Type: dataprovider.ActionTypeHTTP, + Options: dataprovider.BaseEventActionOptions{ + HTTPConfig: dataprovider.EventActionHTTPConfig{ + Endpoint: "http://localhost", + Timeout: 20, + Method: http.MethodGet, + }, + }, + } + err := dataprovider.AddEventAction(action, "", "") + assert.NoError(t, err) + rule := &dataprovider.EventRule{ + Name: "rule", + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{operationUpload}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action.Name, + }, + Order: 1, + }, + }, + } + + err = dataprovider.AddEventRule(rule, "", "") + assert.NoError(t, err) + + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 1) + assert.Len(t, eventManager.ProviderEvents, 0) + assert.Len(t, eventManager.Schedules, 0) + assert.Len(t, eventManager.schedulesMapping, 0) + eventManager.RUnlock() + + rule.Trigger = dataprovider.EventTriggerProviderEvent + rule.Conditions = dataprovider.EventConditions{ + ProviderEvents: []string{"add"}, + } + err = dataprovider.UpdateEventRule(rule, "", "") + assert.NoError(t, err) + + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 0) + assert.Len(t, eventManager.ProviderEvents, 1) + assert.Len(t, eventManager.Schedules, 0) + assert.Len(t, eventManager.schedulesMapping, 0) + eventManager.RUnlock() + + rule.Trigger = dataprovider.EventTriggerSchedule + rule.Conditions = dataprovider.EventConditions{ + Schedules: []dataprovider.Schedule{ + { + Hours: "0", + DayOfWeek: "*", + DayOfMonth: "*", + Month: "*", + }, + }, + } + rule.DeletedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-12 * time.Hour)) + eventManager.addUpdateRuleInternal(*rule) + + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 0) + assert.Len(t, eventManager.ProviderEvents, 0) + assert.Len(t, eventManager.Schedules, 0) + assert.Len(t, eventManager.schedulesMapping, 0) + eventManager.RUnlock() + + assert.Eventually(t, func() bool { + _, err = dataprovider.EventRuleExists(rule.Name) + _, ok := err.(*util.RecordNotFoundError) + return ok + }, 2*time.Second, 100*time.Millisecond) + + rule.DeletedAt = 0 + err = dataprovider.AddEventRule(rule, "", "") + assert.NoError(t, err) + + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 0) + assert.Len(t, eventManager.ProviderEvents, 0) + assert.Len(t, eventManager.Schedules, 1) + assert.Len(t, eventManager.schedulesMapping, 1) + eventManager.RUnlock() + + err = dataprovider.DeleteEventRule(rule.Name, "", "") + assert.NoError(t, err) + + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 0) + assert.Len(t, eventManager.ProviderEvents, 0) + assert.Len(t, eventManager.Schedules, 0) + assert.Len(t, eventManager.schedulesMapping, 0) + eventManager.RUnlock() + + err = dataprovider.DeleteEventAction(action.Name, "", "") + assert.NoError(t, err) + stopEventScheduler() +} + +func TestEventManagerErrors(t *testing.T) { + startEventScheduler() + providerConf := dataprovider.GetProviderConfig() + err := dataprovider.Close() + assert.NoError(t, err) + + err = executeUsersQuotaResetRuleAction(dataprovider.ConditionOptions{}) + assert.Error(t, err) + err = executeFoldersQuotaResetRuleAction(dataprovider.ConditionOptions{}) + assert.Error(t, err) + err = executeTransferQuotaResetRuleAction(dataprovider.ConditionOptions{}) + assert.Error(t, err) + + eventManager.loadRules() + + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 0) + assert.Len(t, eventManager.ProviderEvents, 0) + assert.Len(t, eventManager.Schedules, 0) + eventManager.RUnlock() + + // rule with invalid trigger + eventManager.addUpdateRuleInternal(dataprovider.EventRule{ + Name: "test rule", + Trigger: -1, + }) + + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 0) + assert.Len(t, eventManager.ProviderEvents, 0) + assert.Len(t, eventManager.Schedules, 0) + eventManager.RUnlock() + // rule with invalid cronspec + eventManager.addUpdateRuleInternal(dataprovider.EventRule{ + Name: "test rule", + Trigger: dataprovider.EventTriggerSchedule, + Conditions: dataprovider.EventConditions{ + Schedules: []dataprovider.Schedule{ + { + Hours: "1000", + }, + }, + }, + }) + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 0) + assert.Len(t, eventManager.ProviderEvents, 0) + assert.Len(t, eventManager.Schedules, 0) + eventManager.RUnlock() + + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + stopEventScheduler() +} + +func TestEventRuleActions(t *testing.T) { + actionName := "test rule action" + action := dataprovider.BaseEventAction{ + Name: actionName, + Type: dataprovider.ActionTypeBackup, + } + err := executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{}) + assert.NoError(t, err) + action.Type = -1 + err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{}) + assert.Error(t, err) + + action = dataprovider.BaseEventAction{ + Name: actionName, + Type: dataprovider.ActionTypeHTTP, + Options: dataprovider.BaseEventActionOptions{ + HTTPConfig: dataprovider.EventActionHTTPConfig{ + Endpoint: "http://foo\x7f.com/", // invalid URL + SkipTLSVerify: true, + Body: "{{ObjectData}}", + Method: http.MethodPost, + QueryParameters: []dataprovider.KeyValue{ + { + Key: "param", + Value: "value", + }, + }, + Timeout: 5, + Headers: []dataprovider.KeyValue{ + { + Key: "Content-Type", + Value: "application/json", + }, + }, + Username: "httpuser", + }, + }, + } + action.Options.SetEmptySecretsIfNil() + err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid endpoint") + } + action.Options.HTTPConfig.Endpoint = fmt.Sprintf("http://%v", httpAddr) + params := EventParams{ + Name: "a", + Object: &dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "test user", + }, + }, + } + err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) + assert.NoError(t, err) + action.Options.HTTPConfig.Endpoint = fmt.Sprintf("http://%v/404", httpAddr) + err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) + if assert.Error(t, err) { + assert.Equal(t, err.Error(), "unexpected status code: 404") + } + action.Options.HTTPConfig.Endpoint = "http://invalid:1234" + err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) + assert.Error(t, err) + action.Options.HTTPConfig.QueryParameters = nil + action.Options.HTTPConfig.Endpoint = "http://bar\x7f.com/" + err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) + assert.Error(t, err) + action.Options.HTTPConfig.Password = kms.NewSecret(sdkkms.SecretStatusSecretBox, "payload", "key", "data") + err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to decrypt password") + } + // test disk and transfer quota reset + username1 := "user1" + username2 := "user2" + user1 := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username1, + HomeDir: filepath.Join(os.TempDir(), username1), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + } + user2 := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username2, + HomeDir: filepath.Join(os.TempDir(), username2), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + } + err = dataprovider.AddUser(&user1, "", "") + assert.NoError(t, err) + err = dataprovider.AddUser(&user2, "", "") + assert.NoError(t, err) + + action = dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeUserQuotaReset, + } + err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username1, + }, + }, + }) + assert.Error(t, err) // no home dir + // create the home dir + err = os.MkdirAll(user1.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(user1.GetHomeDir(), "file.txt"), []byte("user"), 0666) + assert.NoError(t, err) + err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username1, + }, + }, + }) + assert.NoError(t, err) + userGet, err := dataprovider.UserExists(username1) + assert.NoError(t, err) + assert.Equal(t, 1, userGet.UsedQuotaFiles) + assert.Equal(t, int64(4), userGet.UsedQuotaSize) + // simulate another quota scan in progress + assert.True(t, QuotaScans.AddUserQuotaScan(username1)) + err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username1, + }, + }, + }) + assert.Error(t, err) + assert.True(t, QuotaScans.RemoveUserQuotaScan(username1)) + + err = os.RemoveAll(user1.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.UpdateUserTransferQuota(&user1, 100, 100, true) + assert.NoError(t, err) + + action.Type = dataprovider.ActionTypeTransferQuotaReset + err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username1, + }, + }, + }) + assert.NoError(t, err) + userGet, err = dataprovider.UserExists(username1) + assert.NoError(t, err) + assert.Equal(t, int64(0), userGet.UsedDownloadDataTransfer) + assert.Equal(t, int64(0), userGet.UsedUploadDataTransfer) + + err = dataprovider.DeleteUser(username1, "", "") + assert.NoError(t, err) + err = dataprovider.DeleteUser(username2, "", "") + assert.NoError(t, err) + // test folder quota reset + foldername1 := "f1" + foldername2 := "f2" + folder1 := vfs.BaseVirtualFolder{ + Name: foldername1, + MappedPath: filepath.Join(os.TempDir(), foldername1), + } + folder2 := vfs.BaseVirtualFolder{ + Name: foldername2, + MappedPath: filepath.Join(os.TempDir(), foldername2), + } + err = dataprovider.AddFolder(&folder1, "", "") + assert.NoError(t, err) + err = dataprovider.AddFolder(&folder2, "", "") + assert.NoError(t, err) + action = dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeFolderQuotaReset, + } + err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: foldername1, + }, + }, + }) + assert.Error(t, err) // no home dir + err = os.MkdirAll(folder1.MappedPath, os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(folder1.MappedPath, "file.txt"), []byte("folder"), 0666) + assert.NoError(t, err) + err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: foldername1, + }, + }, + }) + assert.NoError(t, err) + folderGet, err := dataprovider.GetFolderByName(foldername1) + assert.NoError(t, err) + assert.Equal(t, 1, folderGet.UsedQuotaFiles) + assert.Equal(t, int64(6), folderGet.UsedQuotaSize) + // simulate another quota scan in progress + assert.True(t, QuotaScans.AddVFolderQuotaScan(foldername1)) + err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: foldername1, + }, + }, + }) + assert.Error(t, err) + assert.True(t, QuotaScans.RemoveVFolderQuotaScan(foldername1)) + + err = os.RemoveAll(folder1.MappedPath) + assert.NoError(t, err) + err = dataprovider.DeleteFolder(foldername1, "", "") + assert.NoError(t, err) + err = dataprovider.DeleteFolder(foldername2, "", "") + assert.NoError(t, err) +} + +func TestQuotaActionsWithQuotaTrackDisabled(t *testing.T) { + oldProviderConf := dataprovider.GetProviderConfig() + providerConf := dataprovider.GetProviderConfig() + providerConf.TrackQuota = 0 + err := dataprovider.Close() + assert.NoError(t, err) + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + username := "u1" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + FsConfig: vfs.Filesystem{ + Provider: sdk.LocalFilesystemProvider, + }, + } + err = dataprovider.AddUser(&user, "", "") + assert.NoError(t, err) + + err = os.MkdirAll(user.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeUserQuotaReset}, + EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username, + }, + }, + }) + assert.Error(t, err) + + err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeTransferQuotaReset}, + EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username, + }, + }, + }) + assert.Error(t, err) + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.DeleteUser(username, "", "") + assert.NoError(t, err) + + foldername := "f1" + folder := vfs.BaseVirtualFolder{ + Name: foldername, + MappedPath: filepath.Join(os.TempDir(), foldername), + } + err = dataprovider.AddFolder(&folder, "", "") + assert.NoError(t, err) + err = os.MkdirAll(folder.MappedPath, os.ModePerm) + assert.NoError(t, err) + + err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeFolderQuotaReset}, + EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: foldername, + }, + }, + }) + assert.Error(t, err) + + err = os.RemoveAll(folder.MappedPath) + assert.NoError(t, err) + err = dataprovider.DeleteFolder(foldername, "", "") + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = dataprovider.Initialize(oldProviderConf, configDir, true) + assert.NoError(t, err) +} + +func TestScheduledActions(t *testing.T) { + startEventScheduler() + backupsPath := filepath.Join(os.TempDir(), "backups") + err := os.RemoveAll(backupsPath) + assert.NoError(t, err) + + action := &dataprovider.BaseEventAction{ + Name: "action", + Type: dataprovider.ActionTypeBackup, + } + err = dataprovider.AddEventAction(action, "", "") + assert.NoError(t, err) + rule := &dataprovider.EventRule{ + Name: "rule", + Trigger: dataprovider.EventTriggerSchedule, + Conditions: dataprovider.EventConditions{ + Schedules: []dataprovider.Schedule{ + { + Hours: "11", + DayOfWeek: "*", + DayOfMonth: "*", + Month: "*", + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action.Name, + }, + Order: 1, + }, + }, + } + + job := eventCronJob{ + ruleName: rule.Name, + } + job.Run() // rule not found + assert.NoDirExists(t, backupsPath) + + err = dataprovider.AddEventRule(rule, "", "") + assert.NoError(t, err) + + job.Run() + assert.DirExists(t, backupsPath) + + err = dataprovider.DeleteEventRule(rule.Name, "", "") + assert.NoError(t, err) + err = dataprovider.DeleteEventAction(action.Name, "", "") + assert.NoError(t, err) + err = os.RemoveAll(backupsPath) + assert.NoError(t, err) + stopEventScheduler() +} diff --git a/internal/common/eventscheduler.go b/internal/common/eventscheduler.go new file mode 100644 index 00000000..aa050a21 --- /dev/null +++ b/internal/common/eventscheduler.go @@ -0,0 +1,43 @@ +// Copyright (C) 2019-2022 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "time" + + "github.com/robfig/cron/v3" + + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + eventScheduler *cron.Cron +) + +func stopEventScheduler() { + if eventScheduler != nil { + eventScheduler.Stop() + eventScheduler = nil + } +} + +func startEventScheduler() { + stopEventScheduler() + + eventScheduler = cron.New(cron.WithLocation(time.UTC)) + _, err := eventScheduler.AddFunc("@every 10m", eventManager.loadRules) + util.PanicOnError(err) + eventScheduler.Start() +} diff --git a/internal/common/protocol_test.go b/internal/common/protocol_test.go index ec149e58..4d3aa5ca 100644 --- a/internal/common/protocol_test.go +++ b/internal/common/protocol_test.go @@ -18,6 +18,7 @@ import ( "bufio" "bytes" "crypto/rand" + "encoding/json" "fmt" "io" "math" @@ -77,6 +78,7 @@ var ( allPerms = []string{dataprovider.PermAny} homeBasePath string logFilePath string + backupsPath string testFileContent = []byte("test data") lastReceivedEmail receivedEmail ) @@ -84,6 +86,7 @@ var ( func TestMain(m *testing.M) { homeBasePath = os.TempDir() logFilePath = filepath.Join(configDir, "common_test.log") + backupsPath = filepath.Join(os.TempDir(), "backups") logger.InitLogger(logFilePath, 5, 1, 28, false, false, zerolog.DebugLevel) os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1") @@ -95,6 +98,7 @@ func TestMain(m *testing.M) { os.Exit(1) } providerConf := config.GetProviderConf() + providerConf.BackupsPath = backupsPath logger.InfoToConsole("Starting COMMON tests, provider: %v", providerConf.Driver) err = common.Initialize(config.GetCommonConfig(), 0) @@ -203,6 +207,7 @@ func TestMain(m *testing.M) { exitCode := m.Run() os.Remove(logFilePath) + os.RemoveAll(backupsPath) os.Exit(exitCode) } @@ -2848,7 +2853,7 @@ func TestEventRule(t *testing.T) { EmailConfig: dataprovider.EventActionEmailConfig{ Recipients: []string{"test1@example.com", "test2@example.com"}, Subject: `New "{{Event}}" from "{{Name}}"`, - Body: "Fs path {{FsPath}}, size: {{FileSize}}, protocol: {{Protocol}}, IP: {{IP}}", + Body: "Fs path {{FsPath}}, size: {{FileSize}}, protocol: {{Protocol}}, IP: {{IP}} Data: {{ObjectData}}", }, }, } @@ -2987,6 +2992,10 @@ func TestEventRule(t *testing.T) { Key: "SFTPGO_ACTION_PATH", Value: "{{FsPath}}", }, + { + Key: "CUSTOM_ENV_VAR", + Value: "value", + }, }, }, } @@ -3090,6 +3099,176 @@ func TestEventRule(t *testing.T) { require.NoError(t, err) } +func TestEventRuleProviderEvents(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir) + require.NoError(t, err) + + saveObjectScriptPath := filepath.Join(os.TempDir(), "provider.sh") + outPath := filepath.Join(os.TempDir(), "provider_out.json") + err = os.WriteFile(saveObjectScriptPath, getSaveProviderObjectScriptContent(outPath, 0), 0755) + assert.NoError(t, err) + + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeCommand, + Options: dataprovider.BaseEventActionOptions{ + CmdConfig: dataprovider.EventActionCommandConfig{ + Cmd: saveObjectScriptPath, + Timeout: 10, + EnvVars: []dataprovider.KeyValue{ + { + Key: "SFTPGO_OBJECT_DATA", + Value: "{{ObjectData}}", + }, + }, + }, + }, + } + a2 := dataprovider.BaseEventAction{ + Name: "a2", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test3@example.com"}, + Subject: `New "{{Event}}" from "{{Name}}"`, + Body: "Object name: {{ObjectName}} object type: {{ObjectType}} Data: {{ObjectData}}", + }, + }, + } + + a3 := dataprovider.BaseEventAction{ + Name: "a3", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"failure@example.com"}, + Subject: `Failed "{{Event}}" from "{{Name}}"`, + Body: "Object name: {{ObjectName}} object type: {{ObjectType}}, IP: {{IP}}", + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated) + assert.NoError(t, err) + + r := dataprovider.EventRule{ + Name: "rule", + Trigger: dataprovider.EventTriggerProviderEvent, + Conditions: dataprovider.EventConditions{ + ProviderEvents: []string{"update"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + StopOnFailure: true, + }, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action3.Name, + }, + Order: 3, + Options: dataprovider.EventActionOptions{ + IsFailureAction: true, + StopOnFailure: true, + }, + }, + }, + } + rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated) + assert.NoError(t, err) + + lastReceivedEmail.reset() + // create and update a folder to trigger the rule + folder := vfs.BaseVirtualFolder{ + Name: "ftest rule", + MappedPath: filepath.Join(os.TempDir(), "p"), + } + folder, _, err = httpdtest.AddFolder(folder, http.StatusCreated) + assert.NoError(t, err) + // no action is triggered on add + assert.NoFileExists(t, outPath) + // update the folder + _, _, err = httpdtest.UpdateFolder(folder, http.StatusOK) + assert.NoError(t, err) + if assert.FileExists(t, outPath) { + content, err := os.ReadFile(outPath) + assert.NoError(t, err) + var folderGet vfs.BaseVirtualFolder + err = json.Unmarshal(content, &folderGet) + assert.NoError(t, err) + assert.Equal(t, folder, folderGet) + err = os.Remove(outPath) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, util.Contains(email.To, "test3@example.com")) + assert.Contains(t, string(email.Data), `Subject: New "update" from "admin"`) + } + // now delete the script to generate an error + lastReceivedEmail.reset() + err = os.Remove(saveObjectScriptPath) + assert.NoError(t, err) + _, _, err = httpdtest.UpdateFolder(folder, http.StatusOK) + assert.NoError(t, err) + assert.NoFileExists(t, outPath) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, util.Contains(email.To, "failure@example.com")) + assert.Contains(t, string(email.Data), `Subject: Failed "update" from "admin"`) + assert.Contains(t, string(email.Data), fmt.Sprintf("Object name: %s object type: folder", folder.Name)) + lastReceivedEmail.reset() + // generate an error for the failure action + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir) + require.NoError(t, err) + _, _, err = httpdtest.UpdateFolder(folder, http.StatusOK) + assert.NoError(t, err) + assert.NoFileExists(t, outPath) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 0) + + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) + + _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) + assert.NoError(t, err) +} + func TestSyncUploadAction(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") @@ -4113,6 +4292,13 @@ func getUploadScriptContent(movedPath string, exitStatus int) []byte { return content } +func getSaveProviderObjectScriptContent(outFilePath string, exitStatus int) []byte { + content := []byte("#!/bin/sh\n\n") + content = append(content, []byte(fmt.Sprintf("echo ${SFTPGO_OBJECT_DATA} > %v\n", outFilePath))...) + content = append(content, []byte(fmt.Sprintf("exit %d", exitStatus))...) + return content +} + func generateTOTPPasscode(secret string, algo otp.Algorithm) (string, error) { return totp.GenerateCodeCustom(secret, time.Now(), totp.ValidateOpts{ Period: 30, diff --git a/internal/dataprovider/actions.go b/internal/dataprovider/actions.go index e95f87cf..2bef46d5 100644 --- a/internal/dataprovider/actions.go +++ b/internal/dataprovider/actions.go @@ -63,17 +63,8 @@ func executeAction(operation, executor, ip, objectType, objectName string, objec Timestamp: time.Now().UnixNano(), }, object) } - if EventManager.hasProviderEvents() { - EventManager.handleProviderEvent(EventParams{ - Name: executor, - ObjectName: objectName, - Event: operation, - Status: 1, - ObjectType: objectType, - IP: ip, - Timestamp: time.Now().UnixNano(), - Object: object, - }) + if fnHandleRuleForProviderEvent != nil { + fnHandleRuleForProviderEvent(operation, executor, ip, objectType, objectName, object) } if config.Actions.Hook == "" { return diff --git a/internal/dataprovider/dataprovider.go b/internal/dataprovider/dataprovider.go index b6a7a9fc..5fbf9c1a 100644 --- a/internal/dataprovider/dataprovider.go +++ b/internal/dataprovider/dataprovider.go @@ -191,6 +191,9 @@ var ( lastLoginMinDelay = 10 * time.Minute usernameRegex = regexp.MustCompile("^[a-zA-Z0-9-_.~]+$") tempPath string + fnReloadRules FnReloadRules + fnRemoveRule FnRemoveRule + fnHandleRuleForProviderEvent FnHandleRuleForProviderEvent ) func initSQLTables() { @@ -214,6 +217,22 @@ func initSQLTables() { sqlTableSchemaVersion = "schema_version" } +// FnReloadRules defined the callback to reload event rules +type FnReloadRules func() + +// FnRemoveRule defines the callback to remove an event rule +type FnRemoveRule func(name string) + +// FnHandleRuleForProviderEvent define the callback to handle event rules for provider events +type FnHandleRuleForProviderEvent func(operation, executor, ip, objectType, objectName string, object plugin.Renderer) + +// SetEventRulesCallbacks sets the event rules callbacks +func SetEventRulesCallbacks(reload FnReloadRules, remove FnRemoveRule, handle FnHandleRuleForProviderEvent) { + fnReloadRules = reload + fnRemoveRule = remove + fnHandleRuleForProviderEvent = handle +} + type schemaVersion struct { Version int } @@ -487,31 +506,36 @@ func (c *Config) requireCustomTLSForMySQL() bool { func (c *Config) doBackup() error { now := time.Now().UTC() outputFile := filepath.Join(c.BackupsPath, fmt.Sprintf("backup_%s_%d.json", now.Weekday(), now.Hour())) - eventManagerLog(logger.LevelDebug, "starting backup to file %q", outputFile) + providerLog(logger.LevelDebug, "starting backup to file %q", outputFile) err := os.MkdirAll(filepath.Dir(outputFile), 0700) if err != nil { - eventManagerLog(logger.LevelError, "unable to create backup dir %q: %v", outputFile, err) + providerLog(logger.LevelError, "unable to create backup dir %q: %v", outputFile, err) return fmt.Errorf("unable to create backup dir: %w", err) } backup, err := DumpData() if err != nil { - eventManagerLog(logger.LevelError, "unable to execute backup: %v", err) + providerLog(logger.LevelError, "unable to execute backup: %v", err) return fmt.Errorf("unable to dump backup data: %w", err) } dump, err := json.Marshal(backup) if err != nil { - eventManagerLog(logger.LevelError, "unable to marshal backup as JSON: %v", err) + providerLog(logger.LevelError, "unable to marshal backup as JSON: %v", err) return fmt.Errorf("unable to marshal backup data as JSON: %w", err) } err = os.WriteFile(outputFile, dump, 0600) if err != nil { - eventManagerLog(logger.LevelError, "unable to save backup: %v", err) + providerLog(logger.LevelError, "unable to save backup: %v", err) return fmt.Errorf("unable to save backup: %w", err) } - eventManagerLog(logger.LevelDebug, "auto backup saved to %q", outputFile) + providerLog(logger.LevelDebug, "backup saved to %q", outputFile) return nil } +// ExecuteBackup executes a backup +func ExecuteBackup() error { + return config.doBackup() +} + // ConvertName converts the given name based on the configured rules func ConvertName(name string) string { return config.convertName(name) @@ -1568,7 +1592,9 @@ func AddEventAction(action *BaseEventAction, executor, ipAddress string) error { func UpdateEventAction(action *BaseEventAction, executor, ipAddress string) error { err := provider.updateEventAction(action) if err == nil { - EventManager.loadRules() + if fnReloadRules != nil { + fnReloadRules() + } executeAction(operationUpdate, executor, ipAddress, actionObjectEventAction, action.Name, action) } return err @@ -1597,6 +1623,11 @@ func GetEventRules(limit, offset int, order string) ([]EventRule, error) { return provider.getEventRules(limit, offset, order) } +// GetRecentlyUpdatedRules returns the event rules updated after the specified time +func GetRecentlyUpdatedRules(after int64) ([]EventRule, error) { + return provider.getRecentlyUpdatedRules(after) +} + // EventRuleExists returns the event rule with the given name if it exists func EventRuleExists(name string) (EventRule, error) { name = config.convertName(name) @@ -1608,7 +1639,9 @@ func AddEventRule(rule *EventRule, executor, ipAddress string) error { rule.Name = config.convertName(rule.Name) err := provider.addEventRule(rule) if err == nil { - EventManager.loadRules() + if fnReloadRules != nil { + fnReloadRules() + } executeAction(operationAdd, executor, ipAddress, actionObjectEventRule, rule.Name, rule) } return err @@ -1618,7 +1651,9 @@ func AddEventRule(rule *EventRule, executor, ipAddress string) error { func UpdateEventRule(rule *EventRule, executor, ipAddress string) error { err := provider.updateEventRule(rule) if err == nil { - EventManager.loadRules() + if fnReloadRules != nil { + fnReloadRules() + } executeAction(operationUpdate, executor, ipAddress, actionObjectEventRule, rule.Name, rule) } return err @@ -1633,12 +1668,39 @@ func DeleteEventRule(name string, executor, ipAddress string) error { } err = provider.deleteEventRule(rule, config.IsShared == 1) if err == nil { - EventManager.RemoveRule(rule.Name) + if fnRemoveRule != nil { + fnRemoveRule(rule.Name) + } executeAction(operationDelete, executor, ipAddress, actionObjectEventRule, rule.Name, &rule) } return err } +// RemoveEventRule delets an existing event rule without marking it as deleted +func RemoveEventRule(rule EventRule) error { + return provider.deleteEventRule(rule, false) +} + +// GetTaskByName returns the task with the specified name +func GetTaskByName(name string) (Task, error) { + return provider.getTaskByName(name) +} + +// AddTask add a task with the specified name +func AddTask(name string) error { + return provider.addTask(name) +} + +// UpdateTask updates the task with the specified name and version +func UpdateTask(name string, version int64) error { + return provider.updateTask(name, version) +} + +// UpdateTaskTimestamp updates the timestamp for the task with the specified name +func UpdateTaskTimestamp(name string) error { + return provider.updateTaskTimestamp(name) +} + // HasAdmin returns true if the first admin has been created // and so SFTPGo is ready to be used func HasAdmin() bool { @@ -1971,6 +2033,16 @@ func GetFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtua return provider.getFolders(limit, offset, order, minimal) } +// DumpUsers returns all users, including confidential data +func DumpUsers() ([]User, error) { + return provider.dumpUsers() +} + +// DumpFolders returns all folders, including confidential data +func DumpFolders() ([]vfs.BaseVirtualFolder, error) { + return provider.dumpFolders() +} + // DumpData returns all users, groups, folders, admins, api keys, shares, actions, rules func DumpData() (BackupData, error) { var data BackupData diff --git a/internal/dataprovider/eventrule.go b/internal/dataprovider/eventrule.go index 191b89df..e210e167 100644 --- a/internal/dataprovider/eventrule.go +++ b/internal/dataprovider/eventrule.go @@ -15,16 +15,10 @@ package dataprovider import ( - "bytes" - "context" "crypto/tls" "encoding/json" "fmt" - "io" "net/http" - "net/url" - "os" - "os/exec" "path" "path/filepath" "strings" @@ -34,9 +28,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" - "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" - "github.com/drakkan/sftpgo/v2/internal/vfs" ) // Supported event actions @@ -204,25 +196,8 @@ func (c *EventActionHTTPConfig) validate(additionalData string) error { return nil } -func (c *EventActionHTTPConfig) getEndpoint(replacer *strings.Replacer) (string, error) { - if len(c.QueryParameters) > 0 { - u, err := url.Parse(c.Endpoint) - if err != nil { - return "", fmt.Errorf("invalid endpoint: %w", err) - } - q := u.Query() - - for _, keyVal := range c.QueryParameters { - q.Add(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer)) - } - - u.RawQuery = q.Encode() - return u.String(), nil - } - return c.Endpoint, nil -} - -func (c *EventActionHTTPConfig) getHTTPClient() *http.Client { +// GetHTTPClient returns an HTTP client based on the config +func (c *EventActionHTTPConfig) GetHTTPClient() *http.Client { client := &http.Client{ Timeout: time.Duration(c.Timeout) * time.Second, } @@ -241,63 +216,6 @@ func (c *EventActionHTTPConfig) getHTTPClient() *http.Client { return client } -func (c *EventActionHTTPConfig) execute(params EventParams) error { - if !c.Password.IsEmpty() { - if err := c.Password.TryDecrypt(); err != nil { - return fmt.Errorf("unable to decrypt password: %w", err) - } - } - addObjectData := false - if params.Object != nil { - if !addObjectData { - if strings.Contains(c.Body, "{{ObjectData}}") { - addObjectData = true - } - } - } - - replacements := params.getStringReplacements(addObjectData) - replacer := strings.NewReplacer(replacements...) - endpoint, err := c.getEndpoint(replacer) - if err != nil { - return err - } - - var body io.Reader - if c.Body != "" && c.Method != http.MethodGet { - body = bytes.NewBufferString(replaceWithReplacer(c.Body, replacer)) - } - req, err := http.NewRequest(c.Method, endpoint, body) - if err != nil { - return err - } - if c.Username != "" { - req.SetBasicAuth(replaceWithReplacer(c.Username, replacer), c.Password.GetAdditionalData()) - } - for _, keyVal := range c.Headers { - req.Header.Set(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer)) - } - client := c.getHTTPClient() - defer client.CloseIdleConnections() - - startTime := time.Now() - resp, err := client.Do(req) - if err != nil { - eventManagerLog(logger.LevelDebug, "unable to send http notification, endpoint: %s, elapsed: %s, err: %v", - endpoint, time.Since(startTime), err) - return err - } - defer resp.Body.Close() - - eventManagerLog(logger.LevelDebug, "http notification sent, endopoint: %s, elapsed: %s, status code: %d", - endpoint, time.Since(startTime), resp.StatusCode) - if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent { - return fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - return nil -} - // EventActionCommandConfig defines the configuration for a command event target type EventActionCommandConfig struct { Cmd string `json:"cmd"` @@ -323,43 +241,6 @@ func (c *EventActionCommandConfig) validate() error { return nil } -func (c *EventActionCommandConfig) getEnvVars(params EventParams) []string { - envVars := make([]string, 0, len(c.EnvVars)) - addObjectData := false - if params.Object != nil { - for _, k := range c.EnvVars { - if strings.Contains(k.Value, "{{ObjectData}}") { - addObjectData = true - break - } - } - } - replacements := params.getStringReplacements(addObjectData) - replacer := strings.NewReplacer(replacements...) - for _, keyVal := range c.EnvVars { - envVars = append(envVars, fmt.Sprintf("%s=%s", keyVal.Key, replaceWithReplacer(keyVal.Value, replacer))) - } - - return envVars -} - -func (c *EventActionCommandConfig) execute(params EventParams) error { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.Timeout)*time.Second) - defer cancel() - - cmd := exec.CommandContext(ctx, c.Cmd) - cmd.Env = append(cmd.Env, os.Environ()...) - cmd.Env = append(cmd.Env, c.getEnvVars(params)...) - - startTime := time.Now() - err := cmd.Run() - - eventManagerLog(logger.LevelDebug, "executed command %q, elapsed: %s, error: %v", - c.Cmd, time.Since(startTime), err) - - return err -} - // EventActionEmailConfig defines the configuration options for SMTP event actions type EventActionEmailConfig struct { Recipients []string `json:"recipients"` @@ -391,24 +272,6 @@ func (o *EventActionEmailConfig) validate() error { return nil } -func (o *EventActionEmailConfig) execute(params EventParams) error { - addObjectData := false - if params.Object != nil { - if strings.Contains(o.Body, "{{ObjectData}}") { - addObjectData = true - } - } - replacements := params.getStringReplacements(addObjectData) - replacer := strings.NewReplacer(replacements...) - body := replaceWithReplacer(o.Body, replacer) - subject := replaceWithReplacer(o.Subject, replacer) - startTime := time.Now() - err := smtp.SendEmail(o.Recipients, subject, body, smtp.EmailContentTypeTextPlain) - eventManagerLog(logger.LevelDebug, "executed email notification action, elapsed: %s, error: %v", - time.Since(startTime), err) - return err -} - // BaseEventActionOptions defines the supported configuration options for a base event actions type BaseEventActionOptions struct { HTTPConfig EventActionHTTPConfig `json:"http_config"` @@ -560,130 +423,6 @@ func (a *BaseEventAction) validate() error { return a.Options.validate(a.Type, a.Name) } -func (a *BaseEventAction) doUsersQuotaReset(conditions ConditionOptions) error { - users, err := provider.dumpUsers() - if err != nil { - return fmt.Errorf("unable to get users: %w", err) - } - var failedResets []string - for _, user := range users { - if !checkConditionPatterns(user.Username, conditions.Names) { - eventManagerLog(logger.LevelDebug, "skipping scheduled quota reset for user %s, name conditions don't match", - user.Username) - continue - } - if !QuotaScans.AddUserQuotaScan(user.Username) { - eventManagerLog(logger.LevelError, "another quota scan is already in progress for user %s", user.Username) - failedResets = append(failedResets, user.Username) - continue - } - numFiles, size, err := user.ScanQuota() - QuotaScans.RemoveUserQuotaScan(user.Username) - if err != nil { - eventManagerLog(logger.LevelError, "error scanning quota for user %s: %v", user.Username, err) - failedResets = append(failedResets, user.Username) - continue - } - err = UpdateUserQuota(&user, numFiles, size, true) - if err != nil { - eventManagerLog(logger.LevelError, "error updating quota for user %s: %v", user.Username, err) - failedResets = append(failedResets, user.Username) - continue - } - } - if len(failedResets) > 0 { - return fmt.Errorf("quota reset failed for users: %+v", failedResets) - } - return nil -} - -func (a *BaseEventAction) doFoldersQuotaReset(conditions ConditionOptions) error { - folders, err := provider.dumpFolders() - if err != nil { - return fmt.Errorf("unable to get folders: %w", err) - } - var failedResets []string - for _, folder := range folders { - if !checkConditionPatterns(folder.Name, conditions.Names) { - eventManagerLog(logger.LevelDebug, "skipping scheduled quota reset for folder %s, name conditions don't match", - folder.Name) - continue - } - if !QuotaScans.AddVFolderQuotaScan(folder.Name) { - eventManagerLog(logger.LevelError, "another quota scan is already in progress for folder %s", folder.Name) - failedResets = append(failedResets, folder.Name) - continue - } - f := vfs.VirtualFolder{ - BaseVirtualFolder: folder, - VirtualPath: "/", - } - numFiles, size, err := f.ScanQuota() - QuotaScans.RemoveVFolderQuotaScan(folder.Name) - if err != nil { - eventManagerLog(logger.LevelError, "error scanning quota for folder %s: %v", folder.Name, err) - failedResets = append(failedResets, folder.Name) - continue - } - err = UpdateVirtualFolderQuota(&folder, numFiles, size, true) - if err != nil { - eventManagerLog(logger.LevelError, "error updating quota for folder %s: %v", folder.Name, err) - failedResets = append(failedResets, folder.Name) - continue - } - } - if len(failedResets) > 0 { - return fmt.Errorf("quota reset failed for folders: %+v", failedResets) - } - return nil -} - -func (a *BaseEventAction) doTransferQuotaReset(conditions ConditionOptions) error { - users, err := provider.dumpUsers() - if err != nil { - return fmt.Errorf("unable to get users: %w", err) - } - var failedResets []string - for _, user := range users { - if !checkConditionPatterns(user.Username, conditions.Names) { - eventManagerLog(logger.LevelDebug, "skipping scheduled transfer quota reset for user %s, name conditions don't match", - user.Username) - continue - } - err = UpdateUserTransferQuota(&user, 0, 0, true) - if err != nil { - eventManagerLog(logger.LevelError, "error updating transfer quota for user %s: %v", user.Username, err) - failedResets = append(failedResets, user.Username) - continue - } - } - if len(failedResets) > 0 { - return fmt.Errorf("transfer quota reset failed for users: %+v", failedResets) - } - return nil -} - -func (a *BaseEventAction) execute(params EventParams, conditions ConditionOptions) error { - switch a.Type { - case ActionTypeHTTP: - return a.Options.HTTPConfig.execute(params) - case ActionTypeCommand: - return a.Options.CmdConfig.execute(params) - case ActionTypeEmail: - return a.Options.EmailConfig.execute(params) - case ActionTypeBackup: - return config.doBackup() - case ActionTypeUserQuotaReset: - return a.doUsersQuotaReset(conditions) - case ActionTypeFolderQuotaReset: - return a.doFoldersQuotaReset(conditions) - case ActionTypeTransferQuotaReset: - return a.doTransferQuotaReset(conditions) - default: - return fmt.Errorf("unsupported action type: %d", a.Type) - } -} - // EventActionOptions defines the supported configuration options for an event action type EventActionOptions struct { IsFailureAction bool `json:"is_failure_action"` @@ -731,18 +470,6 @@ type ConditionPattern struct { InverseMatch bool `json:"inverse_match,omitempty"` } -func (p *ConditionPattern) match(name string) bool { - matched, err := path.Match(p.Pattern, name) - if err != nil { - eventManagerLog(logger.LevelError, "pattern matching error %q, err: %v", p.Pattern, err) - return false - } - if p.InverseMatch { - return !matched - } - return matched -} - func (p *ConditionPattern) validate() error { if p.Pattern == "" { return util.NewValidationError("empty condition pattern not allowed") @@ -826,12 +553,13 @@ type Schedule struct { Month string `json:"month"` } -func (s *Schedule) getCronSpec() string { +// GetCronSpec returns the cron compatible schedule string +func (s *Schedule) GetCronSpec() string { return fmt.Sprintf("0 %s %s %s %s", s.Hours, s.DayOfMonth, s.Month, s.DayOfWeek) } func (s *Schedule) validate() error { - _, err := cron.ParseStandard(s.getCronSpec()) + _, err := cron.ParseStandard(s.GetCronSpec()) if err != nil { return util.NewValidationError(fmt.Sprintf("invalid schedule, hour: %q, day of month: %q, month: %q, day of week: %q", s.Hours, s.DayOfMonth, s.Month, s.DayOfWeek)) @@ -871,51 +599,6 @@ func (c *EventConditions) getACopy() EventConditions { } } -// ProviderEventMatch returns true if the specified provider event match -func (c *EventConditions) ProviderEventMatch(params EventParams) bool { - if !util.Contains(c.ProviderEvents, params.Event) { - return false - } - if !checkConditionPatterns(params.Name, c.Options.Names) { - return false - } - if len(c.Options.ProviderObjects) > 0 && !util.Contains(c.Options.ProviderObjects, params.ObjectType) { - return false - } - return true -} - -// FsEventMatch returns true if the specified filesystem event match -func (c *EventConditions) FsEventMatch(params EventParams) bool { - if !util.Contains(c.FsEvents, params.Event) { - return false - } - if !checkConditionPatterns(params.Name, c.Options.Names) { - return false - } - if !checkConditionPatterns(params.VirtualPath, c.Options.FsPaths) { - if !checkConditionPatterns(params.ObjectName, c.Options.FsPaths) { - return false - } - } - if len(c.Options.Protocols) > 0 && !util.Contains(c.Options.Protocols, params.Protocol) { - return false - } - if params.Event == "upload" || params.Event == "download" { - if c.Options.MinFileSize > 0 { - if params.FileSize < c.Options.MinFileSize { - return false - } - } - if c.Options.MaxFileSize > 0 { - if params.FileSize > c.Options.MaxFileSize { - return false - } - } - } - return true -} - func (c *EventConditions) validate(trigger int) error { switch trigger { case EventTriggerFsEvent: @@ -1015,7 +698,9 @@ func (r *EventRule) getACopy() EventRule { } } -func (r *EventRule) guardFromConcurrentExecution() bool { +// GuardFromConcurrentExecution returns true if the rule cannot be executed concurrently +// from multiple instances +func (r *EventRule) GuardFromConcurrentExecution() bool { if config.IsShared == 0 { return false } @@ -1102,6 +787,28 @@ func (r *EventRule) RenderAsJSON(reload bool) ([]byte, error) { return json.Marshal(r) } +func cloneKeyValues(keyVals []KeyValue) []KeyValue { + res := make([]KeyValue, 0, len(keyVals)) + for _, kv := range keyVals { + res = append(res, KeyValue{ + Key: kv.Key, + Value: kv.Value, + }) + } + return res +} + +func cloneConditionPatterns(patterns []ConditionPattern) []ConditionPattern { + res := make([]ConditionPattern, 0, len(patterns)) + for _, p := range patterns { + res = append(res, ConditionPattern{ + Pattern: p.Pattern, + InverseMatch: p.InverseMatch, + }) + } + return res +} + // Task stores the state for a scheduled task type Task struct { Name string `json:"name"` diff --git a/internal/dataprovider/eventruleutil.go b/internal/dataprovider/eventruleutil.go deleted file mode 100644 index b0e1cf9e..00000000 --- a/internal/dataprovider/eventruleutil.go +++ /dev/null @@ -1,474 +0,0 @@ -// Copyright (C) 2019-2022 Nicola Murino -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published -// by the Free Software Foundation, version 3. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package dataprovider - -import ( - "fmt" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/robfig/cron/v3" - - "github.com/drakkan/sftpgo/v2/internal/logger" - "github.com/drakkan/sftpgo/v2/internal/plugin" - "github.com/drakkan/sftpgo/v2/internal/util" -) - -var ( - // EventManager handle the supported event rules actions - EventManager EventRulesContainer -) - -func init() { - EventManager = EventRulesContainer{ - schedulesMapping: make(map[string][]cron.EntryID), - } -} - -// EventRulesContainer stores event rules by trigger -type EventRulesContainer struct { - sync.RWMutex - FsEvents []EventRule - ProviderEvents []EventRule - Schedules []EventRule - schedulesMapping map[string][]cron.EntryID - lastLoad int64 -} - -func (r *EventRulesContainer) getLastLoadTime() int64 { - return atomic.LoadInt64(&r.lastLoad) -} - -func (r *EventRulesContainer) setLastLoadTime(modTime int64) { - atomic.StoreInt64(&r.lastLoad, modTime) -} - -// RemoveRule deletes the rule with the specified name -func (r *EventRulesContainer) RemoveRule(name string) { - r.Lock() - defer r.Unlock() - - r.removeRuleInternal(name) - eventManagerLog(logger.LevelDebug, "event rules updated after delete, fs events: %d, provider events: %d, schedules: %d", - len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules)) -} - -func (r *EventRulesContainer) removeRuleInternal(name string) { - for idx := range r.FsEvents { - if r.FsEvents[idx].Name == name { - lastIdx := len(r.FsEvents) - 1 - r.FsEvents[idx] = r.FsEvents[lastIdx] - r.FsEvents = r.FsEvents[:lastIdx] - eventManagerLog(logger.LevelDebug, "removed rule %q from fs events", name) - return - } - } - for idx := range r.ProviderEvents { - if r.ProviderEvents[idx].Name == name { - lastIdx := len(r.ProviderEvents) - 1 - r.ProviderEvents[idx] = r.ProviderEvents[lastIdx] - r.ProviderEvents = r.ProviderEvents[:lastIdx] - eventManagerLog(logger.LevelDebug, "removed rule %q from provider events", name) - return - } - } - for idx := range r.Schedules { - if r.Schedules[idx].Name == name { - if schedules, ok := r.schedulesMapping[name]; ok { - for _, entryID := range schedules { - eventManagerLog(logger.LevelDebug, "removing scheduled entry id %d for rule %q", entryID, name) - scheduler.Remove(entryID) - } - delete(r.schedulesMapping, name) - } - - lastIdx := len(r.Schedules) - 1 - r.Schedules[idx] = r.Schedules[lastIdx] - r.Schedules = r.Schedules[:lastIdx] - eventManagerLog(logger.LevelDebug, "removed rule %q from scheduled events", name) - return - } - } -} - -func (r *EventRulesContainer) addUpdateRuleInternal(rule EventRule) { - r.removeRuleInternal(rule.Name) - if rule.DeletedAt > 0 { - deletedAt := util.GetTimeFromMsecSinceEpoch(rule.DeletedAt) - if deletedAt.Add(30 * time.Minute).Before(time.Now()) { - eventManagerLog(logger.LevelDebug, "removing rule %q deleted at %s", rule.Name, deletedAt) - go provider.deleteEventRule(rule, false) //nolint:errcheck - } - return - } - switch rule.Trigger { - case EventTriggerFsEvent: - r.FsEvents = append(r.FsEvents, rule) - eventManagerLog(logger.LevelDebug, "added rule %q to fs events", rule.Name) - case EventTriggerProviderEvent: - r.ProviderEvents = append(r.ProviderEvents, rule) - eventManagerLog(logger.LevelDebug, "added rule %q to provider events", rule.Name) - case EventTriggerSchedule: - r.Schedules = append(r.Schedules, rule) - eventManagerLog(logger.LevelDebug, "added rule %q to scheduled events", rule.Name) - for _, schedule := range rule.Conditions.Schedules { - cronSpec := schedule.getCronSpec() - job := &cronJob{ - ruleName: ConvertName(rule.Name), - } - entryID, err := scheduler.AddJob(cronSpec, job) - if err != nil { - eventManagerLog(logger.LevelError, "unable to add scheduled rule %q: %v", rule.Name, err) - } else { - r.schedulesMapping[rule.Name] = append(r.schedulesMapping[rule.Name], entryID) - eventManagerLog(logger.LevelDebug, "scheduled rule %q added, id: %d, active scheduling rules: %d", - rule.Name, entryID, len(r.schedulesMapping)) - } - } - default: - eventManagerLog(logger.LevelError, "unsupported trigger: %d", rule.Trigger) - } -} - -func (r *EventRulesContainer) loadRules() { - eventManagerLog(logger.LevelDebug, "loading updated rules") - modTime := util.GetTimeAsMsSinceEpoch(time.Now()) - rules, err := provider.getRecentlyUpdatedRules(r.getLastLoadTime()) - if err != nil { - eventManagerLog(logger.LevelError, "unable to load event rules: %v", err) - return - } - eventManagerLog(logger.LevelDebug, "recently updated event rules loaded: %d", len(rules)) - - if len(rules) > 0 { - r.Lock() - defer r.Unlock() - - for _, rule := range rules { - r.addUpdateRuleInternal(rule) - } - } - eventManagerLog(logger.LevelDebug, "event rules updated, fs events: %d, provider events: %d, schedules: %d", - len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules)) - - r.setLastLoadTime(modTime) -} - -// HasFsRules returns true if there are any rules for filesystem event triggers -func (r *EventRulesContainer) HasFsRules() bool { - r.RLock() - defer r.RUnlock() - - return len(r.FsEvents) > 0 -} - -func (r *EventRulesContainer) hasProviderEvents() bool { - r.RLock() - defer r.RUnlock() - - return len(r.ProviderEvents) > 0 -} - -// HandleFsEvent executes the rules actions defined for the specified event -func (r *EventRulesContainer) HandleFsEvent(params EventParams) error { - r.RLock() - - var rulesWithSyncActions, rulesAsync []EventRule - for _, rule := range r.FsEvents { - if rule.Conditions.FsEventMatch(params) { - hasSyncActions := false - for _, action := range rule.Actions { - if action.Options.ExecuteSync { - hasSyncActions = true - break - } - } - if hasSyncActions { - rulesWithSyncActions = append(rulesWithSyncActions, rule) - } else { - rulesAsync = append(rulesAsync, rule) - } - } - } - - r.RUnlock() - - if len(rulesAsync) > 0 { - go executeAsyncActions(rulesAsync, params) - } - - if len(rulesWithSyncActions) > 0 { - return executeSyncActions(rulesWithSyncActions, params) - } - return nil -} - -func (r *EventRulesContainer) handleProviderEvent(params EventParams) { - r.RLock() - defer r.RUnlock() - - var rules []EventRule - for _, rule := range r.ProviderEvents { - if rule.Conditions.ProviderEventMatch(params) { - rules = append(rules, rule) - } - } - - go executeAsyncActions(rules, params) -} - -// EventParams defines the supported event parameters -type EventParams struct { - Name string - Event string - Status int - VirtualPath string - FsPath string - VirtualTargetPath string - FsTargetPath string - ObjectName string - ObjectType string - FileSize int64 - Protocol string - IP string - Timestamp int64 - Object plugin.Renderer -} - -func (p *EventParams) getStringReplacements(addObjectData bool) []string { - replacements := []string{ - "{{Name}}", p.Name, - "{{Event}}", p.Event, - "{{Status}}", fmt.Sprintf("%d", p.Status), - "{{VirtualPath}}", p.VirtualPath, - "{{FsPath}}", p.FsPath, - "{{VirtualTargetPath}}", p.VirtualTargetPath, - "{{FsTargetPath}}", p.FsTargetPath, - "{{ObjectName}}", p.ObjectName, - "{{ObjectType}}", p.ObjectType, - "{{FileSize}}", fmt.Sprintf("%d", p.FileSize), - "{{Protocol}}", p.Protocol, - "{{IP}}", p.IP, - "{{Timestamp}}", fmt.Sprintf("%d", p.Timestamp), - } - if addObjectData { - data, err := p.Object.RenderAsJSON(p.Event != operationDelete) - if err == nil { - replacements = append(replacements, "{{ObjectData}}", string(data)) - } - } - return replacements -} - -func replaceWithReplacer(input string, replacer *strings.Replacer) string { - if !strings.Contains(input, "{{") { - return input - } - return replacer.Replace(input) -} - -// checkConditionPatterns returns false if patterns are defined and no match is found -func checkConditionPatterns(name string, patterns []ConditionPattern) bool { - if len(patterns) == 0 { - return true - } - for _, p := range patterns { - if p.match(name) { - return true - } - } - - return false -} - -func executeSyncActions(rules []EventRule, params EventParams) error { - var errRes error - - for _, rule := range rules { - var failedActions []string - for _, action := range rule.Actions { - if !action.Options.IsFailureAction && action.Options.ExecuteSync { - startTime := time.Now() - if err := action.execute(params, rule.Conditions.Options); err != nil { - eventManagerLog(logger.LevelError, "unable to execute sync action %q for rule %q, elapsed %s, err: %v", - action.Name, rule.Name, time.Since(startTime), err) - failedActions = append(failedActions, action.Name) - // we return the last error, it is ok for now - errRes = err - if action.Options.StopOnFailure { - break - } - } else { - eventManagerLog(logger.LevelDebug, "executed sync action %q for rule %q, elapsed: %s", - action.Name, rule.Name, time.Since(startTime)) - } - } - } - // execute async actions if any, including failure actions - go executeRuleAsyncActions(rule, params, failedActions) - } - - return errRes -} - -func executeAsyncActions(rules []EventRule, params EventParams) { - for _, rule := range rules { - executeRuleAsyncActions(rule, params, nil) - } -} - -func executeRuleAsyncActions(rule EventRule, params EventParams, failedActions []string) { - for _, action := range rule.Actions { - if !action.Options.IsFailureAction && !action.Options.ExecuteSync { - startTime := time.Now() - if err := action.execute(params, rule.Conditions.Options); err != nil { - eventManagerLog(logger.LevelError, "unable to execute action %q for rule %q, elapsed %s, err: %v", - action.Name, rule.Name, time.Since(startTime), err) - failedActions = append(failedActions, action.Name) - if action.Options.StopOnFailure { - break - } - } else { - eventManagerLog(logger.LevelDebug, "executed action %q for rule %q, elapsed %s", - action.Name, rule.Name, time.Since(startTime)) - } - } - if len(failedActions) > 0 { - // execute failure actions - for _, action := range rule.Actions { - if action.Options.IsFailureAction { - startTime := time.Now() - if err := action.execute(params, rule.Conditions.Options); err != nil { - eventManagerLog(logger.LevelError, "unable to execute failure action %q for rule %q, elapsed %s, err: %v", - action.Name, rule.Name, time.Since(startTime), err) - if action.Options.StopOnFailure { - break - } - } else { - eventManagerLog(logger.LevelDebug, "executed failure action %q for rule %q, elapsed: %s", - action.Name, rule.Name, time.Since(startTime)) - } - } - } - } - } -} - -type cronJob struct { - ruleName string -} - -func (j *cronJob) getTask(rule EventRule) (Task, error) { - if rule.guardFromConcurrentExecution() { - task, err := provider.getTaskByName(rule.Name) - if _, ok := err.(*util.RecordNotFoundError); ok { - eventManagerLog(logger.LevelDebug, "adding task for rule %q", rule.Name) - task = Task{ - Name: rule.Name, - UpdateAt: 0, - Version: 0, - } - err = provider.addTask(rule.Name) - if err != nil { - eventManagerLog(logger.LevelWarn, "unable to add task for rule %q: %v", rule.Name, err) - return task, err - } - } else { - eventManagerLog(logger.LevelWarn, "unable to get task for rule %q: %v", rule.Name, err) - } - return task, err - } - - return Task{}, nil -} - -func (j *cronJob) Run() { - eventManagerLog(logger.LevelDebug, "executing scheduled rule %q", j.ruleName) - rule, err := provider.eventRuleExists(j.ruleName) - if err != nil { - eventManagerLog(logger.LevelError, "unable to load rule with name %q", j.ruleName) - return - } - task, err := j.getTask(rule) - if err != nil { - return - } - if task.Name != "" { - updateInterval := 5 * time.Minute - updatedAt := util.GetTimeFromMsecSinceEpoch(task.UpdateAt) - if updatedAt.Add(updateInterval*2 + 1).After(time.Now()) { - eventManagerLog(logger.LevelDebug, "task for rule %q too recent: %s, skip execution", rule.Name, updatedAt) - return - } - err = provider.updateTask(rule.Name, task.Version) - if err != nil { - eventManagerLog(logger.LevelInfo, "unable to update task timestamp for rule %q, skip execution, err: %v", - rule.Name, err) - return - } - ticker := time.NewTicker(updateInterval) - done := make(chan bool) - - go func(taskName string) { - eventManagerLog(logger.LevelDebug, "update task %q timestamp worker started", taskName) - for { - select { - case <-done: - eventManagerLog(logger.LevelDebug, "update task %q timestamp worker finished", taskName) - return - case <-ticker.C: - err := provider.updateTaskTimestamp(taskName) - eventManagerLog(logger.LevelInfo, "updated timestamp for task %q, err: %v", taskName, err) - } - } - }(task.Name) - - executeRuleAsyncActions(rule, EventParams{}, nil) - - done <- true - ticker.Stop() - } else { - executeRuleAsyncActions(rule, EventParams{}, nil) - } - eventManagerLog(logger.LevelDebug, "execution for scheduled rule %q finished", j.ruleName) -} - -func cloneKeyValues(keyVals []KeyValue) []KeyValue { - res := make([]KeyValue, 0, len(keyVals)) - for _, kv := range keyVals { - res = append(res, KeyValue{ - Key: kv.Key, - Value: kv.Value, - }) - } - return res -} - -func cloneConditionPatterns(patterns []ConditionPattern) []ConditionPattern { - res := make([]ConditionPattern, 0, len(patterns)) - for _, p := range patterns { - res = append(res, ConditionPattern{ - Pattern: p.Pattern, - InverseMatch: p.InverseMatch, - }) - } - return res -} - -func eventManagerLog(level logger.LogLevel, format string, v ...any) { - logger.Log(level, "eventmanager", "", format, v...) -} diff --git a/internal/dataprovider/quota.go b/internal/dataprovider/quota.go deleted file mode 100644 index cdcba099..00000000 --- a/internal/dataprovider/quota.go +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (C) 2019-2022 Nicola Murino -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published -// by the Free Software Foundation, version 3. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package dataprovider - -import ( - "sync" - "time" - - "github.com/drakkan/sftpgo/v2/internal/util" -) - -var ( - // QuotaScans is the list of active quota scans - QuotaScans ActiveScans -) - -// ActiveQuotaScan defines an active quota scan for a user home dir -type ActiveQuotaScan struct { - // Username to which the quota scan refers - Username string `json:"username"` - // quota scan start time as unix timestamp in milliseconds - StartTime int64 `json:"start_time"` -} - -// ActiveVirtualFolderQuotaScan defines an active quota scan for a virtual folder -type ActiveVirtualFolderQuotaScan struct { - // folder name to which the quota scan refers - Name string `json:"name"` - // quota scan start time as unix timestamp in milliseconds - StartTime int64 `json:"start_time"` -} - -// ActiveScans holds the active quota scans -type ActiveScans struct { - sync.RWMutex - UserScans []ActiveQuotaScan - FolderScans []ActiveVirtualFolderQuotaScan -} - -// GetUsersQuotaScans returns the active quota scans for users home directories -func (s *ActiveScans) GetUsersQuotaScans() []ActiveQuotaScan { - s.RLock() - defer s.RUnlock() - - scans := make([]ActiveQuotaScan, len(s.UserScans)) - copy(scans, s.UserScans) - return scans -} - -// AddUserQuotaScan adds a user to the ones with active quota scans. -// Returns false if the user has a quota scan already running -func (s *ActiveScans) AddUserQuotaScan(username string) bool { - s.Lock() - defer s.Unlock() - - for _, scan := range s.UserScans { - if scan.Username == username { - return false - } - } - s.UserScans = append(s.UserScans, ActiveQuotaScan{ - Username: username, - StartTime: util.GetTimeAsMsSinceEpoch(time.Now()), - }) - return true -} - -// RemoveUserQuotaScan removes a user from the ones with active quota scans. -// Returns false if the user has no active quota scans -func (s *ActiveScans) RemoveUserQuotaScan(username string) bool { - s.Lock() - defer s.Unlock() - - for idx, scan := range s.UserScans { - if scan.Username == username { - lastIdx := len(s.UserScans) - 1 - s.UserScans[idx] = s.UserScans[lastIdx] - s.UserScans = s.UserScans[:lastIdx] - return true - } - } - - return false -} - -// GetVFoldersQuotaScans returns the active quota scans for virtual folders -func (s *ActiveScans) GetVFoldersQuotaScans() []ActiveVirtualFolderQuotaScan { - s.RLock() - defer s.RUnlock() - scans := make([]ActiveVirtualFolderQuotaScan, len(s.FolderScans)) - copy(scans, s.FolderScans) - return scans -} - -// AddVFolderQuotaScan adds a virtual folder to the ones with active quota scans. -// Returns false if the folder has a quota scan already running -func (s *ActiveScans) AddVFolderQuotaScan(folderName string) bool { - s.Lock() - defer s.Unlock() - - for _, scan := range s.FolderScans { - if scan.Name == folderName { - return false - } - } - s.FolderScans = append(s.FolderScans, ActiveVirtualFolderQuotaScan{ - Name: folderName, - StartTime: util.GetTimeAsMsSinceEpoch(time.Now()), - }) - return true -} - -// RemoveVFolderQuotaScan removes a folder from the ones with active quota scans. -// Returns false if the folder has no active quota scans -func (s *ActiveScans) RemoveVFolderQuotaScan(folderName string) bool { - s.Lock() - defer s.Unlock() - - for idx, scan := range s.FolderScans { - if scan.Name == folderName { - lastIdx := len(s.FolderScans) - 1 - s.FolderScans[idx] = s.FolderScans[lastIdx] - s.FolderScans = s.FolderScans[:lastIdx] - return true - } - } - - return false -} diff --git a/internal/dataprovider/scheduler.go b/internal/dataprovider/scheduler.go index 7cb7ad03..2b9d7a94 100644 --- a/internal/dataprovider/scheduler.go +++ b/internal/dataprovider/scheduler.go @@ -54,7 +54,9 @@ func startScheduler() error { if err != nil { return err } - EventManager.loadRules() + if fnReloadRules != nil { + fnReloadRules() + } scheduler.Start() return nil } @@ -77,7 +79,7 @@ func checkDataprovider() { } func checkCacheUpdates() { - providerLog(logger.LevelDebug, "start caches check, update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate)) + providerLog(logger.LevelDebug, "start user cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate)) checkTime := util.GetTimeAsMsSinceEpoch(time.Now()) users, err := provider.getRecentlyUpdatedUsers(lastUserCacheUpdate) if err != nil { @@ -101,8 +103,7 @@ func checkCacheUpdates() { } lastUserCacheUpdate = checkTime - EventManager.loadRules() - providerLog(logger.LevelDebug, "end caches check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate)) + providerLog(logger.LevelDebug, "end user cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate)) } func setLastUserUpdate() { diff --git a/internal/httpd/api_maintenance.go b/internal/httpd/api_maintenance.go index bbf86804..61b5c938 100644 --- a/internal/httpd/api_maintenance.go +++ b/internal/httpd/api_maintenance.go @@ -27,6 +27,7 @@ import ( "github.com/go-chi/render" + "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" @@ -265,7 +266,7 @@ func RestoreFolders(folders []vfs.BaseVirtualFolder, inputFile string, mode, sca return fmt.Errorf("unable to restore folder %#v: %w", folder.Name, err) } if scanQuota >= 1 { - if dataprovider.QuotaScans.AddVFolderQuotaScan(folder.Name) { + if common.QuotaScans.AddVFolderQuotaScan(folder.Name) { logger.Debug(logSender, "", "starting quota scan for restored folder: %#v", folder.Name) go doFolderQuotaScan(folder) //nolint:errcheck } @@ -453,7 +454,7 @@ func RestoreUsers(users []dataprovider.User, inputFile string, mode, scanQuota i return fmt.Errorf("unable to restore user %#v: %w", user.Username, err) } if scanQuota == 1 || (scanQuota == 2 && user.HasQuotaRestrictions()) { - if dataprovider.QuotaScans.AddUserQuotaScan(user.Username) { + if common.QuotaScans.AddUserQuotaScan(user.Username) { logger.Debug(logSender, "", "starting quota scan for restored user: %#v", user.Username) go doUserQuotaScan(user) //nolint:errcheck } diff --git a/internal/httpd/api_quota.go b/internal/httpd/api_quota.go index c675a97c..a3e63baa 100644 --- a/internal/httpd/api_quota.go +++ b/internal/httpd/api_quota.go @@ -21,6 +21,7 @@ import ( "github.com/go-chi/render" + "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/vfs" @@ -43,12 +44,12 @@ type transferQuotaUsage struct { func getUsersQuotaScans(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - render.JSON(w, r, dataprovider.QuotaScans.GetUsersQuotaScans()) + render.JSON(w, r, common.QuotaScans.GetUsersQuotaScans()) } func getFoldersQuotaScans(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - render.JSON(w, r, dataprovider.QuotaScans.GetVFoldersQuotaScans()) + render.JSON(w, r, common.QuotaScans.GetVFoldersQuotaScans()) } func updateUserQuotaUsage(w http.ResponseWriter, r *http.Request) { @@ -141,11 +142,11 @@ func doUpdateUserQuotaUsage(w http.ResponseWriter, r *http.Request, username str "", http.StatusBadRequest) return } - if !dataprovider.QuotaScans.AddUserQuotaScan(user.Username) { + if !common.QuotaScans.AddUserQuotaScan(user.Username) { sendAPIResponse(w, r, err, "A quota scan is in progress for this user", http.StatusConflict) return } - defer dataprovider.QuotaScans.RemoveUserQuotaScan(user.Username) + defer common.QuotaScans.RemoveUserQuotaScan(user.Username) err = dataprovider.UpdateUserQuota(&user, usage.UsedQuotaFiles, usage.UsedQuotaSize, mode == quotaUpdateModeReset) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) @@ -170,11 +171,11 @@ func doUpdateFolderQuotaUsage(w http.ResponseWriter, r *http.Request, name strin sendAPIResponse(w, r, err, "", getRespStatus(err)) return } - if !dataprovider.QuotaScans.AddVFolderQuotaScan(folder.Name) { + if !common.QuotaScans.AddVFolderQuotaScan(folder.Name) { sendAPIResponse(w, r, err, "A quota scan is in progress for this folder", http.StatusConflict) return } - defer dataprovider.QuotaScans.RemoveVFolderQuotaScan(folder.Name) + defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name) err = dataprovider.UpdateVirtualFolderQuota(&folder, usage.UsedQuotaFiles, usage.UsedQuotaSize, mode == quotaUpdateModeReset) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) @@ -193,7 +194,7 @@ func doStartUserQuotaScan(w http.ResponseWriter, r *http.Request, username strin sendAPIResponse(w, r, err, "", getRespStatus(err)) return } - if !dataprovider.QuotaScans.AddUserQuotaScan(user.Username) { + if !common.QuotaScans.AddUserQuotaScan(user.Username) { sendAPIResponse(w, r, nil, fmt.Sprintf("Another scan is already in progress for user %#v", username), http.StatusConflict) return @@ -212,7 +213,7 @@ func doStartFolderQuotaScan(w http.ResponseWriter, r *http.Request, name string) sendAPIResponse(w, r, err, "", getRespStatus(err)) return } - if !dataprovider.QuotaScans.AddVFolderQuotaScan(folder.Name) { + if !common.QuotaScans.AddVFolderQuotaScan(folder.Name) { sendAPIResponse(w, r, err, fmt.Sprintf("Another scan is already in progress for folder %#v", name), http.StatusConflict) return @@ -222,7 +223,7 @@ func doStartFolderQuotaScan(w http.ResponseWriter, r *http.Request, name string) } func doUserQuotaScan(user dataprovider.User) error { - defer dataprovider.QuotaScans.RemoveUserQuotaScan(user.Username) + defer common.QuotaScans.RemoveUserQuotaScan(user.Username) numFiles, size, err := user.ScanQuota() if err != nil { logger.Warn(logSender, "", "error scanning user quota %#v: %v", user.Username, err) @@ -234,7 +235,7 @@ func doUserQuotaScan(user dataprovider.User) error { } func doFolderQuotaScan(folder vfs.BaseVirtualFolder) error { - defer dataprovider.QuotaScans.RemoveVFolderQuotaScan(folder.Name) + defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name) f := vfs.VirtualFolder{ BaseVirtualFolder: folder, VirtualPath: "/", diff --git a/internal/httpd/httpd_test.go b/internal/httpd/httpd_test.go index 4660196a..54ff1df1 100644 --- a/internal/httpd/httpd_test.go +++ b/internal/httpd/httpd_test.go @@ -1951,6 +1951,8 @@ func TestHTTPUserAuthEmptyPassword(t *testing.T) { c.CloseIdleConnections() assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, "") if assert.Error(t, err) { @@ -1986,6 +1988,8 @@ func TestHTTPAnonymousUser(t *testing.T) { c.CloseIdleConnections() assert.NoError(t, err) assert.Equal(t, http.StatusForbidden, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) if assert.Error(t, err) { @@ -9347,12 +9351,12 @@ func TestUpdateUserQuotaUsageMock(t *testing.T) { setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) - assert.True(t, dataprovider.QuotaScans.AddUserQuotaScan(user.Username)) + assert.True(t, common.QuotaScans.AddUserQuotaScan(user.Username)) req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "usage"), bytes.NewBuffer(userAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusConflict, rr) - assert.True(t, dataprovider.QuotaScans.RemoveUserQuotaScan(user.Username)) + assert.True(t, common.QuotaScans.RemoveUserQuotaScan(user.Username)) req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) setBearerForReq(req, token) rr = executeRequest(req) @@ -9624,12 +9628,12 @@ func TestStartQuotaScanMock(t *testing.T) { assert.NoError(t, err) } // simulate a duplicate quota scan - dataprovider.QuotaScans.AddUserQuotaScan(user.Username) + common.QuotaScans.AddUserQuotaScan(user.Username) req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusConflict, rr) - assert.True(t, dataprovider.QuotaScans.RemoveUserQuotaScan(user.Username)) + assert.True(t, common.QuotaScans.RemoveUserQuotaScan(user.Username)) req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil) setBearerForReq(req, token) @@ -9743,13 +9747,13 @@ func TestUpdateFolderQuotaUsageMock(t *testing.T) { rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) - assert.True(t, dataprovider.QuotaScans.AddVFolderQuotaScan(folderName)) + assert.True(t, common.QuotaScans.AddVFolderQuotaScan(folderName)) req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "folders", folder.Name, "usage"), bytes.NewBuffer(folderAsJSON)) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusConflict, rr) - assert.True(t, dataprovider.QuotaScans.RemoveVFolderQuotaScan(folderName)) + assert.True(t, common.QuotaScans.RemoveVFolderQuotaScan(folderName)) req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) setBearerForReq(req, token) @@ -9778,12 +9782,12 @@ func TestStartFolderQuotaScanMock(t *testing.T) { assert.NoError(t, err) } // simulate a duplicate quota scan - dataprovider.QuotaScans.AddVFolderQuotaScan(folderName) + common.QuotaScans.AddVFolderQuotaScan(folderName) req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "folders", folder.Name, "scan"), nil) setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusConflict, rr) - assert.True(t, dataprovider.QuotaScans.RemoveVFolderQuotaScan(folderName)) + assert.True(t, common.QuotaScans.RemoveVFolderQuotaScan(folderName)) // and now a real quota scan _, err = os.Stat(mappedPath) if err != nil && errors.Is(err, fs.ErrNotExist) { @@ -20392,7 +20396,7 @@ func startOIDCMockServer() { func waitForUsersQuotaScan(t *testing.T, token string) { for { - var scans []dataprovider.ActiveQuotaScan + var scans []common.ActiveQuotaScan req, _ := http.NewRequest(http.MethodGet, quotaScanPath, nil) setBearerForReq(req, token) rr := executeRequest(req) @@ -20410,7 +20414,7 @@ func waitForUsersQuotaScan(t *testing.T, token string) { } func waitForFoldersQuotaScanPath(t *testing.T, token string) { - var scans []dataprovider.ActiveVirtualFolderQuotaScan + var scans []common.ActiveVirtualFolderQuotaScan for { req, _ := http.NewRequest(http.MethodGet, quotaScanVFolderPath, nil) setBearerForReq(req, token) diff --git a/internal/httpd/internal_test.go b/internal/httpd/internal_test.go index 26cb127f..5868aa9b 100644 --- a/internal/httpd/internal_test.go +++ b/internal/httpd/internal_test.go @@ -1421,7 +1421,7 @@ func TestQuotaScanInvalidFs(t *testing.T) { Provider: sdk.S3FilesystemProvider, }, } - dataprovider.QuotaScans.AddUserQuotaScan(user.Username) + common.QuotaScans.AddUserQuotaScan(user.Username) err := doUserQuotaScan(user) assert.Error(t, err) } diff --git a/internal/httpdtest/httpdtest.go b/internal/httpdtest/httpdtest.go index 40ee67a4..b1ca67ab 100644 --- a/internal/httpdtest/httpdtest.go +++ b/internal/httpdtest/httpdtest.go @@ -833,8 +833,8 @@ func GetEventRules(limit, offset int64, expectedStatusCode int) ([]dataprovider. } // GetQuotaScans gets active quota scans for users and checks the received HTTP Status code against expectedStatusCode. -func GetQuotaScans(expectedStatusCode int) ([]dataprovider.ActiveQuotaScan, []byte, error) { - var quotaScans []dataprovider.ActiveQuotaScan +func GetQuotaScans(expectedStatusCode int) ([]common.ActiveQuotaScan, []byte, error) { + var quotaScans []common.ActiveQuotaScan var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanPath), nil, "", getDefaultToken()) if err != nil { @@ -1077,8 +1077,8 @@ func GetFolders(limit int64, offset int64, expectedStatusCode int) ([]vfs.BaseVi } // GetFoldersQuotaScans gets active quota scans for folders and checks the received HTTP Status code against expectedStatusCode. -func GetFoldersQuotaScans(expectedStatusCode int) ([]dataprovider.ActiveVirtualFolderQuotaScan, []byte, error) { - var quotaScans []dataprovider.ActiveVirtualFolderQuotaScan +func GetFoldersQuotaScans(expectedStatusCode int) ([]common.ActiveVirtualFolderQuotaScan, []byte, error) { + var quotaScans []common.ActiveVirtualFolderQuotaScan var body []byte resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanVFolderPath), nil, "", getDefaultToken()) if err != nil { diff --git a/internal/sftpd/internal_test.go b/internal/sftpd/internal_test.go index ac44c2a8..ae973c44 100644 --- a/internal/sftpd/internal_test.go +++ b/internal/sftpd/internal_test.go @@ -155,7 +155,7 @@ func newMockOsFs(err, statErr error, atomicUpload bool, connectionID, rootDir st } func TestRemoveNonexistentQuotaScan(t *testing.T) { - assert.False(t, dataprovider.QuotaScans.RemoveUserQuotaScan("username")) + assert.False(t, common.QuotaScans.RemoveUserQuotaScan("username")) } func TestGetOSOpenFlags(t *testing.T) { diff --git a/internal/sftpd/sftpd_test.go b/internal/sftpd/sftpd_test.go index f37880f8..4c93fdc7 100644 --- a/internal/sftpd/sftpd_test.go +++ b/internal/sftpd/sftpd_test.go @@ -1213,9 +1213,6 @@ func TestRealPath(t *testing.T) { for _, user := range []dataprovider.User{localUser, sftpUser} { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { - defer conn.Close() - defer client.Close() - p, err := client.RealPath("../..") assert.NoError(t, err) assert.Equal(t, "/", p) @@ -1247,6 +1244,9 @@ func TestRealPath(t *testing.T) { assert.NoError(t, err) _, err = client.RealPath(path.Join(subdir, "temp")) assert.ErrorIs(t, err, os.ErrPermission) + + conn.Close() + client.Close() err = os.Remove(filepath.Join(localUser.GetHomeDir(), subdir, "temp")) assert.NoError(t, err) if user.Username == localUser.Username { @@ -4590,14 +4590,14 @@ func TestQuotaScan(t *testing.T) { } func TestMultipleQuotaScans(t *testing.T) { - res := dataprovider.QuotaScans.AddUserQuotaScan(defaultUsername) + res := common.QuotaScans.AddUserQuotaScan(defaultUsername) assert.True(t, res) - res = dataprovider.QuotaScans.AddUserQuotaScan(defaultUsername) + res = common.QuotaScans.AddUserQuotaScan(defaultUsername) assert.False(t, res, "add quota must fail if another scan is already active") - assert.True(t, dataprovider.QuotaScans.RemoveUserQuotaScan(defaultUsername)) - activeScans := dataprovider.QuotaScans.GetUsersQuotaScans() + assert.True(t, common.QuotaScans.RemoveUserQuotaScan(defaultUsername)) + activeScans := common.QuotaScans.GetUsersQuotaScans() assert.Equal(t, 0, len(activeScans)) - assert.False(t, dataprovider.QuotaScans.RemoveUserQuotaScan(defaultUsername)) + assert.False(t, common.QuotaScans.RemoveUserQuotaScan(defaultUsername)) } func TestQuotaLimits(t *testing.T) { @@ -6949,15 +6949,15 @@ func TestVirtualFolderQuotaScan(t *testing.T) { func TestVFolderMultipleQuotaScan(t *testing.T) { folderName := "folder_name" - res := dataprovider.QuotaScans.AddVFolderQuotaScan(folderName) + res := common.QuotaScans.AddVFolderQuotaScan(folderName) assert.True(t, res) - res = dataprovider.QuotaScans.AddVFolderQuotaScan(folderName) + res = common.QuotaScans.AddVFolderQuotaScan(folderName) assert.False(t, res) - res = dataprovider.QuotaScans.RemoveVFolderQuotaScan(folderName) + res = common.QuotaScans.RemoveVFolderQuotaScan(folderName) assert.True(t, res) - activeScans := dataprovider.QuotaScans.GetVFoldersQuotaScans() + activeScans := common.QuotaScans.GetVFoldersQuotaScans() assert.Len(t, activeScans, 0) - res = dataprovider.QuotaScans.RemoveVFolderQuotaScan(folderName) + res = common.QuotaScans.RemoveVFolderQuotaScan(folderName) assert.False(t, res) } diff --git a/internal/util/util.go b/internal/util/util.go index 8a273940..d0e956d6 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -607,3 +607,10 @@ func GetTLSVersion(val int) uint16 { func IsEmailValid(email string) bool { return emailRegex.MatchString(email) } + +// PanicOnError calls panic if err is not nil +func PanicOnError(err error) { + if err != nil { + panic(fmt.Errorf("unexpected error: %w", err)) + } +} diff --git a/internal/vfs/cryptfs.go b/internal/vfs/cryptfs.go index 388c9f54..a92c51ee 100644 --- a/internal/vfs/cryptfs.go +++ b/internal/vfs/cryptfs.go @@ -156,7 +156,7 @@ func (fs *CryptFs) Create(name string, flag int) (File, *PipeWriter, func(), err if flag == 0 { f, err = os.Create(name) } else { - f, err = os.OpenFile(name, flag, os.ModePerm) + f, err = os.OpenFile(name, flag, 0666) } if err != nil { return nil, nil, nil, err