refactor: move eventmanager to common package
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
parent
3ca62d76d7
commit
9d2b5dc07d
24 changed files with 2030 additions and 1161 deletions
6
go.mod
6
go.mod
|
@ -52,7 +52,7 @@ require (
|
|||
github.com/rs/xid v1.4.0
|
||||
github.com/rs/zerolog v1.27.0
|
||||
github.com/sftpgo/sdk v0.1.2-0.20220727164210-06723ba7ce9a
|
||||
github.com/shirou/gopsutil/v3 v3.22.6
|
||||
github.com/shirou/gopsutil/v3 v3.22.7
|
||||
github.com/spf13/afero v1.9.2
|
||||
github.com/spf13/cobra v1.5.0
|
||||
github.com/spf13/viper v1.12.0
|
||||
|
@ -68,7 +68,7 @@ require (
|
|||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
|
||||
golang.org/x/net v0.0.0-20220728211354-c7608f3a8462
|
||||
golang.org/x/oauth2 v0.0.0-20220722155238-128564f6959c
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10
|
||||
golang.org/x/sys v0.0.0-20220731174439-a90be440212d
|
||||
golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9
|
||||
google.golang.org/api v0.90.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
|
@ -155,7 +155,7 @@ require (
|
|||
golang.org/x/tools v0.1.12 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/genproto v0.0.0-20220728213248-dd149ef739b9 // indirect
|
||||
google.golang.org/genproto v0.0.0-20220801145646-83ce21fca29f // indirect
|
||||
google.golang.org/grpc v1.48.0 // indirect
|
||||
google.golang.org/protobuf v1.28.1 // indirect
|
||||
gopkg.in/ini.v1 v1.66.6 // indirect
|
||||
|
|
12
go.sum
12
go.sum
|
@ -712,8 +712,8 @@ github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdh
|
|||
github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4 h1:PT+ElG/UUFMfqy5HrxJxNzj3QBOf7dZwupeVC+mG1Lo=
|
||||
github.com/sftpgo/sdk v0.1.2-0.20220727164210-06723ba7ce9a h1:X9qPZ+GPQ87TnBDNZN6dyX7FkjhwnFh98WgB6Y1T5O8=
|
||||
github.com/sftpgo/sdk v0.1.2-0.20220727164210-06723ba7ce9a/go.mod h1:RL4HeorXC6XgqtkLYnQUSogLdsdMfbsogIvdBVLuy4w=
|
||||
github.com/shirou/gopsutil/v3 v3.22.6 h1:FnHOFOh+cYAM0C30P+zysPISzlknLC5Z1G4EAElznfQ=
|
||||
github.com/shirou/gopsutil/v3 v3.22.6/go.mod h1:EdIubSnZhbAvBS1yJ7Xi+AShB/hxwLHOMz4MCYz7yMs=
|
||||
github.com/shirou/gopsutil/v3 v3.22.7 h1:flKnuCMfUUrO+oAvwAd6GKZgnPzr098VA/UJ14nhJd4=
|
||||
github.com/shirou/gopsutil/v3 v3.22.7/go.mod h1:s648gW4IywYzUfE/KjXxUsqrqx/T2xO5VqOXxONeRfI=
|
||||
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
|
||||
github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
|
||||
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
|
||||
|
@ -746,7 +746,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
|||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
|
||||
github.com/stretchr/testify v1.7.5/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/studio-b12/gowebdav v0.0.0-20220128162035-c7b1ff8a5e62 h1:b2nJXyPCa9HY7giGM+kYcnQ71m14JnGdQabMPmyt++8=
|
||||
|
@ -973,8 +972,9 @@ golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220624220833-87e55d714810/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80=
|
||||
golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
|
@ -1225,8 +1225,8 @@ google.golang.org/genproto v0.0.0-20220616135557-88e70c0c3a90/go.mod h1:KEWEmljW
|
|||
google.golang.org/genproto v0.0.0-20220617124728-180714bec0ad/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA=
|
||||
google.golang.org/genproto v0.0.0-20220624142145-8cd45d7dbd1f/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA=
|
||||
google.golang.org/genproto v0.0.0-20220628213854-d9e0b6570c03/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA=
|
||||
google.golang.org/genproto v0.0.0-20220728213248-dd149ef739b9 h1:d3fKQZK+1rWQMg3xLKQbPMirUCo29I/NRdI2WarSzTg=
|
||||
google.golang.org/genproto v0.0.0-20220728213248-dd149ef739b9/go.mod h1:iHe1svFLAZg9VWz891+QbRMwUv9O/1Ww+/mngYeThbc=
|
||||
google.golang.org/genproto v0.0.0-20220801145646-83ce21fca29f h1:XVHpVMvPs4MtH3h6cThzKs2snNexcfd35vQx2T3IuIY=
|
||||
google.golang.org/genproto v0.0.0-20220801145646-83ce21fca29f/go.mod h1:iHe1svFLAZg9VWz891+QbRMwUv9O/1Ww+/mngYeThbc=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
||||
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
|
||||
|
|
|
@ -102,7 +102,7 @@ func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtua
|
|||
) error {
|
||||
hasNotifiersPlugin := plugin.Handler.HasNotifiers()
|
||||
hasHook := util.Contains(Config.Actions.ExecuteOn, operation)
|
||||
hasRules := dataprovider.EventManager.HasFsRules()
|
||||
hasRules := eventManager.hasFsRules()
|
||||
if !hasHook && !hasNotifiersPlugin && !hasRules {
|
||||
return nil
|
||||
}
|
||||
|
@ -113,7 +113,7 @@ func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtua
|
|||
}
|
||||
var errRes error
|
||||
if hasRules {
|
||||
errRes = dataprovider.EventManager.HandleFsEvent(dataprovider.EventParams{
|
||||
errRes = eventManager.handleFsEvent(EventParams{
|
||||
Name: notification.Username,
|
||||
Event: notification.Action,
|
||||
Status: notification.Status,
|
||||
|
|
|
@ -138,11 +138,11 @@ var (
|
|||
// Config is the configuration for the supported protocols
|
||||
Config Configuration
|
||||
// Connections is the list of active connections
|
||||
Connections ActiveConnections
|
||||
transfersChecker TransfersChecker
|
||||
periodicTimeoutTicker *time.Ticker
|
||||
periodicTimeoutTickerDone chan bool
|
||||
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
|
||||
Connections ActiveConnections
|
||||
// QuotaScans is the list of active quota scans
|
||||
QuotaScans ActiveScans
|
||||
transfersChecker TransfersChecker
|
||||
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
|
||||
ProtocolHTTP, ProtocolHTTPShare, ProtocolOIDC}
|
||||
disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
|
||||
// the map key is the protocol, for each protocol we can have multiple rate limiters
|
||||
|
@ -157,7 +157,7 @@ func Initialize(c Configuration, isShared int) error {
|
|||
Config.ProxyAllowed = util.RemoveDuplicates(Config.ProxyAllowed, true)
|
||||
Config.idleLoginTimeout = 2 * time.Minute
|
||||
Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute
|
||||
startPeriodicTimeoutTicker(periodicTimeoutCheckInterval)
|
||||
startPeriodicChecks(periodicTimeoutCheckInterval)
|
||||
Config.defender = nil
|
||||
Config.whitelist = nil
|
||||
rateLimiters = make(map[string][]*rateLimiter)
|
||||
|
@ -308,35 +308,18 @@ func AddDefenderEvent(ip string, event HostEvent) {
|
|||
Config.defender.AddEvent(ip, event)
|
||||
}
|
||||
|
||||
// the ticker cannot be started/stopped from multiple goroutines
|
||||
func startPeriodicTimeoutTicker(duration time.Duration) {
|
||||
stopPeriodicTimeoutTicker()
|
||||
periodicTimeoutTicker = time.NewTicker(duration)
|
||||
periodicTimeoutTickerDone = make(chan bool)
|
||||
go func() {
|
||||
counter := int64(0)
|
||||
func startPeriodicChecks(duration time.Duration) {
|
||||
startEventScheduler()
|
||||
spec := fmt.Sprintf("@every %s", duration)
|
||||
_, err := eventScheduler.AddFunc(spec, Connections.checkTransfers)
|
||||
util.PanicOnError(err)
|
||||
logger.Info(logSender, "", "scheduled overquota transfers check, schedule %q", spec)
|
||||
if Config.IdleTimeout > 0 {
|
||||
ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval
|
||||
for {
|
||||
select {
|
||||
case <-periodicTimeoutTickerDone:
|
||||
return
|
||||
case <-periodicTimeoutTicker.C:
|
||||
counter++
|
||||
if Config.IdleTimeout > 0 && counter >= int64(ratio) {
|
||||
counter = 0
|
||||
Connections.checkIdles()
|
||||
}
|
||||
go Connections.checkTransfers()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func stopPeriodicTimeoutTicker() {
|
||||
if periodicTimeoutTicker != nil {
|
||||
periodicTimeoutTicker.Stop()
|
||||
periodicTimeoutTickerDone <- true
|
||||
periodicTimeoutTicker = nil
|
||||
spec = fmt.Sprintf("@every %s", duration*ratio)
|
||||
_, err = eventScheduler.AddFunc(spec, Connections.checkIdles)
|
||||
util.PanicOnError(err)
|
||||
logger.Info(logSender, "", "scheduled idle connections check, schedule %q", spec)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1162,3 +1145,117 @@ func (c *ConnectionStatus) GetTransfersAsString() string {
|
|||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ActiveQuotaScan defines an active quota scan for a user home dir
|
||||
type ActiveQuotaScan struct {
|
||||
// Username to which the quota scan refers
|
||||
Username string `json:"username"`
|
||||
// quota scan start time as unix timestamp in milliseconds
|
||||
StartTime int64 `json:"start_time"`
|
||||
}
|
||||
|
||||
// ActiveVirtualFolderQuotaScan defines an active quota scan for a virtual folder
|
||||
type ActiveVirtualFolderQuotaScan struct {
|
||||
// folder name to which the quota scan refers
|
||||
Name string `json:"name"`
|
||||
// quota scan start time as unix timestamp in milliseconds
|
||||
StartTime int64 `json:"start_time"`
|
||||
}
|
||||
|
||||
// ActiveScans holds the active quota scans
|
||||
type ActiveScans struct {
|
||||
sync.RWMutex
|
||||
UserScans []ActiveQuotaScan
|
||||
FolderScans []ActiveVirtualFolderQuotaScan
|
||||
}
|
||||
|
||||
// GetUsersQuotaScans returns the active quota scans for users home directories
|
||||
func (s *ActiveScans) GetUsersQuotaScans() []ActiveQuotaScan {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
scans := make([]ActiveQuotaScan, len(s.UserScans))
|
||||
copy(scans, s.UserScans)
|
||||
return scans
|
||||
}
|
||||
|
||||
// AddUserQuotaScan adds a user to the ones with active quota scans.
|
||||
// Returns false if the user has a quota scan already running
|
||||
func (s *ActiveScans) AddUserQuotaScan(username string) bool {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
for _, scan := range s.UserScans {
|
||||
if scan.Username == username {
|
||||
return false
|
||||
}
|
||||
}
|
||||
s.UserScans = append(s.UserScans, ActiveQuotaScan{
|
||||
Username: username,
|
||||
StartTime: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
// RemoveUserQuotaScan removes a user from the ones with active quota scans.
|
||||
// Returns false if the user has no active quota scans
|
||||
func (s *ActiveScans) RemoveUserQuotaScan(username string) bool {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
for idx, scan := range s.UserScans {
|
||||
if scan.Username == username {
|
||||
lastIdx := len(s.UserScans) - 1
|
||||
s.UserScans[idx] = s.UserScans[lastIdx]
|
||||
s.UserScans = s.UserScans[:lastIdx]
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetVFoldersQuotaScans returns the active quota scans for virtual folders
|
||||
func (s *ActiveScans) GetVFoldersQuotaScans() []ActiveVirtualFolderQuotaScan {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
scans := make([]ActiveVirtualFolderQuotaScan, len(s.FolderScans))
|
||||
copy(scans, s.FolderScans)
|
||||
return scans
|
||||
}
|
||||
|
||||
// AddVFolderQuotaScan adds a virtual folder to the ones with active quota scans.
|
||||
// Returns false if the folder has a quota scan already running
|
||||
func (s *ActiveScans) AddVFolderQuotaScan(folderName string) bool {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
for _, scan := range s.FolderScans {
|
||||
if scan.Name == folderName {
|
||||
return false
|
||||
}
|
||||
}
|
||||
s.FolderScans = append(s.FolderScans, ActiveVirtualFolderQuotaScan{
|
||||
Name: folderName,
|
||||
StartTime: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
// RemoveVFolderQuotaScan removes a folder from the ones with active quota scans.
|
||||
// Returns false if the folder has no active quota scans
|
||||
func (s *ActiveScans) RemoveVFolderQuotaScan(folderName string) bool {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
for idx, scan := range s.FolderScans {
|
||||
if scan.Name == folderName {
|
||||
lastIdx := len(s.FolderScans) - 1
|
||||
s.FolderScans[idx] = s.FolderScans[lastIdx]
|
||||
s.FolderScans = s.FolderScans[:lastIdx]
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
@ -533,19 +532,19 @@ func TestIdleConnections(t *testing.T) {
|
|||
assert.Len(t, Connections.sshConnections, 2)
|
||||
Connections.RUnlock()
|
||||
|
||||
startPeriodicTimeoutTicker(100 * time.Millisecond)
|
||||
assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond)
|
||||
startPeriodicChecks(100 * time.Millisecond)
|
||||
assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 2*time.Second, 200*time.Millisecond)
|
||||
assert.Eventually(t, func() bool {
|
||||
Connections.RLock()
|
||||
defer Connections.RUnlock()
|
||||
return len(Connections.sshConnections) == 1
|
||||
}, 1*time.Second, 200*time.Millisecond)
|
||||
stopPeriodicTimeoutTicker()
|
||||
stopEventScheduler()
|
||||
assert.Len(t, Connections.GetStats(), 2)
|
||||
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
|
||||
cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
|
||||
sshConn2.lastActivity = c.lastActivity
|
||||
startPeriodicTimeoutTicker(100 * time.Millisecond)
|
||||
startPeriodicChecks(100 * time.Millisecond)
|
||||
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond)
|
||||
assert.Eventually(t, func() bool {
|
||||
Connections.RLock()
|
||||
|
@ -553,7 +552,7 @@ func TestIdleConnections(t *testing.T) {
|
|||
return len(Connections.sshConnections) == 0
|
||||
}, 1*time.Second, 200*time.Millisecond)
|
||||
assert.Equal(t, int32(0), Connections.GetClientConnections())
|
||||
stopPeriodicTimeoutTicker()
|
||||
stopEventScheduler()
|
||||
assert.True(t, customConn1.isClosed)
|
||||
assert.True(t, customConn2.isClosed)
|
||||
|
||||
|
@ -719,6 +718,35 @@ func TestConnectionStatus(t *testing.T) {
|
|||
assert.Len(t, stats, 0)
|
||||
}
|
||||
|
||||
func TestQuotaScans(t *testing.T) {
|
||||
username := "username"
|
||||
assert.True(t, QuotaScans.AddUserQuotaScan(username))
|
||||
assert.False(t, QuotaScans.AddUserQuotaScan(username))
|
||||
usersScans := QuotaScans.GetUsersQuotaScans()
|
||||
if assert.Len(t, usersScans, 1) {
|
||||
assert.Equal(t, usersScans[0].Username, username)
|
||||
assert.Equal(t, QuotaScans.UserScans[0].StartTime, usersScans[0].StartTime)
|
||||
QuotaScans.UserScans[0].StartTime = 0
|
||||
assert.NotEqual(t, QuotaScans.UserScans[0].StartTime, usersScans[0].StartTime)
|
||||
}
|
||||
|
||||
assert.True(t, QuotaScans.RemoveUserQuotaScan(username))
|
||||
assert.False(t, QuotaScans.RemoveUserQuotaScan(username))
|
||||
assert.Len(t, QuotaScans.GetUsersQuotaScans(), 0)
|
||||
assert.Len(t, usersScans, 1)
|
||||
|
||||
folderName := "folder"
|
||||
assert.True(t, QuotaScans.AddVFolderQuotaScan(folderName))
|
||||
assert.False(t, QuotaScans.AddVFolderQuotaScan(folderName))
|
||||
if assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 1) {
|
||||
assert.Equal(t, QuotaScans.GetVFoldersQuotaScans()[0].Name, folderName)
|
||||
}
|
||||
|
||||
assert.True(t, QuotaScans.RemoveVFolderQuotaScan(folderName))
|
||||
assert.False(t, QuotaScans.RemoveVFolderQuotaScan(folderName))
|
||||
assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 0)
|
||||
}
|
||||
|
||||
func TestProxyProtocolVersion(t *testing.T) {
|
||||
c := Configuration{
|
||||
ProxyProtocol: 0,
|
||||
|
@ -1033,110 +1061,6 @@ func TestUserRecentActivity(t *testing.T) {
|
|||
assert.True(t, res)
|
||||
}
|
||||
|
||||
func TestEventRuleMatch(t *testing.T) {
|
||||
conditions := dataprovider.EventConditions{
|
||||
ProviderEvents: []string{"add", "update"},
|
||||
Options: dataprovider.ConditionOptions{
|
||||
Names: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: "user1",
|
||||
InverseMatch: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
res := conditions.ProviderEventMatch(dataprovider.EventParams{
|
||||
Name: "user1",
|
||||
Event: "add",
|
||||
})
|
||||
assert.False(t, res)
|
||||
res = conditions.ProviderEventMatch(dataprovider.EventParams{
|
||||
Name: "user2",
|
||||
Event: "update",
|
||||
})
|
||||
assert.True(t, res)
|
||||
res = conditions.ProviderEventMatch(dataprovider.EventParams{
|
||||
Name: "user2",
|
||||
Event: "delete",
|
||||
})
|
||||
assert.False(t, res)
|
||||
conditions.Options.ProviderObjects = []string{"api_key"}
|
||||
res = conditions.ProviderEventMatch(dataprovider.EventParams{
|
||||
Name: "user2",
|
||||
Event: "update",
|
||||
ObjectType: "share",
|
||||
})
|
||||
assert.False(t, res)
|
||||
res = conditions.ProviderEventMatch(dataprovider.EventParams{
|
||||
Name: "user2",
|
||||
Event: "update",
|
||||
ObjectType: "api_key",
|
||||
})
|
||||
assert.True(t, res)
|
||||
// now test fs events
|
||||
conditions = dataprovider.EventConditions{
|
||||
FsEvents: []string{operationUpload, operationDownload},
|
||||
Options: dataprovider.ConditionOptions{
|
||||
Names: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: "user*",
|
||||
},
|
||||
{
|
||||
Pattern: "tester*",
|
||||
},
|
||||
},
|
||||
FsPaths: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: "*.txt",
|
||||
},
|
||||
},
|
||||
Protocols: []string{ProtocolSFTP},
|
||||
MinFileSize: 10,
|
||||
MaxFileSize: 30,
|
||||
},
|
||||
}
|
||||
params := dataprovider.EventParams{
|
||||
Name: "tester4",
|
||||
Event: operationDelete,
|
||||
VirtualPath: "/path.txt",
|
||||
Protocol: ProtocolSFTP,
|
||||
ObjectName: "path.txt",
|
||||
FileSize: 20,
|
||||
}
|
||||
res = conditions.FsEventMatch(params)
|
||||
assert.False(t, res)
|
||||
params.Event = operationDownload
|
||||
res = conditions.FsEventMatch(params)
|
||||
assert.True(t, res)
|
||||
params.Name = "name"
|
||||
res = conditions.FsEventMatch(params)
|
||||
assert.False(t, res)
|
||||
params.Name = "user5"
|
||||
res = conditions.FsEventMatch(params)
|
||||
assert.True(t, res)
|
||||
params.VirtualPath = "/sub/f.jpg"
|
||||
params.ObjectName = path.Base(params.VirtualPath)
|
||||
res = conditions.FsEventMatch(params)
|
||||
assert.False(t, res)
|
||||
params.VirtualPath = "/sub/f.txt"
|
||||
params.ObjectName = path.Base(params.VirtualPath)
|
||||
res = conditions.FsEventMatch(params)
|
||||
assert.True(t, res)
|
||||
params.Protocol = ProtocolHTTP
|
||||
res = conditions.FsEventMatch(params)
|
||||
assert.False(t, res)
|
||||
params.Protocol = ProtocolSFTP
|
||||
params.FileSize = 5
|
||||
res = conditions.FsEventMatch(params)
|
||||
assert.False(t, res)
|
||||
params.FileSize = 50
|
||||
res = conditions.FsEventMatch(params)
|
||||
assert.False(t, res)
|
||||
params.FileSize = 25
|
||||
res = conditions.FsEventMatch(params)
|
||||
assert.True(t, res)
|
||||
}
|
||||
|
||||
func BenchmarkBcryptHashing(b *testing.B) {
|
||||
bcryptPassword := "bcryptpassword"
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
|
776
internal/common/eventmanager.go
Normal file
776
internal/common/eventmanager.go
Normal file
|
@ -0,0 +1,776 @@
|
|||
// Copyright (C) 2019-2022 Nicola Murino
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published
|
||||
// by the Free Software Foundation, version 3.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
|
||||
"github.com/drakkan/sftpgo/v2/internal/logger"
|
||||
"github.com/drakkan/sftpgo/v2/internal/plugin"
|
||||
"github.com/drakkan/sftpgo/v2/internal/smtp"
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
"github.com/drakkan/sftpgo/v2/internal/vfs"
|
||||
)
|
||||
|
||||
var (
|
||||
// eventManager handle the supported event rules actions
|
||||
eventManager eventRulesContainer
|
||||
)
|
||||
|
||||
func init() {
|
||||
eventManager = eventRulesContainer{
|
||||
schedulesMapping: make(map[string][]cron.EntryID),
|
||||
}
|
||||
dataprovider.SetEventRulesCallbacks(eventManager.loadRules, eventManager.RemoveRule,
|
||||
func(operation, executor, ip, objectType, objectName string, object plugin.Renderer) {
|
||||
eventManager.handleProviderEvent(EventParams{
|
||||
Name: executor,
|
||||
ObjectName: objectName,
|
||||
Event: operation,
|
||||
Status: 1,
|
||||
ObjectType: objectType,
|
||||
IP: ip,
|
||||
Timestamp: time.Now().UnixNano(),
|
||||
Object: object,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// eventRulesContainer stores event rules by trigger
|
||||
type eventRulesContainer struct {
|
||||
sync.RWMutex
|
||||
FsEvents []dataprovider.EventRule
|
||||
ProviderEvents []dataprovider.EventRule
|
||||
Schedules []dataprovider.EventRule
|
||||
schedulesMapping map[string][]cron.EntryID
|
||||
lastLoad int64
|
||||
}
|
||||
|
||||
func (r *eventRulesContainer) getLastLoadTime() int64 {
|
||||
return atomic.LoadInt64(&r.lastLoad)
|
||||
}
|
||||
|
||||
func (r *eventRulesContainer) setLastLoadTime(modTime int64) {
|
||||
atomic.StoreInt64(&r.lastLoad, modTime)
|
||||
}
|
||||
|
||||
// RemoveRule deletes the rule with the specified name
|
||||
func (r *eventRulesContainer) RemoveRule(name string) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
r.removeRuleInternal(name)
|
||||
eventManagerLog(logger.LevelDebug, "event rules updated after delete, fs events: %d, provider events: %d, schedules: %d",
|
||||
len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules))
|
||||
}
|
||||
|
||||
func (r *eventRulesContainer) removeRuleInternal(name string) {
|
||||
for idx := range r.FsEvents {
|
||||
if r.FsEvents[idx].Name == name {
|
||||
lastIdx := len(r.FsEvents) - 1
|
||||
r.FsEvents[idx] = r.FsEvents[lastIdx]
|
||||
r.FsEvents = r.FsEvents[:lastIdx]
|
||||
eventManagerLog(logger.LevelDebug, "removed rule %q from fs events", name)
|
||||
return
|
||||
}
|
||||
}
|
||||
for idx := range r.ProviderEvents {
|
||||
if r.ProviderEvents[idx].Name == name {
|
||||
lastIdx := len(r.ProviderEvents) - 1
|
||||
r.ProviderEvents[idx] = r.ProviderEvents[lastIdx]
|
||||
r.ProviderEvents = r.ProviderEvents[:lastIdx]
|
||||
eventManagerLog(logger.LevelDebug, "removed rule %q from provider events", name)
|
||||
return
|
||||
}
|
||||
}
|
||||
for idx := range r.Schedules {
|
||||
if r.Schedules[idx].Name == name {
|
||||
if schedules, ok := r.schedulesMapping[name]; ok {
|
||||
for _, entryID := range schedules {
|
||||
eventManagerLog(logger.LevelDebug, "removing scheduled entry id %d for rule %q", entryID, name)
|
||||
eventScheduler.Remove(entryID)
|
||||
}
|
||||
delete(r.schedulesMapping, name)
|
||||
}
|
||||
|
||||
lastIdx := len(r.Schedules) - 1
|
||||
r.Schedules[idx] = r.Schedules[lastIdx]
|
||||
r.Schedules = r.Schedules[:lastIdx]
|
||||
eventManagerLog(logger.LevelDebug, "removed rule %q from scheduled events", name)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *eventRulesContainer) addUpdateRuleInternal(rule dataprovider.EventRule) {
|
||||
r.removeRuleInternal(rule.Name)
|
||||
if rule.DeletedAt > 0 {
|
||||
deletedAt := util.GetTimeFromMsecSinceEpoch(rule.DeletedAt)
|
||||
if deletedAt.Add(30 * time.Minute).Before(time.Now()) {
|
||||
eventManagerLog(logger.LevelDebug, "removing rule %q deleted at %s", rule.Name, deletedAt)
|
||||
go dataprovider.RemoveEventRule(rule) //nolint:errcheck
|
||||
}
|
||||
return
|
||||
}
|
||||
switch rule.Trigger {
|
||||
case dataprovider.EventTriggerFsEvent:
|
||||
r.FsEvents = append(r.FsEvents, rule)
|
||||
eventManagerLog(logger.LevelDebug, "added rule %q to fs events", rule.Name)
|
||||
case dataprovider.EventTriggerProviderEvent:
|
||||
r.ProviderEvents = append(r.ProviderEvents, rule)
|
||||
eventManagerLog(logger.LevelDebug, "added rule %q to provider events", rule.Name)
|
||||
case dataprovider.EventTriggerSchedule:
|
||||
for _, schedule := range rule.Conditions.Schedules {
|
||||
cronSpec := schedule.GetCronSpec()
|
||||
job := &eventCronJob{
|
||||
ruleName: dataprovider.ConvertName(rule.Name),
|
||||
}
|
||||
entryID, err := eventScheduler.AddJob(cronSpec, job)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to add scheduled rule %q, cron string %q: %v", rule.Name, cronSpec, err)
|
||||
return
|
||||
}
|
||||
r.schedulesMapping[rule.Name] = append(r.schedulesMapping[rule.Name], entryID)
|
||||
eventManagerLog(logger.LevelDebug, "schedule for rule %q added, id: %d, cron string %q, active scheduling rules: %d",
|
||||
rule.Name, entryID, cronSpec, len(r.schedulesMapping))
|
||||
}
|
||||
r.Schedules = append(r.Schedules, rule)
|
||||
eventManagerLog(logger.LevelDebug, "added rule %q to scheduled events", rule.Name)
|
||||
default:
|
||||
eventManagerLog(logger.LevelError, "unsupported trigger: %d", rule.Trigger)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *eventRulesContainer) loadRules() {
|
||||
eventManagerLog(logger.LevelDebug, "loading updated rules")
|
||||
modTime := util.GetTimeAsMsSinceEpoch(time.Now())
|
||||
rules, err := dataprovider.GetRecentlyUpdatedRules(r.getLastLoadTime())
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to load event rules: %v", err)
|
||||
return
|
||||
}
|
||||
eventManagerLog(logger.LevelDebug, "recently updated event rules loaded: %d", len(rules))
|
||||
|
||||
if len(rules) > 0 {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
for _, rule := range rules {
|
||||
r.addUpdateRuleInternal(rule)
|
||||
}
|
||||
}
|
||||
eventManagerLog(logger.LevelDebug, "event rules updated, fs events: %d, provider events: %d, schedules: %d",
|
||||
len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules))
|
||||
|
||||
r.setLastLoadTime(modTime)
|
||||
}
|
||||
|
||||
func (r *eventRulesContainer) checkProviderEventMatch(conditions dataprovider.EventConditions, params EventParams) bool {
|
||||
if !util.Contains(conditions.ProviderEvents, params.Event) {
|
||||
return false
|
||||
}
|
||||
if !checkEventConditionPatterns(params.Name, conditions.Options.Names) {
|
||||
return false
|
||||
}
|
||||
if len(conditions.Options.ProviderObjects) > 0 && !util.Contains(conditions.Options.ProviderObjects, params.ObjectType) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *eventRulesContainer) checkFsEventMatch(conditions dataprovider.EventConditions, params EventParams) bool {
|
||||
if !util.Contains(conditions.FsEvents, params.Event) {
|
||||
return false
|
||||
}
|
||||
if !checkEventConditionPatterns(params.Name, conditions.Options.Names) {
|
||||
return false
|
||||
}
|
||||
if !checkEventConditionPatterns(params.VirtualPath, conditions.Options.FsPaths) {
|
||||
if !checkEventConditionPatterns(params.ObjectName, conditions.Options.FsPaths) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(conditions.Options.Protocols) > 0 && !util.Contains(conditions.Options.Protocols, params.Protocol) {
|
||||
return false
|
||||
}
|
||||
if params.Event == operationUpload || params.Event == operationDownload {
|
||||
if conditions.Options.MinFileSize > 0 {
|
||||
if params.FileSize < conditions.Options.MinFileSize {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if conditions.Options.MaxFileSize > 0 {
|
||||
if params.FileSize > conditions.Options.MaxFileSize {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// hasFsRules returns true if there are any rules for filesystem event triggers
|
||||
func (r *eventRulesContainer) hasFsRules() bool {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
|
||||
return len(r.FsEvents) > 0
|
||||
}
|
||||
|
||||
// handleFsEvent executes the rules actions defined for the specified event
|
||||
func (r *eventRulesContainer) handleFsEvent(params EventParams) error {
|
||||
r.RLock()
|
||||
|
||||
var rulesWithSyncActions, rulesAsync []dataprovider.EventRule
|
||||
for _, rule := range r.FsEvents {
|
||||
if r.checkFsEventMatch(rule.Conditions, params) {
|
||||
hasSyncActions := false
|
||||
for _, action := range rule.Actions {
|
||||
if action.Options.ExecuteSync {
|
||||
hasSyncActions = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasSyncActions {
|
||||
rulesWithSyncActions = append(rulesWithSyncActions, rule)
|
||||
} else {
|
||||
rulesAsync = append(rulesAsync, rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.RUnlock()
|
||||
|
||||
if len(rulesAsync) > 0 {
|
||||
go executeAsyncRulesActions(rulesAsync, params)
|
||||
}
|
||||
|
||||
if len(rulesWithSyncActions) > 0 {
|
||||
return executeSyncRulesActions(rulesWithSyncActions, params)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *eventRulesContainer) handleProviderEvent(params EventParams) {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
|
||||
var rules []dataprovider.EventRule
|
||||
for _, rule := range r.ProviderEvents {
|
||||
if r.checkProviderEventMatch(rule.Conditions, params) {
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
}
|
||||
|
||||
if len(rules) > 0 {
|
||||
go executeAsyncRulesActions(rules, params)
|
||||
}
|
||||
}
|
||||
|
||||
// EventParams defines the supported event parameters
|
||||
type EventParams struct {
|
||||
Name string
|
||||
Event string
|
||||
Status int
|
||||
VirtualPath string
|
||||
FsPath string
|
||||
VirtualTargetPath string
|
||||
FsTargetPath string
|
||||
ObjectName string
|
||||
ObjectType string
|
||||
FileSize int64
|
||||
Protocol string
|
||||
IP string
|
||||
Timestamp int64
|
||||
Object plugin.Renderer
|
||||
}
|
||||
|
||||
func (p *EventParams) getStringReplacements(addObjectData bool) []string {
|
||||
replacements := []string{
|
||||
"{{Name}}", p.Name,
|
||||
"{{Event}}", p.Event,
|
||||
"{{Status}}", fmt.Sprintf("%d", p.Status),
|
||||
"{{VirtualPath}}", p.VirtualPath,
|
||||
"{{FsPath}}", p.FsPath,
|
||||
"{{VirtualTargetPath}}", p.VirtualTargetPath,
|
||||
"{{FsTargetPath}}", p.FsTargetPath,
|
||||
"{{ObjectName}}", p.ObjectName,
|
||||
"{{ObjectType}}", p.ObjectType,
|
||||
"{{FileSize}}", fmt.Sprintf("%d", p.FileSize),
|
||||
"{{Protocol}}", p.Protocol,
|
||||
"{{IP}}", p.IP,
|
||||
"{{Timestamp}}", fmt.Sprintf("%d", p.Timestamp),
|
||||
}
|
||||
if addObjectData {
|
||||
data, err := p.Object.RenderAsJSON(p.Event != operationDelete)
|
||||
if err == nil {
|
||||
replacements = append(replacements, "{{ObjectData}}", string(data))
|
||||
}
|
||||
}
|
||||
return replacements
|
||||
}
|
||||
|
||||
func replaceWithReplacer(input string, replacer *strings.Replacer) string {
|
||||
if !strings.Contains(input, "{{") {
|
||||
return input
|
||||
}
|
||||
return replacer.Replace(input)
|
||||
}
|
||||
|
||||
func checkEventConditionPattern(p dataprovider.ConditionPattern, name string) bool {
|
||||
matched, err := path.Match(p.Pattern, name)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "pattern matching error %q, err: %v", p.Pattern, err)
|
||||
return false
|
||||
}
|
||||
if p.InverseMatch {
|
||||
return !matched
|
||||
}
|
||||
return matched
|
||||
}
|
||||
|
||||
// checkConditionPatterns returns false if patterns are defined and no match is found
|
||||
func checkEventConditionPatterns(name string, patterns []dataprovider.ConditionPattern) bool {
|
||||
if len(patterns) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, p := range patterns {
|
||||
if checkEventConditionPattern(p, name) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func getHTTPRuleActionEndpoint(c dataprovider.EventActionHTTPConfig, replacer *strings.Replacer) (string, error) {
|
||||
if len(c.QueryParameters) > 0 {
|
||||
u, err := url.Parse(c.Endpoint)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid endpoint: %w", err)
|
||||
}
|
||||
q := u.Query()
|
||||
|
||||
for _, keyVal := range c.QueryParameters {
|
||||
q.Add(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer))
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
return c.Endpoint, nil
|
||||
}
|
||||
|
||||
func executeHTTPRuleAction(c dataprovider.EventActionHTTPConfig, params EventParams) error {
|
||||
if !c.Password.IsEmpty() {
|
||||
if err := c.Password.TryDecrypt(); err != nil {
|
||||
return fmt.Errorf("unable to decrypt password: %w", err)
|
||||
}
|
||||
}
|
||||
addObjectData := false
|
||||
if params.Object != nil {
|
||||
if !addObjectData {
|
||||
if strings.Contains(c.Body, "{{ObjectData}}") {
|
||||
addObjectData = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
replacements := params.getStringReplacements(addObjectData)
|
||||
replacer := strings.NewReplacer(replacements...)
|
||||
endpoint, err := getHTTPRuleActionEndpoint(c, replacer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var body io.Reader
|
||||
if c.Body != "" && c.Method != http.MethodGet {
|
||||
body = bytes.NewBufferString(replaceWithReplacer(c.Body, replacer))
|
||||
}
|
||||
req, err := http.NewRequest(c.Method, endpoint, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.Username != "" {
|
||||
req.SetBasicAuth(replaceWithReplacer(c.Username, replacer), c.Password.GetAdditionalData())
|
||||
}
|
||||
for _, keyVal := range c.Headers {
|
||||
req.Header.Set(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer))
|
||||
}
|
||||
client := c.GetHTTPClient()
|
||||
defer client.CloseIdleConnections()
|
||||
|
||||
startTime := time.Now()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelDebug, "unable to send http notification, endpoint: %s, elapsed: %s, err: %v",
|
||||
endpoint, time.Since(startTime), err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
eventManagerLog(logger.LevelDebug, "http notification sent, endopoint: %s, elapsed: %s, status code: %d",
|
||||
endpoint, time.Since(startTime), resp.StatusCode)
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent {
|
||||
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func executeCommandRuleAction(c dataprovider.EventActionCommandConfig, params EventParams) error {
|
||||
envVars := make([]string, 0, len(c.EnvVars))
|
||||
addObjectData := false
|
||||
if params.Object != nil {
|
||||
for _, k := range c.EnvVars {
|
||||
if strings.Contains(k.Value, "{{ObjectData}}") {
|
||||
addObjectData = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
replacements := params.getStringReplacements(addObjectData)
|
||||
replacer := strings.NewReplacer(replacements...)
|
||||
for _, keyVal := range c.EnvVars {
|
||||
envVars = append(envVars, fmt.Sprintf("%s=%s", keyVal.Key, replaceWithReplacer(keyVal.Value, replacer)))
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, c.Cmd)
|
||||
cmd.Env = append(cmd.Env, os.Environ()...)
|
||||
cmd.Env = append(cmd.Env, envVars...)
|
||||
|
||||
startTime := time.Now()
|
||||
err := cmd.Run()
|
||||
|
||||
eventManagerLog(logger.LevelDebug, "executed command %q, elapsed: %s, error: %v",
|
||||
c.Cmd, time.Since(startTime), err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func executeEmailRuleAction(c dataprovider.EventActionEmailConfig, params EventParams) error {
|
||||
addObjectData := false
|
||||
if params.Object != nil {
|
||||
if strings.Contains(c.Body, "{{ObjectData}}") {
|
||||
addObjectData = true
|
||||
}
|
||||
}
|
||||
replacements := params.getStringReplacements(addObjectData)
|
||||
replacer := strings.NewReplacer(replacements...)
|
||||
body := replaceWithReplacer(c.Body, replacer)
|
||||
subject := replaceWithReplacer(c.Subject, replacer)
|
||||
startTime := time.Now()
|
||||
err := smtp.SendEmail(c.Recipients, subject, body, smtp.EmailContentTypeTextPlain)
|
||||
eventManagerLog(logger.LevelDebug, "executed email notification action, elapsed: %s, error: %v",
|
||||
time.Since(startTime), err)
|
||||
return err
|
||||
}
|
||||
|
||||
func executeUsersQuotaResetRuleAction(conditions dataprovider.ConditionOptions) error {
|
||||
users, err := dataprovider.DumpUsers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get users: %w", err)
|
||||
}
|
||||
var failedResets []string
|
||||
for _, user := range users {
|
||||
if !checkEventConditionPatterns(user.Username, conditions.Names) {
|
||||
eventManagerLog(logger.LevelDebug, "skipping scheduled quota reset for user %s, name conditions don't match",
|
||||
user.Username)
|
||||
continue
|
||||
}
|
||||
if !QuotaScans.AddUserQuotaScan(user.Username) {
|
||||
eventManagerLog(logger.LevelError, "another quota scan is already in progress for user %s", user.Username)
|
||||
failedResets = append(failedResets, user.Username)
|
||||
continue
|
||||
}
|
||||
numFiles, size, err := user.ScanQuota()
|
||||
QuotaScans.RemoveUserQuotaScan(user.Username)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "error scanning quota for user %s: %v", user.Username, err)
|
||||
failedResets = append(failedResets, user.Username)
|
||||
continue
|
||||
}
|
||||
err = dataprovider.UpdateUserQuota(&user, numFiles, size, true)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "error updating quota for user %s: %v", user.Username, err)
|
||||
failedResets = append(failedResets, user.Username)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(failedResets) > 0 {
|
||||
return fmt.Errorf("quota reset failed for users: %+v", failedResets)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func executeFoldersQuotaResetRuleAction(conditions dataprovider.ConditionOptions) error {
|
||||
folders, err := dataprovider.DumpFolders()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get folders: %w", err)
|
||||
}
|
||||
var failedResets []string
|
||||
for _, folder := range folders {
|
||||
if !checkEventConditionPatterns(folder.Name, conditions.Names) {
|
||||
eventManagerLog(logger.LevelDebug, "skipping scheduled quota reset for folder %s, name conditions don't match",
|
||||
folder.Name)
|
||||
continue
|
||||
}
|
||||
if !QuotaScans.AddVFolderQuotaScan(folder.Name) {
|
||||
eventManagerLog(logger.LevelError, "another quota scan is already in progress for folder %s", folder.Name)
|
||||
failedResets = append(failedResets, folder.Name)
|
||||
continue
|
||||
}
|
||||
f := vfs.VirtualFolder{
|
||||
BaseVirtualFolder: folder,
|
||||
VirtualPath: "/",
|
||||
}
|
||||
numFiles, size, err := f.ScanQuota()
|
||||
QuotaScans.RemoveVFolderQuotaScan(folder.Name)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "error scanning quota for folder %s: %v", folder.Name, err)
|
||||
failedResets = append(failedResets, folder.Name)
|
||||
continue
|
||||
}
|
||||
err = dataprovider.UpdateVirtualFolderQuota(&folder, numFiles, size, true)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "error updating quota for folder %s: %v", folder.Name, err)
|
||||
failedResets = append(failedResets, folder.Name)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(failedResets) > 0 {
|
||||
return fmt.Errorf("quota reset failed for folders: %+v", failedResets)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func executeTransferQuotaResetRuleAction(conditions dataprovider.ConditionOptions) error {
|
||||
users, err := dataprovider.DumpUsers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get users: %w", err)
|
||||
}
|
||||
var failedResets []string
|
||||
for _, user := range users {
|
||||
if !checkEventConditionPatterns(user.Username, conditions.Names) {
|
||||
eventManagerLog(logger.LevelDebug, "skipping scheduled transfer quota reset for user %s, name conditions don't match",
|
||||
user.Username)
|
||||
continue
|
||||
}
|
||||
err = dataprovider.UpdateUserTransferQuota(&user, 0, 0, true)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "error updating transfer quota for user %s: %v", user.Username, err)
|
||||
failedResets = append(failedResets, user.Username)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(failedResets) > 0 {
|
||||
return fmt.Errorf("transfer quota reset failed for users: %+v", failedResets)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func executeRuleAction(action dataprovider.BaseEventAction, params EventParams, conditions dataprovider.ConditionOptions) error {
|
||||
switch action.Type {
|
||||
case dataprovider.ActionTypeHTTP:
|
||||
return executeHTTPRuleAction(action.Options.HTTPConfig, params)
|
||||
case dataprovider.ActionTypeCommand:
|
||||
return executeCommandRuleAction(action.Options.CmdConfig, params)
|
||||
case dataprovider.ActionTypeEmail:
|
||||
return executeEmailRuleAction(action.Options.EmailConfig, params)
|
||||
case dataprovider.ActionTypeBackup:
|
||||
return dataprovider.ExecuteBackup()
|
||||
case dataprovider.ActionTypeUserQuotaReset:
|
||||
return executeUsersQuotaResetRuleAction(conditions)
|
||||
case dataprovider.ActionTypeFolderQuotaReset:
|
||||
return executeFoldersQuotaResetRuleAction(conditions)
|
||||
case dataprovider.ActionTypeTransferQuotaReset:
|
||||
return executeTransferQuotaResetRuleAction(conditions)
|
||||
default:
|
||||
return fmt.Errorf("unsupported action type: %d", action.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func executeSyncRulesActions(rules []dataprovider.EventRule, params EventParams) error {
|
||||
var errRes error
|
||||
|
||||
for _, rule := range rules {
|
||||
var failedActions []string
|
||||
for _, action := range rule.Actions {
|
||||
if !action.Options.IsFailureAction && action.Options.ExecuteSync {
|
||||
startTime := time.Now()
|
||||
if err := executeRuleAction(action.BaseEventAction, params, rule.Conditions.Options); err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to execute sync action %q for rule %q, elapsed %s, err: %v",
|
||||
action.Name, rule.Name, time.Since(startTime), err)
|
||||
failedActions = append(failedActions, action.Name)
|
||||
// we return the last error, it is ok for now
|
||||
errRes = err
|
||||
if action.Options.StopOnFailure {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
eventManagerLog(logger.LevelDebug, "executed sync action %q for rule %q, elapsed: %s",
|
||||
action.Name, rule.Name, time.Since(startTime))
|
||||
}
|
||||
}
|
||||
}
|
||||
// execute async actions if any, including failure actions
|
||||
go executeRuleAsyncActions(rule, params, failedActions)
|
||||
}
|
||||
|
||||
return errRes
|
||||
}
|
||||
|
||||
func executeAsyncRulesActions(rules []dataprovider.EventRule, params EventParams) {
|
||||
for _, rule := range rules {
|
||||
executeRuleAsyncActions(rule, params, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func executeRuleAsyncActions(rule dataprovider.EventRule, params EventParams, failedActions []string) {
|
||||
for _, action := range rule.Actions {
|
||||
if !action.Options.IsFailureAction && !action.Options.ExecuteSync {
|
||||
startTime := time.Now()
|
||||
if err := executeRuleAction(action.BaseEventAction, params, rule.Conditions.Options); err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to execute action %q for rule %q, elapsed %s, err: %v",
|
||||
action.Name, rule.Name, time.Since(startTime), err)
|
||||
failedActions = append(failedActions, action.Name)
|
||||
if action.Options.StopOnFailure {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
eventManagerLog(logger.LevelDebug, "executed action %q for rule %q, elapsed %s",
|
||||
action.Name, rule.Name, time.Since(startTime))
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(failedActions) > 0 {
|
||||
// execute failure actions
|
||||
for _, action := range rule.Actions {
|
||||
if action.Options.IsFailureAction {
|
||||
startTime := time.Now()
|
||||
if err := executeRuleAction(action.BaseEventAction, params, rule.Conditions.Options); err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to execute failure action %q for rule %q, elapsed %s, err: %v",
|
||||
action.Name, rule.Name, time.Since(startTime), err)
|
||||
if action.Options.StopOnFailure {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
eventManagerLog(logger.LevelDebug, "executed failure action %q for rule %q, elapsed: %s",
|
||||
action.Name, rule.Name, time.Since(startTime))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type eventCronJob struct {
|
||||
ruleName string
|
||||
}
|
||||
|
||||
func (j *eventCronJob) getTask(rule dataprovider.EventRule) (dataprovider.Task, error) {
|
||||
if rule.GuardFromConcurrentExecution() {
|
||||
task, err := dataprovider.GetTaskByName(rule.Name)
|
||||
if _, ok := err.(*util.RecordNotFoundError); ok {
|
||||
eventManagerLog(logger.LevelDebug, "adding task for rule %q", rule.Name)
|
||||
task = dataprovider.Task{
|
||||
Name: rule.Name,
|
||||
UpdateAt: 0,
|
||||
Version: 0,
|
||||
}
|
||||
err = dataprovider.AddTask(rule.Name)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelWarn, "unable to add task for rule %q: %v", rule.Name, err)
|
||||
return task, err
|
||||
}
|
||||
} else {
|
||||
eventManagerLog(logger.LevelWarn, "unable to get task for rule %q: %v", rule.Name, err)
|
||||
}
|
||||
return task, err
|
||||
}
|
||||
|
||||
return dataprovider.Task{}, nil
|
||||
}
|
||||
|
||||
func (j *eventCronJob) Run() {
|
||||
eventManagerLog(logger.LevelDebug, "executing scheduled rule %q", j.ruleName)
|
||||
rule, err := dataprovider.EventRuleExists(j.ruleName)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to load rule with name %q", j.ruleName)
|
||||
return
|
||||
}
|
||||
task, err := j.getTask(rule)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if task.Name != "" {
|
||||
updateInterval := 5 * time.Minute
|
||||
updatedAt := util.GetTimeFromMsecSinceEpoch(task.UpdateAt)
|
||||
if updatedAt.Add(updateInterval*2 + 1).After(time.Now()) {
|
||||
eventManagerLog(logger.LevelDebug, "task for rule %q too recent: %s, skip execution", rule.Name, updatedAt)
|
||||
return
|
||||
}
|
||||
err = dataprovider.UpdateTask(rule.Name, task.Version)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelInfo, "unable to update task timestamp for rule %q, skip execution, err: %v",
|
||||
rule.Name, err)
|
||||
return
|
||||
}
|
||||
ticker := time.NewTicker(updateInterval)
|
||||
done := make(chan bool)
|
||||
|
||||
go func(taskName string) {
|
||||
eventManagerLog(logger.LevelDebug, "update task %q timestamp worker started", taskName)
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
eventManagerLog(logger.LevelDebug, "update task %q timestamp worker finished", taskName)
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := dataprovider.UpdateTaskTimestamp(taskName)
|
||||
eventManagerLog(logger.LevelInfo, "updated timestamp for task %q, err: %v", taskName, err)
|
||||
}
|
||||
}
|
||||
}(task.Name)
|
||||
|
||||
executeRuleAsyncActions(rule, EventParams{}, nil)
|
||||
|
||||
done <- true
|
||||
ticker.Stop()
|
||||
} else {
|
||||
executeRuleAsyncActions(rule, EventParams{}, nil)
|
||||
}
|
||||
eventManagerLog(logger.LevelDebug, "execution for scheduled rule %q finished", j.ruleName)
|
||||
}
|
||||
|
||||
func eventManagerLog(level logger.LogLevel, format string, v ...any) {
|
||||
logger.Log(level, "eventmanager", "", format, v...)
|
||||
}
|
674
internal/common/eventmanager_test.go
Normal file
674
internal/common/eventmanager_test.go
Normal file
|
@ -0,0 +1,674 @@
|
|||
// Copyright (C) 2019-2022 Nicola Murino
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published
|
||||
// by the Free Software Foundation, version 3.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sftpgo/sdk"
|
||||
sdkkms "github.com/sftpgo/sdk/kms"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
|
||||
"github.com/drakkan/sftpgo/v2/internal/kms"
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
"github.com/drakkan/sftpgo/v2/internal/vfs"
|
||||
)
|
||||
|
||||
func TestEventRuleMatch(t *testing.T) {
|
||||
conditions := dataprovider.EventConditions{
|
||||
ProviderEvents: []string{"add", "update"},
|
||||
Options: dataprovider.ConditionOptions{
|
||||
Names: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: "user1",
|
||||
InverseMatch: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
res := eventManager.checkProviderEventMatch(conditions, EventParams{
|
||||
Name: "user1",
|
||||
Event: "add",
|
||||
})
|
||||
assert.False(t, res)
|
||||
res = eventManager.checkProviderEventMatch(conditions, EventParams{
|
||||
Name: "user2",
|
||||
Event: "update",
|
||||
})
|
||||
assert.True(t, res)
|
||||
res = eventManager.checkProviderEventMatch(conditions, EventParams{
|
||||
Name: "user2",
|
||||
Event: "delete",
|
||||
})
|
||||
assert.False(t, res)
|
||||
conditions.Options.ProviderObjects = []string{"api_key"}
|
||||
res = eventManager.checkProviderEventMatch(conditions, EventParams{
|
||||
Name: "user2",
|
||||
Event: "update",
|
||||
ObjectType: "share",
|
||||
})
|
||||
assert.False(t, res)
|
||||
res = eventManager.checkProviderEventMatch(conditions, EventParams{
|
||||
Name: "user2",
|
||||
Event: "update",
|
||||
ObjectType: "api_key",
|
||||
})
|
||||
assert.True(t, res)
|
||||
// now test fs events
|
||||
conditions = dataprovider.EventConditions{
|
||||
FsEvents: []string{operationUpload, operationDownload},
|
||||
Options: dataprovider.ConditionOptions{
|
||||
Names: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: "user*",
|
||||
},
|
||||
{
|
||||
Pattern: "tester*",
|
||||
},
|
||||
},
|
||||
FsPaths: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: "*.txt",
|
||||
},
|
||||
},
|
||||
Protocols: []string{ProtocolSFTP},
|
||||
MinFileSize: 10,
|
||||
MaxFileSize: 30,
|
||||
},
|
||||
}
|
||||
params := EventParams{
|
||||
Name: "tester4",
|
||||
Event: operationDelete,
|
||||
VirtualPath: "/path.txt",
|
||||
Protocol: ProtocolSFTP,
|
||||
ObjectName: "path.txt",
|
||||
FileSize: 20,
|
||||
}
|
||||
res = eventManager.checkFsEventMatch(conditions, params)
|
||||
assert.False(t, res)
|
||||
params.Event = operationDownload
|
||||
res = eventManager.checkFsEventMatch(conditions, params)
|
||||
assert.True(t, res)
|
||||
params.Name = "name"
|
||||
res = eventManager.checkFsEventMatch(conditions, params)
|
||||
assert.False(t, res)
|
||||
params.Name = "user5"
|
||||
res = eventManager.checkFsEventMatch(conditions, params)
|
||||
assert.True(t, res)
|
||||
params.VirtualPath = "/sub/f.jpg"
|
||||
params.ObjectName = path.Base(params.VirtualPath)
|
||||
res = eventManager.checkFsEventMatch(conditions, params)
|
||||
assert.False(t, res)
|
||||
params.VirtualPath = "/sub/f.txt"
|
||||
params.ObjectName = path.Base(params.VirtualPath)
|
||||
res = eventManager.checkFsEventMatch(conditions, params)
|
||||
assert.True(t, res)
|
||||
params.Protocol = ProtocolHTTP
|
||||
res = eventManager.checkFsEventMatch(conditions, params)
|
||||
assert.False(t, res)
|
||||
params.Protocol = ProtocolSFTP
|
||||
params.FileSize = 5
|
||||
res = eventManager.checkFsEventMatch(conditions, params)
|
||||
assert.False(t, res)
|
||||
params.FileSize = 50
|
||||
res = eventManager.checkFsEventMatch(conditions, params)
|
||||
assert.False(t, res)
|
||||
params.FileSize = 25
|
||||
res = eventManager.checkFsEventMatch(conditions, params)
|
||||
assert.True(t, res)
|
||||
// bad pattern
|
||||
conditions.Options.Names = []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: "[-]",
|
||||
},
|
||||
}
|
||||
res = eventManager.checkFsEventMatch(conditions, params)
|
||||
assert.False(t, res)
|
||||
}
|
||||
|
||||
func TestEventManager(t *testing.T) {
|
||||
startEventScheduler()
|
||||
action := &dataprovider.BaseEventAction{
|
||||
Name: "test_action",
|
||||
Type: dataprovider.ActionTypeHTTP,
|
||||
Options: dataprovider.BaseEventActionOptions{
|
||||
HTTPConfig: dataprovider.EventActionHTTPConfig{
|
||||
Endpoint: "http://localhost",
|
||||
Timeout: 20,
|
||||
Method: http.MethodGet,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := dataprovider.AddEventAction(action, "", "")
|
||||
assert.NoError(t, err)
|
||||
rule := &dataprovider.EventRule{
|
||||
Name: "rule",
|
||||
Trigger: dataprovider.EventTriggerFsEvent,
|
||||
Conditions: dataprovider.EventConditions{
|
||||
FsEvents: []string{operationUpload},
|
||||
},
|
||||
Actions: []dataprovider.EventAction{
|
||||
{
|
||||
BaseEventAction: dataprovider.BaseEventAction{
|
||||
Name: action.Name,
|
||||
},
|
||||
Order: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = dataprovider.AddEventRule(rule, "", "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
eventManager.RLock()
|
||||
assert.Len(t, eventManager.FsEvents, 1)
|
||||
assert.Len(t, eventManager.ProviderEvents, 0)
|
||||
assert.Len(t, eventManager.Schedules, 0)
|
||||
assert.Len(t, eventManager.schedulesMapping, 0)
|
||||
eventManager.RUnlock()
|
||||
|
||||
rule.Trigger = dataprovider.EventTriggerProviderEvent
|
||||
rule.Conditions = dataprovider.EventConditions{
|
||||
ProviderEvents: []string{"add"},
|
||||
}
|
||||
err = dataprovider.UpdateEventRule(rule, "", "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
eventManager.RLock()
|
||||
assert.Len(t, eventManager.FsEvents, 0)
|
||||
assert.Len(t, eventManager.ProviderEvents, 1)
|
||||
assert.Len(t, eventManager.Schedules, 0)
|
||||
assert.Len(t, eventManager.schedulesMapping, 0)
|
||||
eventManager.RUnlock()
|
||||
|
||||
rule.Trigger = dataprovider.EventTriggerSchedule
|
||||
rule.Conditions = dataprovider.EventConditions{
|
||||
Schedules: []dataprovider.Schedule{
|
||||
{
|
||||
Hours: "0",
|
||||
DayOfWeek: "*",
|
||||
DayOfMonth: "*",
|
||||
Month: "*",
|
||||
},
|
||||
},
|
||||
}
|
||||
rule.DeletedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-12 * time.Hour))
|
||||
eventManager.addUpdateRuleInternal(*rule)
|
||||
|
||||
eventManager.RLock()
|
||||
assert.Len(t, eventManager.FsEvents, 0)
|
||||
assert.Len(t, eventManager.ProviderEvents, 0)
|
||||
assert.Len(t, eventManager.Schedules, 0)
|
||||
assert.Len(t, eventManager.schedulesMapping, 0)
|
||||
eventManager.RUnlock()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
_, err = dataprovider.EventRuleExists(rule.Name)
|
||||
_, ok := err.(*util.RecordNotFoundError)
|
||||
return ok
|
||||
}, 2*time.Second, 100*time.Millisecond)
|
||||
|
||||
rule.DeletedAt = 0
|
||||
err = dataprovider.AddEventRule(rule, "", "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
eventManager.RLock()
|
||||
assert.Len(t, eventManager.FsEvents, 0)
|
||||
assert.Len(t, eventManager.ProviderEvents, 0)
|
||||
assert.Len(t, eventManager.Schedules, 1)
|
||||
assert.Len(t, eventManager.schedulesMapping, 1)
|
||||
eventManager.RUnlock()
|
||||
|
||||
err = dataprovider.DeleteEventRule(rule.Name, "", "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
eventManager.RLock()
|
||||
assert.Len(t, eventManager.FsEvents, 0)
|
||||
assert.Len(t, eventManager.ProviderEvents, 0)
|
||||
assert.Len(t, eventManager.Schedules, 0)
|
||||
assert.Len(t, eventManager.schedulesMapping, 0)
|
||||
eventManager.RUnlock()
|
||||
|
||||
err = dataprovider.DeleteEventAction(action.Name, "", "")
|
||||
assert.NoError(t, err)
|
||||
stopEventScheduler()
|
||||
}
|
||||
|
||||
func TestEventManagerErrors(t *testing.T) {
|
||||
startEventScheduler()
|
||||
providerConf := dataprovider.GetProviderConfig()
|
||||
err := dataprovider.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = executeUsersQuotaResetRuleAction(dataprovider.ConditionOptions{})
|
||||
assert.Error(t, err)
|
||||
err = executeFoldersQuotaResetRuleAction(dataprovider.ConditionOptions{})
|
||||
assert.Error(t, err)
|
||||
err = executeTransferQuotaResetRuleAction(dataprovider.ConditionOptions{})
|
||||
assert.Error(t, err)
|
||||
|
||||
eventManager.loadRules()
|
||||
|
||||
eventManager.RLock()
|
||||
assert.Len(t, eventManager.FsEvents, 0)
|
||||
assert.Len(t, eventManager.ProviderEvents, 0)
|
||||
assert.Len(t, eventManager.Schedules, 0)
|
||||
eventManager.RUnlock()
|
||||
|
||||
// rule with invalid trigger
|
||||
eventManager.addUpdateRuleInternal(dataprovider.EventRule{
|
||||
Name: "test rule",
|
||||
Trigger: -1,
|
||||
})
|
||||
|
||||
eventManager.RLock()
|
||||
assert.Len(t, eventManager.FsEvents, 0)
|
||||
assert.Len(t, eventManager.ProviderEvents, 0)
|
||||
assert.Len(t, eventManager.Schedules, 0)
|
||||
eventManager.RUnlock()
|
||||
// rule with invalid cronspec
|
||||
eventManager.addUpdateRuleInternal(dataprovider.EventRule{
|
||||
Name: "test rule",
|
||||
Trigger: dataprovider.EventTriggerSchedule,
|
||||
Conditions: dataprovider.EventConditions{
|
||||
Schedules: []dataprovider.Schedule{
|
||||
{
|
||||
Hours: "1000",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
eventManager.RLock()
|
||||
assert.Len(t, eventManager.FsEvents, 0)
|
||||
assert.Len(t, eventManager.ProviderEvents, 0)
|
||||
assert.Len(t, eventManager.Schedules, 0)
|
||||
eventManager.RUnlock()
|
||||
|
||||
err = dataprovider.Initialize(providerConf, configDir, true)
|
||||
assert.NoError(t, err)
|
||||
stopEventScheduler()
|
||||
}
|
||||
|
||||
func TestEventRuleActions(t *testing.T) {
|
||||
actionName := "test rule action"
|
||||
action := dataprovider.BaseEventAction{
|
||||
Name: actionName,
|
||||
Type: dataprovider.ActionTypeBackup,
|
||||
}
|
||||
err := executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{})
|
||||
assert.NoError(t, err)
|
||||
action.Type = -1
|
||||
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{})
|
||||
assert.Error(t, err)
|
||||
|
||||
action = dataprovider.BaseEventAction{
|
||||
Name: actionName,
|
||||
Type: dataprovider.ActionTypeHTTP,
|
||||
Options: dataprovider.BaseEventActionOptions{
|
||||
HTTPConfig: dataprovider.EventActionHTTPConfig{
|
||||
Endpoint: "http://foo\x7f.com/", // invalid URL
|
||||
SkipTLSVerify: true,
|
||||
Body: "{{ObjectData}}",
|
||||
Method: http.MethodPost,
|
||||
QueryParameters: []dataprovider.KeyValue{
|
||||
{
|
||||
Key: "param",
|
||||
Value: "value",
|
||||
},
|
||||
},
|
||||
Timeout: 5,
|
||||
Headers: []dataprovider.KeyValue{
|
||||
{
|
||||
Key: "Content-Type",
|
||||
Value: "application/json",
|
||||
},
|
||||
},
|
||||
Username: "httpuser",
|
||||
},
|
||||
},
|
||||
}
|
||||
action.Options.SetEmptySecretsIfNil()
|
||||
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{})
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "invalid endpoint")
|
||||
}
|
||||
action.Options.HTTPConfig.Endpoint = fmt.Sprintf("http://%v", httpAddr)
|
||||
params := EventParams{
|
||||
Name: "a",
|
||||
Object: &dataprovider.User{
|
||||
BaseUser: sdk.BaseUser{
|
||||
Username: "test user",
|
||||
},
|
||||
},
|
||||
}
|
||||
err = executeRuleAction(action, params, dataprovider.ConditionOptions{})
|
||||
assert.NoError(t, err)
|
||||
action.Options.HTTPConfig.Endpoint = fmt.Sprintf("http://%v/404", httpAddr)
|
||||
err = executeRuleAction(action, params, dataprovider.ConditionOptions{})
|
||||
if assert.Error(t, err) {
|
||||
assert.Equal(t, err.Error(), "unexpected status code: 404")
|
||||
}
|
||||
action.Options.HTTPConfig.Endpoint = "http://invalid:1234"
|
||||
err = executeRuleAction(action, params, dataprovider.ConditionOptions{})
|
||||
assert.Error(t, err)
|
||||
action.Options.HTTPConfig.QueryParameters = nil
|
||||
action.Options.HTTPConfig.Endpoint = "http://bar\x7f.com/"
|
||||
err = executeRuleAction(action, params, dataprovider.ConditionOptions{})
|
||||
assert.Error(t, err)
|
||||
action.Options.HTTPConfig.Password = kms.NewSecret(sdkkms.SecretStatusSecretBox, "payload", "key", "data")
|
||||
err = executeRuleAction(action, params, dataprovider.ConditionOptions{})
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "unable to decrypt password")
|
||||
}
|
||||
// test disk and transfer quota reset
|
||||
username1 := "user1"
|
||||
username2 := "user2"
|
||||
user1 := dataprovider.User{
|
||||
BaseUser: sdk.BaseUser{
|
||||
Username: username1,
|
||||
HomeDir: filepath.Join(os.TempDir(), username1),
|
||||
Status: 1,
|
||||
Permissions: map[string][]string{
|
||||
"/": {dataprovider.PermAny},
|
||||
},
|
||||
},
|
||||
}
|
||||
user2 := dataprovider.User{
|
||||
BaseUser: sdk.BaseUser{
|
||||
Username: username2,
|
||||
HomeDir: filepath.Join(os.TempDir(), username2),
|
||||
Status: 1,
|
||||
Permissions: map[string][]string{
|
||||
"/": {dataprovider.PermAny},
|
||||
},
|
||||
},
|
||||
}
|
||||
err = dataprovider.AddUser(&user1, "", "")
|
||||
assert.NoError(t, err)
|
||||
err = dataprovider.AddUser(&user2, "", "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
action = dataprovider.BaseEventAction{
|
||||
Type: dataprovider.ActionTypeUserQuotaReset,
|
||||
}
|
||||
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{
|
||||
Names: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: username1,
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Error(t, err) // no home dir
|
||||
// create the home dir
|
||||
err = os.MkdirAll(user1.GetHomeDir(), os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = os.WriteFile(filepath.Join(user1.GetHomeDir(), "file.txt"), []byte("user"), 0666)
|
||||
assert.NoError(t, err)
|
||||
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{
|
||||
Names: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: username1,
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
userGet, err := dataprovider.UserExists(username1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, userGet.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(4), userGet.UsedQuotaSize)
|
||||
// simulate another quota scan in progress
|
||||
assert.True(t, QuotaScans.AddUserQuotaScan(username1))
|
||||
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{
|
||||
Names: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: username1,
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.True(t, QuotaScans.RemoveUserQuotaScan(username1))
|
||||
|
||||
err = os.RemoveAll(user1.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = dataprovider.UpdateUserTransferQuota(&user1, 100, 100, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
action.Type = dataprovider.ActionTypeTransferQuotaReset
|
||||
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{
|
||||
Names: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: username1,
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
userGet, err = dataprovider.UserExists(username1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(0), userGet.UsedDownloadDataTransfer)
|
||||
assert.Equal(t, int64(0), userGet.UsedUploadDataTransfer)
|
||||
|
||||
err = dataprovider.DeleteUser(username1, "", "")
|
||||
assert.NoError(t, err)
|
||||
err = dataprovider.DeleteUser(username2, "", "")
|
||||
assert.NoError(t, err)
|
||||
// test folder quota reset
|
||||
foldername1 := "f1"
|
||||
foldername2 := "f2"
|
||||
folder1 := vfs.BaseVirtualFolder{
|
||||
Name: foldername1,
|
||||
MappedPath: filepath.Join(os.TempDir(), foldername1),
|
||||
}
|
||||
folder2 := vfs.BaseVirtualFolder{
|
||||
Name: foldername2,
|
||||
MappedPath: filepath.Join(os.TempDir(), foldername2),
|
||||
}
|
||||
err = dataprovider.AddFolder(&folder1, "", "")
|
||||
assert.NoError(t, err)
|
||||
err = dataprovider.AddFolder(&folder2, "", "")
|
||||
assert.NoError(t, err)
|
||||
action = dataprovider.BaseEventAction{
|
||||
Type: dataprovider.ActionTypeFolderQuotaReset,
|
||||
}
|
||||
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{
|
||||
Names: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: foldername1,
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Error(t, err) // no home dir
|
||||
err = os.MkdirAll(folder1.MappedPath, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = os.WriteFile(filepath.Join(folder1.MappedPath, "file.txt"), []byte("folder"), 0666)
|
||||
assert.NoError(t, err)
|
||||
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{
|
||||
Names: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: foldername1,
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
folderGet, err := dataprovider.GetFolderByName(foldername1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, folderGet.UsedQuotaFiles)
|
||||
assert.Equal(t, int64(6), folderGet.UsedQuotaSize)
|
||||
// simulate another quota scan in progress
|
||||
assert.True(t, QuotaScans.AddVFolderQuotaScan(foldername1))
|
||||
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{
|
||||
Names: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: foldername1,
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.True(t, QuotaScans.RemoveVFolderQuotaScan(foldername1))
|
||||
|
||||
err = os.RemoveAll(folder1.MappedPath)
|
||||
assert.NoError(t, err)
|
||||
err = dataprovider.DeleteFolder(foldername1, "", "")
|
||||
assert.NoError(t, err)
|
||||
err = dataprovider.DeleteFolder(foldername2, "", "")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestQuotaActionsWithQuotaTrackDisabled(t *testing.T) {
|
||||
oldProviderConf := dataprovider.GetProviderConfig()
|
||||
providerConf := dataprovider.GetProviderConfig()
|
||||
providerConf.TrackQuota = 0
|
||||
err := dataprovider.Close()
|
||||
assert.NoError(t, err)
|
||||
err = dataprovider.Initialize(providerConf, configDir, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
username := "u1"
|
||||
user := dataprovider.User{
|
||||
BaseUser: sdk.BaseUser{
|
||||
Username: username,
|
||||
HomeDir: filepath.Join(os.TempDir(), username),
|
||||
Status: 1,
|
||||
Permissions: map[string][]string{
|
||||
"/": {dataprovider.PermAny},
|
||||
},
|
||||
},
|
||||
FsConfig: vfs.Filesystem{
|
||||
Provider: sdk.LocalFilesystemProvider,
|
||||
},
|
||||
}
|
||||
err = dataprovider.AddUser(&user, "", "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.MkdirAll(user.GetHomeDir(), os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeUserQuotaReset},
|
||||
EventParams{}, dataprovider.ConditionOptions{
|
||||
Names: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: username,
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
||||
err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeTransferQuotaReset},
|
||||
EventParams{}, dataprovider.ConditionOptions{
|
||||
Names: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: username,
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
err = dataprovider.DeleteUser(username, "", "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
foldername := "f1"
|
||||
folder := vfs.BaseVirtualFolder{
|
||||
Name: foldername,
|
||||
MappedPath: filepath.Join(os.TempDir(), foldername),
|
||||
}
|
||||
err = dataprovider.AddFolder(&folder, "", "")
|
||||
assert.NoError(t, err)
|
||||
err = os.MkdirAll(folder.MappedPath, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeFolderQuotaReset},
|
||||
EventParams{}, dataprovider.ConditionOptions{
|
||||
Names: []dataprovider.ConditionPattern{
|
||||
{
|
||||
Pattern: foldername,
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
||||
err = os.RemoveAll(folder.MappedPath)
|
||||
assert.NoError(t, err)
|
||||
err = dataprovider.DeleteFolder(foldername, "", "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = dataprovider.Close()
|
||||
assert.NoError(t, err)
|
||||
err = dataprovider.Initialize(oldProviderConf, configDir, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestScheduledActions(t *testing.T) {
|
||||
startEventScheduler()
|
||||
backupsPath := filepath.Join(os.TempDir(), "backups")
|
||||
err := os.RemoveAll(backupsPath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
action := &dataprovider.BaseEventAction{
|
||||
Name: "action",
|
||||
Type: dataprovider.ActionTypeBackup,
|
||||
}
|
||||
err = dataprovider.AddEventAction(action, "", "")
|
||||
assert.NoError(t, err)
|
||||
rule := &dataprovider.EventRule{
|
||||
Name: "rule",
|
||||
Trigger: dataprovider.EventTriggerSchedule,
|
||||
Conditions: dataprovider.EventConditions{
|
||||
Schedules: []dataprovider.Schedule{
|
||||
{
|
||||
Hours: "11",
|
||||
DayOfWeek: "*",
|
||||
DayOfMonth: "*",
|
||||
Month: "*",
|
||||
},
|
||||
},
|
||||
},
|
||||
Actions: []dataprovider.EventAction{
|
||||
{
|
||||
BaseEventAction: dataprovider.BaseEventAction{
|
||||
Name: action.Name,
|
||||
},
|
||||
Order: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
job := eventCronJob{
|
||||
ruleName: rule.Name,
|
||||
}
|
||||
job.Run() // rule not found
|
||||
assert.NoDirExists(t, backupsPath)
|
||||
|
||||
err = dataprovider.AddEventRule(rule, "", "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
job.Run()
|
||||
assert.DirExists(t, backupsPath)
|
||||
|
||||
err = dataprovider.DeleteEventRule(rule.Name, "", "")
|
||||
assert.NoError(t, err)
|
||||
err = dataprovider.DeleteEventAction(action.Name, "", "")
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(backupsPath)
|
||||
assert.NoError(t, err)
|
||||
stopEventScheduler()
|
||||
}
|
43
internal/common/eventscheduler.go
Normal file
43
internal/common/eventscheduler.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
// Copyright (C) 2019-2022 Nicola Murino
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published
|
||||
// by the Free Software Foundation, version 3.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
)
|
||||
|
||||
var (
|
||||
eventScheduler *cron.Cron
|
||||
)
|
||||
|
||||
func stopEventScheduler() {
|
||||
if eventScheduler != nil {
|
||||
eventScheduler.Stop()
|
||||
eventScheduler = nil
|
||||
}
|
||||
}
|
||||
|
||||
func startEventScheduler() {
|
||||
stopEventScheduler()
|
||||
|
||||
eventScheduler = cron.New(cron.WithLocation(time.UTC))
|
||||
_, err := eventScheduler.AddFunc("@every 10m", eventManager.loadRules)
|
||||
util.PanicOnError(err)
|
||||
eventScheduler.Start()
|
||||
}
|
|
@ -18,6 +18,7 @@ import (
|
|||
"bufio"
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
|
@ -77,6 +78,7 @@ var (
|
|||
allPerms = []string{dataprovider.PermAny}
|
||||
homeBasePath string
|
||||
logFilePath string
|
||||
backupsPath string
|
||||
testFileContent = []byte("test data")
|
||||
lastReceivedEmail receivedEmail
|
||||
)
|
||||
|
@ -84,6 +86,7 @@ var (
|
|||
func TestMain(m *testing.M) {
|
||||
homeBasePath = os.TempDir()
|
||||
logFilePath = filepath.Join(configDir, "common_test.log")
|
||||
backupsPath = filepath.Join(os.TempDir(), "backups")
|
||||
logger.InitLogger(logFilePath, 5, 1, 28, false, false, zerolog.DebugLevel)
|
||||
|
||||
os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1")
|
||||
|
@ -95,6 +98,7 @@ func TestMain(m *testing.M) {
|
|||
os.Exit(1)
|
||||
}
|
||||
providerConf := config.GetProviderConf()
|
||||
providerConf.BackupsPath = backupsPath
|
||||
logger.InfoToConsole("Starting COMMON tests, provider: %v", providerConf.Driver)
|
||||
|
||||
err = common.Initialize(config.GetCommonConfig(), 0)
|
||||
|
@ -203,6 +207,7 @@ func TestMain(m *testing.M) {
|
|||
|
||||
exitCode := m.Run()
|
||||
os.Remove(logFilePath)
|
||||
os.RemoveAll(backupsPath)
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
|
@ -2848,7 +2853,7 @@ func TestEventRule(t *testing.T) {
|
|||
EmailConfig: dataprovider.EventActionEmailConfig{
|
||||
Recipients: []string{"test1@example.com", "test2@example.com"},
|
||||
Subject: `New "{{Event}}" from "{{Name}}"`,
|
||||
Body: "Fs path {{FsPath}}, size: {{FileSize}}, protocol: {{Protocol}}, IP: {{IP}}",
|
||||
Body: "Fs path {{FsPath}}, size: {{FileSize}}, protocol: {{Protocol}}, IP: {{IP}} Data: {{ObjectData}}",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -2987,6 +2992,10 @@ func TestEventRule(t *testing.T) {
|
|||
Key: "SFTPGO_ACTION_PATH",
|
||||
Value: "{{FsPath}}",
|
||||
},
|
||||
{
|
||||
Key: "CUSTOM_ENV_VAR",
|
||||
Value: "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -3090,6 +3099,176 @@ func TestEventRule(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEventRuleProviderEvents(t *testing.T) {
|
||||
if runtime.GOOS == osWindows {
|
||||
t.Skip("this test is not available on Windows")
|
||||
}
|
||||
smtpCfg := smtp.Config{
|
||||
Host: "127.0.0.1",
|
||||
Port: 2525,
|
||||
From: "notification@example.com",
|
||||
TemplatesPath: "templates",
|
||||
}
|
||||
err := smtpCfg.Initialize(configDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
saveObjectScriptPath := filepath.Join(os.TempDir(), "provider.sh")
|
||||
outPath := filepath.Join(os.TempDir(), "provider_out.json")
|
||||
err = os.WriteFile(saveObjectScriptPath, getSaveProviderObjectScriptContent(outPath, 0), 0755)
|
||||
assert.NoError(t, err)
|
||||
|
||||
a1 := dataprovider.BaseEventAction{
|
||||
Name: "a1",
|
||||
Type: dataprovider.ActionTypeCommand,
|
||||
Options: dataprovider.BaseEventActionOptions{
|
||||
CmdConfig: dataprovider.EventActionCommandConfig{
|
||||
Cmd: saveObjectScriptPath,
|
||||
Timeout: 10,
|
||||
EnvVars: []dataprovider.KeyValue{
|
||||
{
|
||||
Key: "SFTPGO_OBJECT_DATA",
|
||||
Value: "{{ObjectData}}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
a2 := dataprovider.BaseEventAction{
|
||||
Name: "a2",
|
||||
Type: dataprovider.ActionTypeEmail,
|
||||
Options: dataprovider.BaseEventActionOptions{
|
||||
EmailConfig: dataprovider.EventActionEmailConfig{
|
||||
Recipients: []string{"test3@example.com"},
|
||||
Subject: `New "{{Event}}" from "{{Name}}"`,
|
||||
Body: "Object name: {{ObjectName}} object type: {{ObjectType}} Data: {{ObjectData}}",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
a3 := dataprovider.BaseEventAction{
|
||||
Name: "a3",
|
||||
Type: dataprovider.ActionTypeEmail,
|
||||
Options: dataprovider.BaseEventActionOptions{
|
||||
EmailConfig: dataprovider.EventActionEmailConfig{
|
||||
Recipients: []string{"failure@example.com"},
|
||||
Subject: `Failed "{{Event}}" from "{{Name}}"`,
|
||||
Body: "Object name: {{ObjectName}} object type: {{ObjectType}}, IP: {{IP}}",
|
||||
},
|
||||
},
|
||||
}
|
||||
action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
|
||||
r := dataprovider.EventRule{
|
||||
Name: "rule",
|
||||
Trigger: dataprovider.EventTriggerProviderEvent,
|
||||
Conditions: dataprovider.EventConditions{
|
||||
ProviderEvents: []string{"update"},
|
||||
},
|
||||
Actions: []dataprovider.EventAction{
|
||||
{
|
||||
BaseEventAction: dataprovider.BaseEventAction{
|
||||
Name: action1.Name,
|
||||
},
|
||||
Order: 1,
|
||||
Options: dataprovider.EventActionOptions{
|
||||
StopOnFailure: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
BaseEventAction: dataprovider.BaseEventAction{
|
||||
Name: action2.Name,
|
||||
},
|
||||
Order: 2,
|
||||
},
|
||||
{
|
||||
BaseEventAction: dataprovider.BaseEventAction{
|
||||
Name: action3.Name,
|
||||
},
|
||||
Order: 3,
|
||||
Options: dataprovider.EventActionOptions{
|
||||
IsFailureAction: true,
|
||||
StopOnFailure: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
|
||||
lastReceivedEmail.reset()
|
||||
// create and update a folder to trigger the rule
|
||||
folder := vfs.BaseVirtualFolder{
|
||||
Name: "ftest rule",
|
||||
MappedPath: filepath.Join(os.TempDir(), "p"),
|
||||
}
|
||||
folder, _, err = httpdtest.AddFolder(folder, http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
// no action is triggered on add
|
||||
assert.NoFileExists(t, outPath)
|
||||
// update the folder
|
||||
_, _, err = httpdtest.UpdateFolder(folder, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
if assert.FileExists(t, outPath) {
|
||||
content, err := os.ReadFile(outPath)
|
||||
assert.NoError(t, err)
|
||||
var folderGet vfs.BaseVirtualFolder
|
||||
err = json.Unmarshal(content, &folderGet)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, folder, folderGet)
|
||||
err = os.Remove(outPath)
|
||||
assert.NoError(t, err)
|
||||
assert.Eventually(t, func() bool {
|
||||
return lastReceivedEmail.get().From != ""
|
||||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "test3@example.com"))
|
||||
assert.Contains(t, string(email.Data), `Subject: New "update" from "admin"`)
|
||||
}
|
||||
// now delete the script to generate an error
|
||||
lastReceivedEmail.reset()
|
||||
err = os.Remove(saveObjectScriptPath)
|
||||
assert.NoError(t, err)
|
||||
_, _, err = httpdtest.UpdateFolder(folder, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.NoFileExists(t, outPath)
|
||||
assert.Eventually(t, func() bool {
|
||||
return lastReceivedEmail.get().From != ""
|
||||
}, 3000*time.Millisecond, 100*time.Millisecond)
|
||||
email := lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 1)
|
||||
assert.True(t, util.Contains(email.To, "failure@example.com"))
|
||||
assert.Contains(t, string(email.Data), `Subject: Failed "update" from "admin"`)
|
||||
assert.Contains(t, string(email.Data), fmt.Sprintf("Object name: %s object type: folder", folder.Name))
|
||||
lastReceivedEmail.reset()
|
||||
// generate an error for the failure action
|
||||
smtpCfg = smtp.Config{}
|
||||
err = smtpCfg.Initialize(configDir)
|
||||
require.NoError(t, err)
|
||||
_, _, err = httpdtest.UpdateFolder(folder, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.NoFileExists(t, outPath)
|
||||
email = lastReceivedEmail.get()
|
||||
assert.Len(t, email.To, 0)
|
||||
|
||||
_, err = httpdtest.RemoveFolder(folder, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = httpdtest.RemoveEventRule(rule, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
_, err = httpdtest.RemoveEventAction(action1, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
_, err = httpdtest.RemoveEventAction(action2, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
_, err = httpdtest.RemoveEventAction(action3, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSyncUploadAction(t *testing.T) {
|
||||
if runtime.GOOS == osWindows {
|
||||
t.Skip("this test is not available on Windows")
|
||||
|
@ -4113,6 +4292,13 @@ func getUploadScriptContent(movedPath string, exitStatus int) []byte {
|
|||
return content
|
||||
}
|
||||
|
||||
func getSaveProviderObjectScriptContent(outFilePath string, exitStatus int) []byte {
|
||||
content := []byte("#!/bin/sh\n\n")
|
||||
content = append(content, []byte(fmt.Sprintf("echo ${SFTPGO_OBJECT_DATA} > %v\n", outFilePath))...)
|
||||
content = append(content, []byte(fmt.Sprintf("exit %d", exitStatus))...)
|
||||
return content
|
||||
}
|
||||
|
||||
func generateTOTPPasscode(secret string, algo otp.Algorithm) (string, error) {
|
||||
return totp.GenerateCodeCustom(secret, time.Now(), totp.ValidateOpts{
|
||||
Period: 30,
|
||||
|
|
|
@ -63,17 +63,8 @@ func executeAction(operation, executor, ip, objectType, objectName string, objec
|
|||
Timestamp: time.Now().UnixNano(),
|
||||
}, object)
|
||||
}
|
||||
if EventManager.hasProviderEvents() {
|
||||
EventManager.handleProviderEvent(EventParams{
|
||||
Name: executor,
|
||||
ObjectName: objectName,
|
||||
Event: operation,
|
||||
Status: 1,
|
||||
ObjectType: objectType,
|
||||
IP: ip,
|
||||
Timestamp: time.Now().UnixNano(),
|
||||
Object: object,
|
||||
})
|
||||
if fnHandleRuleForProviderEvent != nil {
|
||||
fnHandleRuleForProviderEvent(operation, executor, ip, objectType, objectName, object)
|
||||
}
|
||||
if config.Actions.Hook == "" {
|
||||
return
|
||||
|
|
|
@ -191,6 +191,9 @@ var (
|
|||
lastLoginMinDelay = 10 * time.Minute
|
||||
usernameRegex = regexp.MustCompile("^[a-zA-Z0-9-_.~]+$")
|
||||
tempPath string
|
||||
fnReloadRules FnReloadRules
|
||||
fnRemoveRule FnRemoveRule
|
||||
fnHandleRuleForProviderEvent FnHandleRuleForProviderEvent
|
||||
)
|
||||
|
||||
func initSQLTables() {
|
||||
|
@ -214,6 +217,22 @@ func initSQLTables() {
|
|||
sqlTableSchemaVersion = "schema_version"
|
||||
}
|
||||
|
||||
// FnReloadRules defined the callback to reload event rules
|
||||
type FnReloadRules func()
|
||||
|
||||
// FnRemoveRule defines the callback to remove an event rule
|
||||
type FnRemoveRule func(name string)
|
||||
|
||||
// FnHandleRuleForProviderEvent define the callback to handle event rules for provider events
|
||||
type FnHandleRuleForProviderEvent func(operation, executor, ip, objectType, objectName string, object plugin.Renderer)
|
||||
|
||||
// SetEventRulesCallbacks sets the event rules callbacks
|
||||
func SetEventRulesCallbacks(reload FnReloadRules, remove FnRemoveRule, handle FnHandleRuleForProviderEvent) {
|
||||
fnReloadRules = reload
|
||||
fnRemoveRule = remove
|
||||
fnHandleRuleForProviderEvent = handle
|
||||
}
|
||||
|
||||
type schemaVersion struct {
|
||||
Version int
|
||||
}
|
||||
|
@ -487,31 +506,36 @@ func (c *Config) requireCustomTLSForMySQL() bool {
|
|||
func (c *Config) doBackup() error {
|
||||
now := time.Now().UTC()
|
||||
outputFile := filepath.Join(c.BackupsPath, fmt.Sprintf("backup_%s_%d.json", now.Weekday(), now.Hour()))
|
||||
eventManagerLog(logger.LevelDebug, "starting backup to file %q", outputFile)
|
||||
providerLog(logger.LevelDebug, "starting backup to file %q", outputFile)
|
||||
err := os.MkdirAll(filepath.Dir(outputFile), 0700)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to create backup dir %q: %v", outputFile, err)
|
||||
providerLog(logger.LevelError, "unable to create backup dir %q: %v", outputFile, err)
|
||||
return fmt.Errorf("unable to create backup dir: %w", err)
|
||||
}
|
||||
backup, err := DumpData()
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to execute backup: %v", err)
|
||||
providerLog(logger.LevelError, "unable to execute backup: %v", err)
|
||||
return fmt.Errorf("unable to dump backup data: %w", err)
|
||||
}
|
||||
dump, err := json.Marshal(backup)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to marshal backup as JSON: %v", err)
|
||||
providerLog(logger.LevelError, "unable to marshal backup as JSON: %v", err)
|
||||
return fmt.Errorf("unable to marshal backup data as JSON: %w", err)
|
||||
}
|
||||
err = os.WriteFile(outputFile, dump, 0600)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to save backup: %v", err)
|
||||
providerLog(logger.LevelError, "unable to save backup: %v", err)
|
||||
return fmt.Errorf("unable to save backup: %w", err)
|
||||
}
|
||||
eventManagerLog(logger.LevelDebug, "auto backup saved to %q", outputFile)
|
||||
providerLog(logger.LevelDebug, "backup saved to %q", outputFile)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteBackup executes a backup
|
||||
func ExecuteBackup() error {
|
||||
return config.doBackup()
|
||||
}
|
||||
|
||||
// ConvertName converts the given name based on the configured rules
|
||||
func ConvertName(name string) string {
|
||||
return config.convertName(name)
|
||||
|
@ -1568,7 +1592,9 @@ func AddEventAction(action *BaseEventAction, executor, ipAddress string) error {
|
|||
func UpdateEventAction(action *BaseEventAction, executor, ipAddress string) error {
|
||||
err := provider.updateEventAction(action)
|
||||
if err == nil {
|
||||
EventManager.loadRules()
|
||||
if fnReloadRules != nil {
|
||||
fnReloadRules()
|
||||
}
|
||||
executeAction(operationUpdate, executor, ipAddress, actionObjectEventAction, action.Name, action)
|
||||
}
|
||||
return err
|
||||
|
@ -1597,6 +1623,11 @@ func GetEventRules(limit, offset int, order string) ([]EventRule, error) {
|
|||
return provider.getEventRules(limit, offset, order)
|
||||
}
|
||||
|
||||
// GetRecentlyUpdatedRules returns the event rules updated after the specified time
|
||||
func GetRecentlyUpdatedRules(after int64) ([]EventRule, error) {
|
||||
return provider.getRecentlyUpdatedRules(after)
|
||||
}
|
||||
|
||||
// EventRuleExists returns the event rule with the given name if it exists
|
||||
func EventRuleExists(name string) (EventRule, error) {
|
||||
name = config.convertName(name)
|
||||
|
@ -1608,7 +1639,9 @@ func AddEventRule(rule *EventRule, executor, ipAddress string) error {
|
|||
rule.Name = config.convertName(rule.Name)
|
||||
err := provider.addEventRule(rule)
|
||||
if err == nil {
|
||||
EventManager.loadRules()
|
||||
if fnReloadRules != nil {
|
||||
fnReloadRules()
|
||||
}
|
||||
executeAction(operationAdd, executor, ipAddress, actionObjectEventRule, rule.Name, rule)
|
||||
}
|
||||
return err
|
||||
|
@ -1618,7 +1651,9 @@ func AddEventRule(rule *EventRule, executor, ipAddress string) error {
|
|||
func UpdateEventRule(rule *EventRule, executor, ipAddress string) error {
|
||||
err := provider.updateEventRule(rule)
|
||||
if err == nil {
|
||||
EventManager.loadRules()
|
||||
if fnReloadRules != nil {
|
||||
fnReloadRules()
|
||||
}
|
||||
executeAction(operationUpdate, executor, ipAddress, actionObjectEventRule, rule.Name, rule)
|
||||
}
|
||||
return err
|
||||
|
@ -1633,12 +1668,39 @@ func DeleteEventRule(name string, executor, ipAddress string) error {
|
|||
}
|
||||
err = provider.deleteEventRule(rule, config.IsShared == 1)
|
||||
if err == nil {
|
||||
EventManager.RemoveRule(rule.Name)
|
||||
if fnRemoveRule != nil {
|
||||
fnRemoveRule(rule.Name)
|
||||
}
|
||||
executeAction(operationDelete, executor, ipAddress, actionObjectEventRule, rule.Name, &rule)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveEventRule delets an existing event rule without marking it as deleted
|
||||
func RemoveEventRule(rule EventRule) error {
|
||||
return provider.deleteEventRule(rule, false)
|
||||
}
|
||||
|
||||
// GetTaskByName returns the task with the specified name
|
||||
func GetTaskByName(name string) (Task, error) {
|
||||
return provider.getTaskByName(name)
|
||||
}
|
||||
|
||||
// AddTask add a task with the specified name
|
||||
func AddTask(name string) error {
|
||||
return provider.addTask(name)
|
||||
}
|
||||
|
||||
// UpdateTask updates the task with the specified name and version
|
||||
func UpdateTask(name string, version int64) error {
|
||||
return provider.updateTask(name, version)
|
||||
}
|
||||
|
||||
// UpdateTaskTimestamp updates the timestamp for the task with the specified name
|
||||
func UpdateTaskTimestamp(name string) error {
|
||||
return provider.updateTaskTimestamp(name)
|
||||
}
|
||||
|
||||
// HasAdmin returns true if the first admin has been created
|
||||
// and so SFTPGo is ready to be used
|
||||
func HasAdmin() bool {
|
||||
|
@ -1971,6 +2033,16 @@ func GetFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtua
|
|||
return provider.getFolders(limit, offset, order, minimal)
|
||||
}
|
||||
|
||||
// DumpUsers returns all users, including confidential data
|
||||
func DumpUsers() ([]User, error) {
|
||||
return provider.dumpUsers()
|
||||
}
|
||||
|
||||
// DumpFolders returns all folders, including confidential data
|
||||
func DumpFolders() ([]vfs.BaseVirtualFolder, error) {
|
||||
return provider.dumpFolders()
|
||||
}
|
||||
|
||||
// DumpData returns all users, groups, folders, admins, api keys, shares, actions, rules
|
||||
func DumpData() (BackupData, error) {
|
||||
var data BackupData
|
||||
|
|
|
@ -15,16 +15,10 @@
|
|||
package dataprovider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
@ -34,9 +28,7 @@ import (
|
|||
|
||||
"github.com/drakkan/sftpgo/v2/internal/kms"
|
||||
"github.com/drakkan/sftpgo/v2/internal/logger"
|
||||
"github.com/drakkan/sftpgo/v2/internal/smtp"
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
"github.com/drakkan/sftpgo/v2/internal/vfs"
|
||||
)
|
||||
|
||||
// Supported event actions
|
||||
|
@ -204,25 +196,8 @@ func (c *EventActionHTTPConfig) validate(additionalData string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *EventActionHTTPConfig) getEndpoint(replacer *strings.Replacer) (string, error) {
|
||||
if len(c.QueryParameters) > 0 {
|
||||
u, err := url.Parse(c.Endpoint)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid endpoint: %w", err)
|
||||
}
|
||||
q := u.Query()
|
||||
|
||||
for _, keyVal := range c.QueryParameters {
|
||||
q.Add(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer))
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
return c.Endpoint, nil
|
||||
}
|
||||
|
||||
func (c *EventActionHTTPConfig) getHTTPClient() *http.Client {
|
||||
// GetHTTPClient returns an HTTP client based on the config
|
||||
func (c *EventActionHTTPConfig) GetHTTPClient() *http.Client {
|
||||
client := &http.Client{
|
||||
Timeout: time.Duration(c.Timeout) * time.Second,
|
||||
}
|
||||
|
@ -241,63 +216,6 @@ func (c *EventActionHTTPConfig) getHTTPClient() *http.Client {
|
|||
return client
|
||||
}
|
||||
|
||||
func (c *EventActionHTTPConfig) execute(params EventParams) error {
|
||||
if !c.Password.IsEmpty() {
|
||||
if err := c.Password.TryDecrypt(); err != nil {
|
||||
return fmt.Errorf("unable to decrypt password: %w", err)
|
||||
}
|
||||
}
|
||||
addObjectData := false
|
||||
if params.Object != nil {
|
||||
if !addObjectData {
|
||||
if strings.Contains(c.Body, "{{ObjectData}}") {
|
||||
addObjectData = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
replacements := params.getStringReplacements(addObjectData)
|
||||
replacer := strings.NewReplacer(replacements...)
|
||||
endpoint, err := c.getEndpoint(replacer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var body io.Reader
|
||||
if c.Body != "" && c.Method != http.MethodGet {
|
||||
body = bytes.NewBufferString(replaceWithReplacer(c.Body, replacer))
|
||||
}
|
||||
req, err := http.NewRequest(c.Method, endpoint, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.Username != "" {
|
||||
req.SetBasicAuth(replaceWithReplacer(c.Username, replacer), c.Password.GetAdditionalData())
|
||||
}
|
||||
for _, keyVal := range c.Headers {
|
||||
req.Header.Set(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer))
|
||||
}
|
||||
client := c.getHTTPClient()
|
||||
defer client.CloseIdleConnections()
|
||||
|
||||
startTime := time.Now()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelDebug, "unable to send http notification, endpoint: %s, elapsed: %s, err: %v",
|
||||
endpoint, time.Since(startTime), err)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
eventManagerLog(logger.LevelDebug, "http notification sent, endopoint: %s, elapsed: %s, status code: %d",
|
||||
endpoint, time.Since(startTime), resp.StatusCode)
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent {
|
||||
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EventActionCommandConfig defines the configuration for a command event target
|
||||
type EventActionCommandConfig struct {
|
||||
Cmd string `json:"cmd"`
|
||||
|
@ -323,43 +241,6 @@ func (c *EventActionCommandConfig) validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *EventActionCommandConfig) getEnvVars(params EventParams) []string {
|
||||
envVars := make([]string, 0, len(c.EnvVars))
|
||||
addObjectData := false
|
||||
if params.Object != nil {
|
||||
for _, k := range c.EnvVars {
|
||||
if strings.Contains(k.Value, "{{ObjectData}}") {
|
||||
addObjectData = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
replacements := params.getStringReplacements(addObjectData)
|
||||
replacer := strings.NewReplacer(replacements...)
|
||||
for _, keyVal := range c.EnvVars {
|
||||
envVars = append(envVars, fmt.Sprintf("%s=%s", keyVal.Key, replaceWithReplacer(keyVal.Value, replacer)))
|
||||
}
|
||||
|
||||
return envVars
|
||||
}
|
||||
|
||||
func (c *EventActionCommandConfig) execute(params EventParams) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, c.Cmd)
|
||||
cmd.Env = append(cmd.Env, os.Environ()...)
|
||||
cmd.Env = append(cmd.Env, c.getEnvVars(params)...)
|
||||
|
||||
startTime := time.Now()
|
||||
err := cmd.Run()
|
||||
|
||||
eventManagerLog(logger.LevelDebug, "executed command %q, elapsed: %s, error: %v",
|
||||
c.Cmd, time.Since(startTime), err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// EventActionEmailConfig defines the configuration options for SMTP event actions
|
||||
type EventActionEmailConfig struct {
|
||||
Recipients []string `json:"recipients"`
|
||||
|
@ -391,24 +272,6 @@ func (o *EventActionEmailConfig) validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (o *EventActionEmailConfig) execute(params EventParams) error {
|
||||
addObjectData := false
|
||||
if params.Object != nil {
|
||||
if strings.Contains(o.Body, "{{ObjectData}}") {
|
||||
addObjectData = true
|
||||
}
|
||||
}
|
||||
replacements := params.getStringReplacements(addObjectData)
|
||||
replacer := strings.NewReplacer(replacements...)
|
||||
body := replaceWithReplacer(o.Body, replacer)
|
||||
subject := replaceWithReplacer(o.Subject, replacer)
|
||||
startTime := time.Now()
|
||||
err := smtp.SendEmail(o.Recipients, subject, body, smtp.EmailContentTypeTextPlain)
|
||||
eventManagerLog(logger.LevelDebug, "executed email notification action, elapsed: %s, error: %v",
|
||||
time.Since(startTime), err)
|
||||
return err
|
||||
}
|
||||
|
||||
// BaseEventActionOptions defines the supported configuration options for a base event actions
|
||||
type BaseEventActionOptions struct {
|
||||
HTTPConfig EventActionHTTPConfig `json:"http_config"`
|
||||
|
@ -560,130 +423,6 @@ func (a *BaseEventAction) validate() error {
|
|||
return a.Options.validate(a.Type, a.Name)
|
||||
}
|
||||
|
||||
func (a *BaseEventAction) doUsersQuotaReset(conditions ConditionOptions) error {
|
||||
users, err := provider.dumpUsers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get users: %w", err)
|
||||
}
|
||||
var failedResets []string
|
||||
for _, user := range users {
|
||||
if !checkConditionPatterns(user.Username, conditions.Names) {
|
||||
eventManagerLog(logger.LevelDebug, "skipping scheduled quota reset for user %s, name conditions don't match",
|
||||
user.Username)
|
||||
continue
|
||||
}
|
||||
if !QuotaScans.AddUserQuotaScan(user.Username) {
|
||||
eventManagerLog(logger.LevelError, "another quota scan is already in progress for user %s", user.Username)
|
||||
failedResets = append(failedResets, user.Username)
|
||||
continue
|
||||
}
|
||||
numFiles, size, err := user.ScanQuota()
|
||||
QuotaScans.RemoveUserQuotaScan(user.Username)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "error scanning quota for user %s: %v", user.Username, err)
|
||||
failedResets = append(failedResets, user.Username)
|
||||
continue
|
||||
}
|
||||
err = UpdateUserQuota(&user, numFiles, size, true)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "error updating quota for user %s: %v", user.Username, err)
|
||||
failedResets = append(failedResets, user.Username)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(failedResets) > 0 {
|
||||
return fmt.Errorf("quota reset failed for users: %+v", failedResets)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *BaseEventAction) doFoldersQuotaReset(conditions ConditionOptions) error {
|
||||
folders, err := provider.dumpFolders()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get folders: %w", err)
|
||||
}
|
||||
var failedResets []string
|
||||
for _, folder := range folders {
|
||||
if !checkConditionPatterns(folder.Name, conditions.Names) {
|
||||
eventManagerLog(logger.LevelDebug, "skipping scheduled quota reset for folder %s, name conditions don't match",
|
||||
folder.Name)
|
||||
continue
|
||||
}
|
||||
if !QuotaScans.AddVFolderQuotaScan(folder.Name) {
|
||||
eventManagerLog(logger.LevelError, "another quota scan is already in progress for folder %s", folder.Name)
|
||||
failedResets = append(failedResets, folder.Name)
|
||||
continue
|
||||
}
|
||||
f := vfs.VirtualFolder{
|
||||
BaseVirtualFolder: folder,
|
||||
VirtualPath: "/",
|
||||
}
|
||||
numFiles, size, err := f.ScanQuota()
|
||||
QuotaScans.RemoveVFolderQuotaScan(folder.Name)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "error scanning quota for folder %s: %v", folder.Name, err)
|
||||
failedResets = append(failedResets, folder.Name)
|
||||
continue
|
||||
}
|
||||
err = UpdateVirtualFolderQuota(&folder, numFiles, size, true)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "error updating quota for folder %s: %v", folder.Name, err)
|
||||
failedResets = append(failedResets, folder.Name)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(failedResets) > 0 {
|
||||
return fmt.Errorf("quota reset failed for folders: %+v", failedResets)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *BaseEventAction) doTransferQuotaReset(conditions ConditionOptions) error {
|
||||
users, err := provider.dumpUsers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get users: %w", err)
|
||||
}
|
||||
var failedResets []string
|
||||
for _, user := range users {
|
||||
if !checkConditionPatterns(user.Username, conditions.Names) {
|
||||
eventManagerLog(logger.LevelDebug, "skipping scheduled transfer quota reset for user %s, name conditions don't match",
|
||||
user.Username)
|
||||
continue
|
||||
}
|
||||
err = UpdateUserTransferQuota(&user, 0, 0, true)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "error updating transfer quota for user %s: %v", user.Username, err)
|
||||
failedResets = append(failedResets, user.Username)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(failedResets) > 0 {
|
||||
return fmt.Errorf("transfer quota reset failed for users: %+v", failedResets)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *BaseEventAction) execute(params EventParams, conditions ConditionOptions) error {
|
||||
switch a.Type {
|
||||
case ActionTypeHTTP:
|
||||
return a.Options.HTTPConfig.execute(params)
|
||||
case ActionTypeCommand:
|
||||
return a.Options.CmdConfig.execute(params)
|
||||
case ActionTypeEmail:
|
||||
return a.Options.EmailConfig.execute(params)
|
||||
case ActionTypeBackup:
|
||||
return config.doBackup()
|
||||
case ActionTypeUserQuotaReset:
|
||||
return a.doUsersQuotaReset(conditions)
|
||||
case ActionTypeFolderQuotaReset:
|
||||
return a.doFoldersQuotaReset(conditions)
|
||||
case ActionTypeTransferQuotaReset:
|
||||
return a.doTransferQuotaReset(conditions)
|
||||
default:
|
||||
return fmt.Errorf("unsupported action type: %d", a.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// EventActionOptions defines the supported configuration options for an event action
|
||||
type EventActionOptions struct {
|
||||
IsFailureAction bool `json:"is_failure_action"`
|
||||
|
@ -731,18 +470,6 @@ type ConditionPattern struct {
|
|||
InverseMatch bool `json:"inverse_match,omitempty"`
|
||||
}
|
||||
|
||||
func (p *ConditionPattern) match(name string) bool {
|
||||
matched, err := path.Match(p.Pattern, name)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "pattern matching error %q, err: %v", p.Pattern, err)
|
||||
return false
|
||||
}
|
||||
if p.InverseMatch {
|
||||
return !matched
|
||||
}
|
||||
return matched
|
||||
}
|
||||
|
||||
func (p *ConditionPattern) validate() error {
|
||||
if p.Pattern == "" {
|
||||
return util.NewValidationError("empty condition pattern not allowed")
|
||||
|
@ -826,12 +553,13 @@ type Schedule struct {
|
|||
Month string `json:"month"`
|
||||
}
|
||||
|
||||
func (s *Schedule) getCronSpec() string {
|
||||
// GetCronSpec returns the cron compatible schedule string
|
||||
func (s *Schedule) GetCronSpec() string {
|
||||
return fmt.Sprintf("0 %s %s %s %s", s.Hours, s.DayOfMonth, s.Month, s.DayOfWeek)
|
||||
}
|
||||
|
||||
func (s *Schedule) validate() error {
|
||||
_, err := cron.ParseStandard(s.getCronSpec())
|
||||
_, err := cron.ParseStandard(s.GetCronSpec())
|
||||
if err != nil {
|
||||
return util.NewValidationError(fmt.Sprintf("invalid schedule, hour: %q, day of month: %q, month: %q, day of week: %q",
|
||||
s.Hours, s.DayOfMonth, s.Month, s.DayOfWeek))
|
||||
|
@ -871,51 +599,6 @@ func (c *EventConditions) getACopy() EventConditions {
|
|||
}
|
||||
}
|
||||
|
||||
// ProviderEventMatch returns true if the specified provider event match
|
||||
func (c *EventConditions) ProviderEventMatch(params EventParams) bool {
|
||||
if !util.Contains(c.ProviderEvents, params.Event) {
|
||||
return false
|
||||
}
|
||||
if !checkConditionPatterns(params.Name, c.Options.Names) {
|
||||
return false
|
||||
}
|
||||
if len(c.Options.ProviderObjects) > 0 && !util.Contains(c.Options.ProviderObjects, params.ObjectType) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// FsEventMatch returns true if the specified filesystem event match
|
||||
func (c *EventConditions) FsEventMatch(params EventParams) bool {
|
||||
if !util.Contains(c.FsEvents, params.Event) {
|
||||
return false
|
||||
}
|
||||
if !checkConditionPatterns(params.Name, c.Options.Names) {
|
||||
return false
|
||||
}
|
||||
if !checkConditionPatterns(params.VirtualPath, c.Options.FsPaths) {
|
||||
if !checkConditionPatterns(params.ObjectName, c.Options.FsPaths) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(c.Options.Protocols) > 0 && !util.Contains(c.Options.Protocols, params.Protocol) {
|
||||
return false
|
||||
}
|
||||
if params.Event == "upload" || params.Event == "download" {
|
||||
if c.Options.MinFileSize > 0 {
|
||||
if params.FileSize < c.Options.MinFileSize {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if c.Options.MaxFileSize > 0 {
|
||||
if params.FileSize > c.Options.MaxFileSize {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *EventConditions) validate(trigger int) error {
|
||||
switch trigger {
|
||||
case EventTriggerFsEvent:
|
||||
|
@ -1015,7 +698,9 @@ func (r *EventRule) getACopy() EventRule {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *EventRule) guardFromConcurrentExecution() bool {
|
||||
// GuardFromConcurrentExecution returns true if the rule cannot be executed concurrently
|
||||
// from multiple instances
|
||||
func (r *EventRule) GuardFromConcurrentExecution() bool {
|
||||
if config.IsShared == 0 {
|
||||
return false
|
||||
}
|
||||
|
@ -1102,6 +787,28 @@ func (r *EventRule) RenderAsJSON(reload bool) ([]byte, error) {
|
|||
return json.Marshal(r)
|
||||
}
|
||||
|
||||
func cloneKeyValues(keyVals []KeyValue) []KeyValue {
|
||||
res := make([]KeyValue, 0, len(keyVals))
|
||||
for _, kv := range keyVals {
|
||||
res = append(res, KeyValue{
|
||||
Key: kv.Key,
|
||||
Value: kv.Value,
|
||||
})
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func cloneConditionPatterns(patterns []ConditionPattern) []ConditionPattern {
|
||||
res := make([]ConditionPattern, 0, len(patterns))
|
||||
for _, p := range patterns {
|
||||
res = append(res, ConditionPattern{
|
||||
Pattern: p.Pattern,
|
||||
InverseMatch: p.InverseMatch,
|
||||
})
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// Task stores the state for a scheduled task
|
||||
type Task struct {
|
||||
Name string `json:"name"`
|
||||
|
|
|
@ -1,474 +0,0 @@
|
|||
// Copyright (C) 2019-2022 Nicola Murino
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published
|
||||
// by the Free Software Foundation, version 3.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package dataprovider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/logger"
|
||||
"github.com/drakkan/sftpgo/v2/internal/plugin"
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
)
|
||||
|
||||
var (
|
||||
// EventManager handle the supported event rules actions
|
||||
EventManager EventRulesContainer
|
||||
)
|
||||
|
||||
func init() {
|
||||
EventManager = EventRulesContainer{
|
||||
schedulesMapping: make(map[string][]cron.EntryID),
|
||||
}
|
||||
}
|
||||
|
||||
// EventRulesContainer stores event rules by trigger
|
||||
type EventRulesContainer struct {
|
||||
sync.RWMutex
|
||||
FsEvents []EventRule
|
||||
ProviderEvents []EventRule
|
||||
Schedules []EventRule
|
||||
schedulesMapping map[string][]cron.EntryID
|
||||
lastLoad int64
|
||||
}
|
||||
|
||||
func (r *EventRulesContainer) getLastLoadTime() int64 {
|
||||
return atomic.LoadInt64(&r.lastLoad)
|
||||
}
|
||||
|
||||
func (r *EventRulesContainer) setLastLoadTime(modTime int64) {
|
||||
atomic.StoreInt64(&r.lastLoad, modTime)
|
||||
}
|
||||
|
||||
// RemoveRule deletes the rule with the specified name
|
||||
func (r *EventRulesContainer) RemoveRule(name string) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
r.removeRuleInternal(name)
|
||||
eventManagerLog(logger.LevelDebug, "event rules updated after delete, fs events: %d, provider events: %d, schedules: %d",
|
||||
len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules))
|
||||
}
|
||||
|
||||
func (r *EventRulesContainer) removeRuleInternal(name string) {
|
||||
for idx := range r.FsEvents {
|
||||
if r.FsEvents[idx].Name == name {
|
||||
lastIdx := len(r.FsEvents) - 1
|
||||
r.FsEvents[idx] = r.FsEvents[lastIdx]
|
||||
r.FsEvents = r.FsEvents[:lastIdx]
|
||||
eventManagerLog(logger.LevelDebug, "removed rule %q from fs events", name)
|
||||
return
|
||||
}
|
||||
}
|
||||
for idx := range r.ProviderEvents {
|
||||
if r.ProviderEvents[idx].Name == name {
|
||||
lastIdx := len(r.ProviderEvents) - 1
|
||||
r.ProviderEvents[idx] = r.ProviderEvents[lastIdx]
|
||||
r.ProviderEvents = r.ProviderEvents[:lastIdx]
|
||||
eventManagerLog(logger.LevelDebug, "removed rule %q from provider events", name)
|
||||
return
|
||||
}
|
||||
}
|
||||
for idx := range r.Schedules {
|
||||
if r.Schedules[idx].Name == name {
|
||||
if schedules, ok := r.schedulesMapping[name]; ok {
|
||||
for _, entryID := range schedules {
|
||||
eventManagerLog(logger.LevelDebug, "removing scheduled entry id %d for rule %q", entryID, name)
|
||||
scheduler.Remove(entryID)
|
||||
}
|
||||
delete(r.schedulesMapping, name)
|
||||
}
|
||||
|
||||
lastIdx := len(r.Schedules) - 1
|
||||
r.Schedules[idx] = r.Schedules[lastIdx]
|
||||
r.Schedules = r.Schedules[:lastIdx]
|
||||
eventManagerLog(logger.LevelDebug, "removed rule %q from scheduled events", name)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *EventRulesContainer) addUpdateRuleInternal(rule EventRule) {
|
||||
r.removeRuleInternal(rule.Name)
|
||||
if rule.DeletedAt > 0 {
|
||||
deletedAt := util.GetTimeFromMsecSinceEpoch(rule.DeletedAt)
|
||||
if deletedAt.Add(30 * time.Minute).Before(time.Now()) {
|
||||
eventManagerLog(logger.LevelDebug, "removing rule %q deleted at %s", rule.Name, deletedAt)
|
||||
go provider.deleteEventRule(rule, false) //nolint:errcheck
|
||||
}
|
||||
return
|
||||
}
|
||||
switch rule.Trigger {
|
||||
case EventTriggerFsEvent:
|
||||
r.FsEvents = append(r.FsEvents, rule)
|
||||
eventManagerLog(logger.LevelDebug, "added rule %q to fs events", rule.Name)
|
||||
case EventTriggerProviderEvent:
|
||||
r.ProviderEvents = append(r.ProviderEvents, rule)
|
||||
eventManagerLog(logger.LevelDebug, "added rule %q to provider events", rule.Name)
|
||||
case EventTriggerSchedule:
|
||||
r.Schedules = append(r.Schedules, rule)
|
||||
eventManagerLog(logger.LevelDebug, "added rule %q to scheduled events", rule.Name)
|
||||
for _, schedule := range rule.Conditions.Schedules {
|
||||
cronSpec := schedule.getCronSpec()
|
||||
job := &cronJob{
|
||||
ruleName: ConvertName(rule.Name),
|
||||
}
|
||||
entryID, err := scheduler.AddJob(cronSpec, job)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to add scheduled rule %q: %v", rule.Name, err)
|
||||
} else {
|
||||
r.schedulesMapping[rule.Name] = append(r.schedulesMapping[rule.Name], entryID)
|
||||
eventManagerLog(logger.LevelDebug, "scheduled rule %q added, id: %d, active scheduling rules: %d",
|
||||
rule.Name, entryID, len(r.schedulesMapping))
|
||||
}
|
||||
}
|
||||
default:
|
||||
eventManagerLog(logger.LevelError, "unsupported trigger: %d", rule.Trigger)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *EventRulesContainer) loadRules() {
|
||||
eventManagerLog(logger.LevelDebug, "loading updated rules")
|
||||
modTime := util.GetTimeAsMsSinceEpoch(time.Now())
|
||||
rules, err := provider.getRecentlyUpdatedRules(r.getLastLoadTime())
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to load event rules: %v", err)
|
||||
return
|
||||
}
|
||||
eventManagerLog(logger.LevelDebug, "recently updated event rules loaded: %d", len(rules))
|
||||
|
||||
if len(rules) > 0 {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
for _, rule := range rules {
|
||||
r.addUpdateRuleInternal(rule)
|
||||
}
|
||||
}
|
||||
eventManagerLog(logger.LevelDebug, "event rules updated, fs events: %d, provider events: %d, schedules: %d",
|
||||
len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules))
|
||||
|
||||
r.setLastLoadTime(modTime)
|
||||
}
|
||||
|
||||
// HasFsRules returns true if there are any rules for filesystem event triggers
|
||||
func (r *EventRulesContainer) HasFsRules() bool {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
|
||||
return len(r.FsEvents) > 0
|
||||
}
|
||||
|
||||
func (r *EventRulesContainer) hasProviderEvents() bool {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
|
||||
return len(r.ProviderEvents) > 0
|
||||
}
|
||||
|
||||
// HandleFsEvent executes the rules actions defined for the specified event
|
||||
func (r *EventRulesContainer) HandleFsEvent(params EventParams) error {
|
||||
r.RLock()
|
||||
|
||||
var rulesWithSyncActions, rulesAsync []EventRule
|
||||
for _, rule := range r.FsEvents {
|
||||
if rule.Conditions.FsEventMatch(params) {
|
||||
hasSyncActions := false
|
||||
for _, action := range rule.Actions {
|
||||
if action.Options.ExecuteSync {
|
||||
hasSyncActions = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasSyncActions {
|
||||
rulesWithSyncActions = append(rulesWithSyncActions, rule)
|
||||
} else {
|
||||
rulesAsync = append(rulesAsync, rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.RUnlock()
|
||||
|
||||
if len(rulesAsync) > 0 {
|
||||
go executeAsyncActions(rulesAsync, params)
|
||||
}
|
||||
|
||||
if len(rulesWithSyncActions) > 0 {
|
||||
return executeSyncActions(rulesWithSyncActions, params)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *EventRulesContainer) handleProviderEvent(params EventParams) {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
|
||||
var rules []EventRule
|
||||
for _, rule := range r.ProviderEvents {
|
||||
if rule.Conditions.ProviderEventMatch(params) {
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
}
|
||||
|
||||
go executeAsyncActions(rules, params)
|
||||
}
|
||||
|
||||
// EventParams defines the supported event parameters
|
||||
type EventParams struct {
|
||||
Name string
|
||||
Event string
|
||||
Status int
|
||||
VirtualPath string
|
||||
FsPath string
|
||||
VirtualTargetPath string
|
||||
FsTargetPath string
|
||||
ObjectName string
|
||||
ObjectType string
|
||||
FileSize int64
|
||||
Protocol string
|
||||
IP string
|
||||
Timestamp int64
|
||||
Object plugin.Renderer
|
||||
}
|
||||
|
||||
func (p *EventParams) getStringReplacements(addObjectData bool) []string {
|
||||
replacements := []string{
|
||||
"{{Name}}", p.Name,
|
||||
"{{Event}}", p.Event,
|
||||
"{{Status}}", fmt.Sprintf("%d", p.Status),
|
||||
"{{VirtualPath}}", p.VirtualPath,
|
||||
"{{FsPath}}", p.FsPath,
|
||||
"{{VirtualTargetPath}}", p.VirtualTargetPath,
|
||||
"{{FsTargetPath}}", p.FsTargetPath,
|
||||
"{{ObjectName}}", p.ObjectName,
|
||||
"{{ObjectType}}", p.ObjectType,
|
||||
"{{FileSize}}", fmt.Sprintf("%d", p.FileSize),
|
||||
"{{Protocol}}", p.Protocol,
|
||||
"{{IP}}", p.IP,
|
||||
"{{Timestamp}}", fmt.Sprintf("%d", p.Timestamp),
|
||||
}
|
||||
if addObjectData {
|
||||
data, err := p.Object.RenderAsJSON(p.Event != operationDelete)
|
||||
if err == nil {
|
||||
replacements = append(replacements, "{{ObjectData}}", string(data))
|
||||
}
|
||||
}
|
||||
return replacements
|
||||
}
|
||||
|
||||
func replaceWithReplacer(input string, replacer *strings.Replacer) string {
|
||||
if !strings.Contains(input, "{{") {
|
||||
return input
|
||||
}
|
||||
return replacer.Replace(input)
|
||||
}
|
||||
|
||||
// checkConditionPatterns returns false if patterns are defined and no match is found
|
||||
func checkConditionPatterns(name string, patterns []ConditionPattern) bool {
|
||||
if len(patterns) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, p := range patterns {
|
||||
if p.match(name) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func executeSyncActions(rules []EventRule, params EventParams) error {
|
||||
var errRes error
|
||||
|
||||
for _, rule := range rules {
|
||||
var failedActions []string
|
||||
for _, action := range rule.Actions {
|
||||
if !action.Options.IsFailureAction && action.Options.ExecuteSync {
|
||||
startTime := time.Now()
|
||||
if err := action.execute(params, rule.Conditions.Options); err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to execute sync action %q for rule %q, elapsed %s, err: %v",
|
||||
action.Name, rule.Name, time.Since(startTime), err)
|
||||
failedActions = append(failedActions, action.Name)
|
||||
// we return the last error, it is ok for now
|
||||
errRes = err
|
||||
if action.Options.StopOnFailure {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
eventManagerLog(logger.LevelDebug, "executed sync action %q for rule %q, elapsed: %s",
|
||||
action.Name, rule.Name, time.Since(startTime))
|
||||
}
|
||||
}
|
||||
}
|
||||
// execute async actions if any, including failure actions
|
||||
go executeRuleAsyncActions(rule, params, failedActions)
|
||||
}
|
||||
|
||||
return errRes
|
||||
}
|
||||
|
||||
func executeAsyncActions(rules []EventRule, params EventParams) {
|
||||
for _, rule := range rules {
|
||||
executeRuleAsyncActions(rule, params, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func executeRuleAsyncActions(rule EventRule, params EventParams, failedActions []string) {
|
||||
for _, action := range rule.Actions {
|
||||
if !action.Options.IsFailureAction && !action.Options.ExecuteSync {
|
||||
startTime := time.Now()
|
||||
if err := action.execute(params, rule.Conditions.Options); err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to execute action %q for rule %q, elapsed %s, err: %v",
|
||||
action.Name, rule.Name, time.Since(startTime), err)
|
||||
failedActions = append(failedActions, action.Name)
|
||||
if action.Options.StopOnFailure {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
eventManagerLog(logger.LevelDebug, "executed action %q for rule %q, elapsed %s",
|
||||
action.Name, rule.Name, time.Since(startTime))
|
||||
}
|
||||
}
|
||||
if len(failedActions) > 0 {
|
||||
// execute failure actions
|
||||
for _, action := range rule.Actions {
|
||||
if action.Options.IsFailureAction {
|
||||
startTime := time.Now()
|
||||
if err := action.execute(params, rule.Conditions.Options); err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to execute failure action %q for rule %q, elapsed %s, err: %v",
|
||||
action.Name, rule.Name, time.Since(startTime), err)
|
||||
if action.Options.StopOnFailure {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
eventManagerLog(logger.LevelDebug, "executed failure action %q for rule %q, elapsed: %s",
|
||||
action.Name, rule.Name, time.Since(startTime))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type cronJob struct {
|
||||
ruleName string
|
||||
}
|
||||
|
||||
func (j *cronJob) getTask(rule EventRule) (Task, error) {
|
||||
if rule.guardFromConcurrentExecution() {
|
||||
task, err := provider.getTaskByName(rule.Name)
|
||||
if _, ok := err.(*util.RecordNotFoundError); ok {
|
||||
eventManagerLog(logger.LevelDebug, "adding task for rule %q", rule.Name)
|
||||
task = Task{
|
||||
Name: rule.Name,
|
||||
UpdateAt: 0,
|
||||
Version: 0,
|
||||
}
|
||||
err = provider.addTask(rule.Name)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelWarn, "unable to add task for rule %q: %v", rule.Name, err)
|
||||
return task, err
|
||||
}
|
||||
} else {
|
||||
eventManagerLog(logger.LevelWarn, "unable to get task for rule %q: %v", rule.Name, err)
|
||||
}
|
||||
return task, err
|
||||
}
|
||||
|
||||
return Task{}, nil
|
||||
}
|
||||
|
||||
func (j *cronJob) Run() {
|
||||
eventManagerLog(logger.LevelDebug, "executing scheduled rule %q", j.ruleName)
|
||||
rule, err := provider.eventRuleExists(j.ruleName)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelError, "unable to load rule with name %q", j.ruleName)
|
||||
return
|
||||
}
|
||||
task, err := j.getTask(rule)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if task.Name != "" {
|
||||
updateInterval := 5 * time.Minute
|
||||
updatedAt := util.GetTimeFromMsecSinceEpoch(task.UpdateAt)
|
||||
if updatedAt.Add(updateInterval*2 + 1).After(time.Now()) {
|
||||
eventManagerLog(logger.LevelDebug, "task for rule %q too recent: %s, skip execution", rule.Name, updatedAt)
|
||||
return
|
||||
}
|
||||
err = provider.updateTask(rule.Name, task.Version)
|
||||
if err != nil {
|
||||
eventManagerLog(logger.LevelInfo, "unable to update task timestamp for rule %q, skip execution, err: %v",
|
||||
rule.Name, err)
|
||||
return
|
||||
}
|
||||
ticker := time.NewTicker(updateInterval)
|
||||
done := make(chan bool)
|
||||
|
||||
go func(taskName string) {
|
||||
eventManagerLog(logger.LevelDebug, "update task %q timestamp worker started", taskName)
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
eventManagerLog(logger.LevelDebug, "update task %q timestamp worker finished", taskName)
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := provider.updateTaskTimestamp(taskName)
|
||||
eventManagerLog(logger.LevelInfo, "updated timestamp for task %q, err: %v", taskName, err)
|
||||
}
|
||||
}
|
||||
}(task.Name)
|
||||
|
||||
executeRuleAsyncActions(rule, EventParams{}, nil)
|
||||
|
||||
done <- true
|
||||
ticker.Stop()
|
||||
} else {
|
||||
executeRuleAsyncActions(rule, EventParams{}, nil)
|
||||
}
|
||||
eventManagerLog(logger.LevelDebug, "execution for scheduled rule %q finished", j.ruleName)
|
||||
}
|
||||
|
||||
func cloneKeyValues(keyVals []KeyValue) []KeyValue {
|
||||
res := make([]KeyValue, 0, len(keyVals))
|
||||
for _, kv := range keyVals {
|
||||
res = append(res, KeyValue{
|
||||
Key: kv.Key,
|
||||
Value: kv.Value,
|
||||
})
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func cloneConditionPatterns(patterns []ConditionPattern) []ConditionPattern {
|
||||
res := make([]ConditionPattern, 0, len(patterns))
|
||||
for _, p := range patterns {
|
||||
res = append(res, ConditionPattern{
|
||||
Pattern: p.Pattern,
|
||||
InverseMatch: p.InverseMatch,
|
||||
})
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func eventManagerLog(level logger.LogLevel, format string, v ...any) {
|
||||
logger.Log(level, "eventmanager", "", format, v...)
|
||||
}
|
|
@ -1,141 +0,0 @@
|
|||
// Copyright (C) 2019-2022 Nicola Murino
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published
|
||||
// by the Free Software Foundation, version 3.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package dataprovider
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
)
|
||||
|
||||
var (
|
||||
// QuotaScans is the list of active quota scans
|
||||
QuotaScans ActiveScans
|
||||
)
|
||||
|
||||
// ActiveQuotaScan defines an active quota scan for a user home dir
|
||||
type ActiveQuotaScan struct {
|
||||
// Username to which the quota scan refers
|
||||
Username string `json:"username"`
|
||||
// quota scan start time as unix timestamp in milliseconds
|
||||
StartTime int64 `json:"start_time"`
|
||||
}
|
||||
|
||||
// ActiveVirtualFolderQuotaScan defines an active quota scan for a virtual folder
|
||||
type ActiveVirtualFolderQuotaScan struct {
|
||||
// folder name to which the quota scan refers
|
||||
Name string `json:"name"`
|
||||
// quota scan start time as unix timestamp in milliseconds
|
||||
StartTime int64 `json:"start_time"`
|
||||
}
|
||||
|
||||
// ActiveScans holds the active quota scans
|
||||
type ActiveScans struct {
|
||||
sync.RWMutex
|
||||
UserScans []ActiveQuotaScan
|
||||
FolderScans []ActiveVirtualFolderQuotaScan
|
||||
}
|
||||
|
||||
// GetUsersQuotaScans returns the active quota scans for users home directories
|
||||
func (s *ActiveScans) GetUsersQuotaScans() []ActiveQuotaScan {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
scans := make([]ActiveQuotaScan, len(s.UserScans))
|
||||
copy(scans, s.UserScans)
|
||||
return scans
|
||||
}
|
||||
|
||||
// AddUserQuotaScan adds a user to the ones with active quota scans.
|
||||
// Returns false if the user has a quota scan already running
|
||||
func (s *ActiveScans) AddUserQuotaScan(username string) bool {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
for _, scan := range s.UserScans {
|
||||
if scan.Username == username {
|
||||
return false
|
||||
}
|
||||
}
|
||||
s.UserScans = append(s.UserScans, ActiveQuotaScan{
|
||||
Username: username,
|
||||
StartTime: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
// RemoveUserQuotaScan removes a user from the ones with active quota scans.
|
||||
// Returns false if the user has no active quota scans
|
||||
func (s *ActiveScans) RemoveUserQuotaScan(username string) bool {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
for idx, scan := range s.UserScans {
|
||||
if scan.Username == username {
|
||||
lastIdx := len(s.UserScans) - 1
|
||||
s.UserScans[idx] = s.UserScans[lastIdx]
|
||||
s.UserScans = s.UserScans[:lastIdx]
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetVFoldersQuotaScans returns the active quota scans for virtual folders
|
||||
func (s *ActiveScans) GetVFoldersQuotaScans() []ActiveVirtualFolderQuotaScan {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
scans := make([]ActiveVirtualFolderQuotaScan, len(s.FolderScans))
|
||||
copy(scans, s.FolderScans)
|
||||
return scans
|
||||
}
|
||||
|
||||
// AddVFolderQuotaScan adds a virtual folder to the ones with active quota scans.
|
||||
// Returns false if the folder has a quota scan already running
|
||||
func (s *ActiveScans) AddVFolderQuotaScan(folderName string) bool {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
for _, scan := range s.FolderScans {
|
||||
if scan.Name == folderName {
|
||||
return false
|
||||
}
|
||||
}
|
||||
s.FolderScans = append(s.FolderScans, ActiveVirtualFolderQuotaScan{
|
||||
Name: folderName,
|
||||
StartTime: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
// RemoveVFolderQuotaScan removes a folder from the ones with active quota scans.
|
||||
// Returns false if the folder has no active quota scans
|
||||
func (s *ActiveScans) RemoveVFolderQuotaScan(folderName string) bool {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
for idx, scan := range s.FolderScans {
|
||||
if scan.Name == folderName {
|
||||
lastIdx := len(s.FolderScans) - 1
|
||||
s.FolderScans[idx] = s.FolderScans[lastIdx]
|
||||
s.FolderScans = s.FolderScans[:lastIdx]
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
|
@ -54,7 +54,9 @@ func startScheduler() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
EventManager.loadRules()
|
||||
if fnReloadRules != nil {
|
||||
fnReloadRules()
|
||||
}
|
||||
scheduler.Start()
|
||||
return nil
|
||||
}
|
||||
|
@ -77,7 +79,7 @@ func checkDataprovider() {
|
|||
}
|
||||
|
||||
func checkCacheUpdates() {
|
||||
providerLog(logger.LevelDebug, "start caches check, update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate))
|
||||
providerLog(logger.LevelDebug, "start user cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate))
|
||||
checkTime := util.GetTimeAsMsSinceEpoch(time.Now())
|
||||
users, err := provider.getRecentlyUpdatedUsers(lastUserCacheUpdate)
|
||||
if err != nil {
|
||||
|
@ -101,8 +103,7 @@ func checkCacheUpdates() {
|
|||
}
|
||||
|
||||
lastUserCacheUpdate = checkTime
|
||||
EventManager.loadRules()
|
||||
providerLog(logger.LevelDebug, "end caches check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate))
|
||||
providerLog(logger.LevelDebug, "end user cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate))
|
||||
}
|
||||
|
||||
func setLastUserUpdate() {
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
|
||||
"github.com/go-chi/render"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/common"
|
||||
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
|
||||
"github.com/drakkan/sftpgo/v2/internal/logger"
|
||||
"github.com/drakkan/sftpgo/v2/internal/util"
|
||||
|
@ -265,7 +266,7 @@ func RestoreFolders(folders []vfs.BaseVirtualFolder, inputFile string, mode, sca
|
|||
return fmt.Errorf("unable to restore folder %#v: %w", folder.Name, err)
|
||||
}
|
||||
if scanQuota >= 1 {
|
||||
if dataprovider.QuotaScans.AddVFolderQuotaScan(folder.Name) {
|
||||
if common.QuotaScans.AddVFolderQuotaScan(folder.Name) {
|
||||
logger.Debug(logSender, "", "starting quota scan for restored folder: %#v", folder.Name)
|
||||
go doFolderQuotaScan(folder) //nolint:errcheck
|
||||
}
|
||||
|
@ -453,7 +454,7 @@ func RestoreUsers(users []dataprovider.User, inputFile string, mode, scanQuota i
|
|||
return fmt.Errorf("unable to restore user %#v: %w", user.Username, err)
|
||||
}
|
||||
if scanQuota == 1 || (scanQuota == 2 && user.HasQuotaRestrictions()) {
|
||||
if dataprovider.QuotaScans.AddUserQuotaScan(user.Username) {
|
||||
if common.QuotaScans.AddUserQuotaScan(user.Username) {
|
||||
logger.Debug(logSender, "", "starting quota scan for restored user: %#v", user.Username)
|
||||
go doUserQuotaScan(user) //nolint:errcheck
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
|
||||
"github.com/go-chi/render"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/common"
|
||||
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
|
||||
"github.com/drakkan/sftpgo/v2/internal/logger"
|
||||
"github.com/drakkan/sftpgo/v2/internal/vfs"
|
||||
|
@ -43,12 +44,12 @@ type transferQuotaUsage struct {
|
|||
|
||||
func getUsersQuotaScans(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
render.JSON(w, r, dataprovider.QuotaScans.GetUsersQuotaScans())
|
||||
render.JSON(w, r, common.QuotaScans.GetUsersQuotaScans())
|
||||
}
|
||||
|
||||
func getFoldersQuotaScans(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
render.JSON(w, r, dataprovider.QuotaScans.GetVFoldersQuotaScans())
|
||||
render.JSON(w, r, common.QuotaScans.GetVFoldersQuotaScans())
|
||||
}
|
||||
|
||||
func updateUserQuotaUsage(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -141,11 +142,11 @@ func doUpdateUserQuotaUsage(w http.ResponseWriter, r *http.Request, username str
|
|||
"", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !dataprovider.QuotaScans.AddUserQuotaScan(user.Username) {
|
||||
if !common.QuotaScans.AddUserQuotaScan(user.Username) {
|
||||
sendAPIResponse(w, r, err, "A quota scan is in progress for this user", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
defer dataprovider.QuotaScans.RemoveUserQuotaScan(user.Username)
|
||||
defer common.QuotaScans.RemoveUserQuotaScan(user.Username)
|
||||
err = dataprovider.UpdateUserQuota(&user, usage.UsedQuotaFiles, usage.UsedQuotaSize, mode == quotaUpdateModeReset)
|
||||
if err != nil {
|
||||
sendAPIResponse(w, r, err, "", getRespStatus(err))
|
||||
|
@ -170,11 +171,11 @@ func doUpdateFolderQuotaUsage(w http.ResponseWriter, r *http.Request, name strin
|
|||
sendAPIResponse(w, r, err, "", getRespStatus(err))
|
||||
return
|
||||
}
|
||||
if !dataprovider.QuotaScans.AddVFolderQuotaScan(folder.Name) {
|
||||
if !common.QuotaScans.AddVFolderQuotaScan(folder.Name) {
|
||||
sendAPIResponse(w, r, err, "A quota scan is in progress for this folder", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
defer dataprovider.QuotaScans.RemoveVFolderQuotaScan(folder.Name)
|
||||
defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name)
|
||||
err = dataprovider.UpdateVirtualFolderQuota(&folder, usage.UsedQuotaFiles, usage.UsedQuotaSize, mode == quotaUpdateModeReset)
|
||||
if err != nil {
|
||||
sendAPIResponse(w, r, err, "", getRespStatus(err))
|
||||
|
@ -193,7 +194,7 @@ func doStartUserQuotaScan(w http.ResponseWriter, r *http.Request, username strin
|
|||
sendAPIResponse(w, r, err, "", getRespStatus(err))
|
||||
return
|
||||
}
|
||||
if !dataprovider.QuotaScans.AddUserQuotaScan(user.Username) {
|
||||
if !common.QuotaScans.AddUserQuotaScan(user.Username) {
|
||||
sendAPIResponse(w, r, nil, fmt.Sprintf("Another scan is already in progress for user %#v", username),
|
||||
http.StatusConflict)
|
||||
return
|
||||
|
@ -212,7 +213,7 @@ func doStartFolderQuotaScan(w http.ResponseWriter, r *http.Request, name string)
|
|||
sendAPIResponse(w, r, err, "", getRespStatus(err))
|
||||
return
|
||||
}
|
||||
if !dataprovider.QuotaScans.AddVFolderQuotaScan(folder.Name) {
|
||||
if !common.QuotaScans.AddVFolderQuotaScan(folder.Name) {
|
||||
sendAPIResponse(w, r, err, fmt.Sprintf("Another scan is already in progress for folder %#v", name),
|
||||
http.StatusConflict)
|
||||
return
|
||||
|
@ -222,7 +223,7 @@ func doStartFolderQuotaScan(w http.ResponseWriter, r *http.Request, name string)
|
|||
}
|
||||
|
||||
func doUserQuotaScan(user dataprovider.User) error {
|
||||
defer dataprovider.QuotaScans.RemoveUserQuotaScan(user.Username)
|
||||
defer common.QuotaScans.RemoveUserQuotaScan(user.Username)
|
||||
numFiles, size, err := user.ScanQuota()
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "", "error scanning user quota %#v: %v", user.Username, err)
|
||||
|
@ -234,7 +235,7 @@ func doUserQuotaScan(user dataprovider.User) error {
|
|||
}
|
||||
|
||||
func doFolderQuotaScan(folder vfs.BaseVirtualFolder) error {
|
||||
defer dataprovider.QuotaScans.RemoveVFolderQuotaScan(folder.Name)
|
||||
defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name)
|
||||
f := vfs.VirtualFolder{
|
||||
BaseVirtualFolder: folder,
|
||||
VirtualPath: "/",
|
||||
|
|
|
@ -1951,6 +1951,8 @@ func TestHTTPUserAuthEmptyPassword(t *testing.T) {
|
|||
c.CloseIdleConnections()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
err = resp.Body.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = getJWTAPIUserTokenFromTestServer(defaultUsername, "")
|
||||
if assert.Error(t, err) {
|
||||
|
@ -1986,6 +1988,8 @@ func TestHTTPAnonymousUser(t *testing.T) {
|
|||
c.CloseIdleConnections()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
err = resp.Body.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
|
||||
if assert.Error(t, err) {
|
||||
|
@ -9347,12 +9351,12 @@ func TestUpdateUserQuotaUsageMock(t *testing.T) {
|
|||
setBearerForReq(req, token)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusBadRequest, rr)
|
||||
assert.True(t, dataprovider.QuotaScans.AddUserQuotaScan(user.Username))
|
||||
assert.True(t, common.QuotaScans.AddUserQuotaScan(user.Username))
|
||||
req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "usage"), bytes.NewBuffer(userAsJSON))
|
||||
setBearerForReq(req, token)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusConflict, rr)
|
||||
assert.True(t, dataprovider.QuotaScans.RemoveUserQuotaScan(user.Username))
|
||||
assert.True(t, common.QuotaScans.RemoveUserQuotaScan(user.Username))
|
||||
req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil)
|
||||
setBearerForReq(req, token)
|
||||
rr = executeRequest(req)
|
||||
|
@ -9624,12 +9628,12 @@ func TestStartQuotaScanMock(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
// simulate a duplicate quota scan
|
||||
dataprovider.QuotaScans.AddUserQuotaScan(user.Username)
|
||||
common.QuotaScans.AddUserQuotaScan(user.Username)
|
||||
req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil)
|
||||
setBearerForReq(req, token)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusConflict, rr)
|
||||
assert.True(t, dataprovider.QuotaScans.RemoveUserQuotaScan(user.Username))
|
||||
assert.True(t, common.QuotaScans.RemoveUserQuotaScan(user.Username))
|
||||
|
||||
req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil)
|
||||
setBearerForReq(req, token)
|
||||
|
@ -9743,13 +9747,13 @@ func TestUpdateFolderQuotaUsageMock(t *testing.T) {
|
|||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusBadRequest, rr)
|
||||
|
||||
assert.True(t, dataprovider.QuotaScans.AddVFolderQuotaScan(folderName))
|
||||
assert.True(t, common.QuotaScans.AddVFolderQuotaScan(folderName))
|
||||
req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "folders", folder.Name, "usage"),
|
||||
bytes.NewBuffer(folderAsJSON))
|
||||
setBearerForReq(req, token)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusConflict, rr)
|
||||
assert.True(t, dataprovider.QuotaScans.RemoveVFolderQuotaScan(folderName))
|
||||
assert.True(t, common.QuotaScans.RemoveVFolderQuotaScan(folderName))
|
||||
|
||||
req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil)
|
||||
setBearerForReq(req, token)
|
||||
|
@ -9778,12 +9782,12 @@ func TestStartFolderQuotaScanMock(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
// simulate a duplicate quota scan
|
||||
dataprovider.QuotaScans.AddVFolderQuotaScan(folderName)
|
||||
common.QuotaScans.AddVFolderQuotaScan(folderName)
|
||||
req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "folders", folder.Name, "scan"), nil)
|
||||
setBearerForReq(req, token)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusConflict, rr)
|
||||
assert.True(t, dataprovider.QuotaScans.RemoveVFolderQuotaScan(folderName))
|
||||
assert.True(t, common.QuotaScans.RemoveVFolderQuotaScan(folderName))
|
||||
// and now a real quota scan
|
||||
_, err = os.Stat(mappedPath)
|
||||
if err != nil && errors.Is(err, fs.ErrNotExist) {
|
||||
|
@ -20392,7 +20396,7 @@ func startOIDCMockServer() {
|
|||
|
||||
func waitForUsersQuotaScan(t *testing.T, token string) {
|
||||
for {
|
||||
var scans []dataprovider.ActiveQuotaScan
|
||||
var scans []common.ActiveQuotaScan
|
||||
req, _ := http.NewRequest(http.MethodGet, quotaScanPath, nil)
|
||||
setBearerForReq(req, token)
|
||||
rr := executeRequest(req)
|
||||
|
@ -20410,7 +20414,7 @@ func waitForUsersQuotaScan(t *testing.T, token string) {
|
|||
}
|
||||
|
||||
func waitForFoldersQuotaScanPath(t *testing.T, token string) {
|
||||
var scans []dataprovider.ActiveVirtualFolderQuotaScan
|
||||
var scans []common.ActiveVirtualFolderQuotaScan
|
||||
for {
|
||||
req, _ := http.NewRequest(http.MethodGet, quotaScanVFolderPath, nil)
|
||||
setBearerForReq(req, token)
|
||||
|
|
|
@ -1421,7 +1421,7 @@ func TestQuotaScanInvalidFs(t *testing.T) {
|
|||
Provider: sdk.S3FilesystemProvider,
|
||||
},
|
||||
}
|
||||
dataprovider.QuotaScans.AddUserQuotaScan(user.Username)
|
||||
common.QuotaScans.AddUserQuotaScan(user.Username)
|
||||
err := doUserQuotaScan(user)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
|
|
@ -833,8 +833,8 @@ func GetEventRules(limit, offset int64, expectedStatusCode int) ([]dataprovider.
|
|||
}
|
||||
|
||||
// GetQuotaScans gets active quota scans for users and checks the received HTTP Status code against expectedStatusCode.
|
||||
func GetQuotaScans(expectedStatusCode int) ([]dataprovider.ActiveQuotaScan, []byte, error) {
|
||||
var quotaScans []dataprovider.ActiveQuotaScan
|
||||
func GetQuotaScans(expectedStatusCode int) ([]common.ActiveQuotaScan, []byte, error) {
|
||||
var quotaScans []common.ActiveQuotaScan
|
||||
var body []byte
|
||||
resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanPath), nil, "", getDefaultToken())
|
||||
if err != nil {
|
||||
|
@ -1077,8 +1077,8 @@ func GetFolders(limit int64, offset int64, expectedStatusCode int) ([]vfs.BaseVi
|
|||
}
|
||||
|
||||
// GetFoldersQuotaScans gets active quota scans for folders and checks the received HTTP Status code against expectedStatusCode.
|
||||
func GetFoldersQuotaScans(expectedStatusCode int) ([]dataprovider.ActiveVirtualFolderQuotaScan, []byte, error) {
|
||||
var quotaScans []dataprovider.ActiveVirtualFolderQuotaScan
|
||||
func GetFoldersQuotaScans(expectedStatusCode int) ([]common.ActiveVirtualFolderQuotaScan, []byte, error) {
|
||||
var quotaScans []common.ActiveVirtualFolderQuotaScan
|
||||
var body []byte
|
||||
resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanVFolderPath), nil, "", getDefaultToken())
|
||||
if err != nil {
|
||||
|
|
|
@ -155,7 +155,7 @@ func newMockOsFs(err, statErr error, atomicUpload bool, connectionID, rootDir st
|
|||
}
|
||||
|
||||
func TestRemoveNonexistentQuotaScan(t *testing.T) {
|
||||
assert.False(t, dataprovider.QuotaScans.RemoveUserQuotaScan("username"))
|
||||
assert.False(t, common.QuotaScans.RemoveUserQuotaScan("username"))
|
||||
}
|
||||
|
||||
func TestGetOSOpenFlags(t *testing.T) {
|
||||
|
|
|
@ -1213,9 +1213,6 @@ func TestRealPath(t *testing.T) {
|
|||
for _, user := range []dataprovider.User{localUser, sftpUser} {
|
||||
conn, client, err := getSftpClient(user, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
|
||||
p, err := client.RealPath("../..")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "/", p)
|
||||
|
@ -1247,6 +1244,9 @@ func TestRealPath(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
_, err = client.RealPath(path.Join(subdir, "temp"))
|
||||
assert.ErrorIs(t, err, os.ErrPermission)
|
||||
|
||||
conn.Close()
|
||||
client.Close()
|
||||
err = os.Remove(filepath.Join(localUser.GetHomeDir(), subdir, "temp"))
|
||||
assert.NoError(t, err)
|
||||
if user.Username == localUser.Username {
|
||||
|
@ -4590,14 +4590,14 @@ func TestQuotaScan(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMultipleQuotaScans(t *testing.T) {
|
||||
res := dataprovider.QuotaScans.AddUserQuotaScan(defaultUsername)
|
||||
res := common.QuotaScans.AddUserQuotaScan(defaultUsername)
|
||||
assert.True(t, res)
|
||||
res = dataprovider.QuotaScans.AddUserQuotaScan(defaultUsername)
|
||||
res = common.QuotaScans.AddUserQuotaScan(defaultUsername)
|
||||
assert.False(t, res, "add quota must fail if another scan is already active")
|
||||
assert.True(t, dataprovider.QuotaScans.RemoveUserQuotaScan(defaultUsername))
|
||||
activeScans := dataprovider.QuotaScans.GetUsersQuotaScans()
|
||||
assert.True(t, common.QuotaScans.RemoveUserQuotaScan(defaultUsername))
|
||||
activeScans := common.QuotaScans.GetUsersQuotaScans()
|
||||
assert.Equal(t, 0, len(activeScans))
|
||||
assert.False(t, dataprovider.QuotaScans.RemoveUserQuotaScan(defaultUsername))
|
||||
assert.False(t, common.QuotaScans.RemoveUserQuotaScan(defaultUsername))
|
||||
}
|
||||
|
||||
func TestQuotaLimits(t *testing.T) {
|
||||
|
@ -6949,15 +6949,15 @@ func TestVirtualFolderQuotaScan(t *testing.T) {
|
|||
|
||||
func TestVFolderMultipleQuotaScan(t *testing.T) {
|
||||
folderName := "folder_name"
|
||||
res := dataprovider.QuotaScans.AddVFolderQuotaScan(folderName)
|
||||
res := common.QuotaScans.AddVFolderQuotaScan(folderName)
|
||||
assert.True(t, res)
|
||||
res = dataprovider.QuotaScans.AddVFolderQuotaScan(folderName)
|
||||
res = common.QuotaScans.AddVFolderQuotaScan(folderName)
|
||||
assert.False(t, res)
|
||||
res = dataprovider.QuotaScans.RemoveVFolderQuotaScan(folderName)
|
||||
res = common.QuotaScans.RemoveVFolderQuotaScan(folderName)
|
||||
assert.True(t, res)
|
||||
activeScans := dataprovider.QuotaScans.GetVFoldersQuotaScans()
|
||||
activeScans := common.QuotaScans.GetVFoldersQuotaScans()
|
||||
assert.Len(t, activeScans, 0)
|
||||
res = dataprovider.QuotaScans.RemoveVFolderQuotaScan(folderName)
|
||||
res = common.QuotaScans.RemoveVFolderQuotaScan(folderName)
|
||||
assert.False(t, res)
|
||||
}
|
||||
|
||||
|
|
|
@ -607,3 +607,10 @@ func GetTLSVersion(val int) uint16 {
|
|||
func IsEmailValid(email string) bool {
|
||||
return emailRegex.MatchString(email)
|
||||
}
|
||||
|
||||
// PanicOnError calls panic if err is not nil
|
||||
func PanicOnError(err error) {
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("unexpected error: %w", err))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -156,7 +156,7 @@ func (fs *CryptFs) Create(name string, flag int) (File, *PipeWriter, func(), err
|
|||
if flag == 0 {
|
||||
f, err = os.Create(name)
|
||||
} else {
|
||||
f, err = os.OpenFile(name, flag, os.ModePerm)
|
||||
f, err = os.OpenFile(name, flag, 0666)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
|
|
Loading…
Reference in a new issue