refactor: move eventmanager to common package

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2022-08-01 18:48:54 +02:00
parent 3ca62d76d7
commit 9d2b5dc07d
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
24 changed files with 2030 additions and 1161 deletions

6
go.mod
View file

@ -52,7 +52,7 @@ require (
github.com/rs/xid v1.4.0
github.com/rs/zerolog v1.27.0
github.com/sftpgo/sdk v0.1.2-0.20220727164210-06723ba7ce9a
github.com/shirou/gopsutil/v3 v3.22.6
github.com/shirou/gopsutil/v3 v3.22.7
github.com/spf13/afero v1.9.2
github.com/spf13/cobra v1.5.0
github.com/spf13/viper v1.12.0
@ -68,7 +68,7 @@ require (
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
golang.org/x/net v0.0.0-20220728211354-c7608f3a8462
golang.org/x/oauth2 v0.0.0-20220722155238-128564f6959c
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10
golang.org/x/sys v0.0.0-20220731174439-a90be440212d
golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9
google.golang.org/api v0.90.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0
@ -155,7 +155,7 @@ require (
golang.org/x/tools v0.1.12 // indirect
golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20220728213248-dd149ef739b9 // indirect
google.golang.org/genproto v0.0.0-20220801145646-83ce21fca29f // indirect
google.golang.org/grpc v1.48.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/ini.v1 v1.66.6 // indirect

12
go.sum
View file

@ -712,8 +712,8 @@ github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdh
github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4 h1:PT+ElG/UUFMfqy5HrxJxNzj3QBOf7dZwupeVC+mG1Lo=
github.com/sftpgo/sdk v0.1.2-0.20220727164210-06723ba7ce9a h1:X9qPZ+GPQ87TnBDNZN6dyX7FkjhwnFh98WgB6Y1T5O8=
github.com/sftpgo/sdk v0.1.2-0.20220727164210-06723ba7ce9a/go.mod h1:RL4HeorXC6XgqtkLYnQUSogLdsdMfbsogIvdBVLuy4w=
github.com/shirou/gopsutil/v3 v3.22.6 h1:FnHOFOh+cYAM0C30P+zysPISzlknLC5Z1G4EAElznfQ=
github.com/shirou/gopsutil/v3 v3.22.6/go.mod h1:EdIubSnZhbAvBS1yJ7Xi+AShB/hxwLHOMz4MCYz7yMs=
github.com/shirou/gopsutil/v3 v3.22.7 h1:flKnuCMfUUrO+oAvwAd6GKZgnPzr098VA/UJ14nhJd4=
github.com/shirou/gopsutil/v3 v3.22.7/go.mod h1:s648gW4IywYzUfE/KjXxUsqrqx/T2xO5VqOXxONeRfI=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
@ -746,7 +746,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
github.com/stretchr/testify v1.7.5/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/studio-b12/gowebdav v0.0.0-20220128162035-c7b1ff8a5e62 h1:b2nJXyPCa9HY7giGM+kYcnQ71m14JnGdQabMPmyt++8=
@ -973,8 +972,9 @@ golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220624220833-87e55d714810/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@ -1225,8 +1225,8 @@ google.golang.org/genproto v0.0.0-20220616135557-88e70c0c3a90/go.mod h1:KEWEmljW
google.golang.org/genproto v0.0.0-20220617124728-180714bec0ad/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA=
google.golang.org/genproto v0.0.0-20220624142145-8cd45d7dbd1f/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA=
google.golang.org/genproto v0.0.0-20220628213854-d9e0b6570c03/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA=
google.golang.org/genproto v0.0.0-20220728213248-dd149ef739b9 h1:d3fKQZK+1rWQMg3xLKQbPMirUCo29I/NRdI2WarSzTg=
google.golang.org/genproto v0.0.0-20220728213248-dd149ef739b9/go.mod h1:iHe1svFLAZg9VWz891+QbRMwUv9O/1Ww+/mngYeThbc=
google.golang.org/genproto v0.0.0-20220801145646-83ce21fca29f h1:XVHpVMvPs4MtH3h6cThzKs2snNexcfd35vQx2T3IuIY=
google.golang.org/genproto v0.0.0-20220801145646-83ce21fca29f/go.mod h1:iHe1svFLAZg9VWz891+QbRMwUv9O/1Ww+/mngYeThbc=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=

View file

@ -102,7 +102,7 @@ func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtua
) error {
hasNotifiersPlugin := plugin.Handler.HasNotifiers()
hasHook := util.Contains(Config.Actions.ExecuteOn, operation)
hasRules := dataprovider.EventManager.HasFsRules()
hasRules := eventManager.hasFsRules()
if !hasHook && !hasNotifiersPlugin && !hasRules {
return nil
}
@ -113,7 +113,7 @@ func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtua
}
var errRes error
if hasRules {
errRes = dataprovider.EventManager.HandleFsEvent(dataprovider.EventParams{
errRes = eventManager.handleFsEvent(EventParams{
Name: notification.Username,
Event: notification.Action,
Status: notification.Status,

View file

@ -138,11 +138,11 @@ var (
// Config is the configuration for the supported protocols
Config Configuration
// Connections is the list of active connections
Connections ActiveConnections
transfersChecker TransfersChecker
periodicTimeoutTicker *time.Ticker
periodicTimeoutTickerDone chan bool
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
Connections ActiveConnections
// QuotaScans is the list of active quota scans
QuotaScans ActiveScans
transfersChecker TransfersChecker
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
ProtocolHTTP, ProtocolHTTPShare, ProtocolOIDC}
disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
// the map key is the protocol, for each protocol we can have multiple rate limiters
@ -157,7 +157,7 @@ func Initialize(c Configuration, isShared int) error {
Config.ProxyAllowed = util.RemoveDuplicates(Config.ProxyAllowed, true)
Config.idleLoginTimeout = 2 * time.Minute
Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute
startPeriodicTimeoutTicker(periodicTimeoutCheckInterval)
startPeriodicChecks(periodicTimeoutCheckInterval)
Config.defender = nil
Config.whitelist = nil
rateLimiters = make(map[string][]*rateLimiter)
@ -308,35 +308,18 @@ func AddDefenderEvent(ip string, event HostEvent) {
Config.defender.AddEvent(ip, event)
}
// the ticker cannot be started/stopped from multiple goroutines
func startPeriodicTimeoutTicker(duration time.Duration) {
stopPeriodicTimeoutTicker()
periodicTimeoutTicker = time.NewTicker(duration)
periodicTimeoutTickerDone = make(chan bool)
go func() {
counter := int64(0)
func startPeriodicChecks(duration time.Duration) {
startEventScheduler()
spec := fmt.Sprintf("@every %s", duration)
_, err := eventScheduler.AddFunc(spec, Connections.checkTransfers)
util.PanicOnError(err)
logger.Info(logSender, "", "scheduled overquota transfers check, schedule %q", spec)
if Config.IdleTimeout > 0 {
ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval
for {
select {
case <-periodicTimeoutTickerDone:
return
case <-periodicTimeoutTicker.C:
counter++
if Config.IdleTimeout > 0 && counter >= int64(ratio) {
counter = 0
Connections.checkIdles()
}
go Connections.checkTransfers()
}
}
}()
}
func stopPeriodicTimeoutTicker() {
if periodicTimeoutTicker != nil {
periodicTimeoutTicker.Stop()
periodicTimeoutTickerDone <- true
periodicTimeoutTicker = nil
spec = fmt.Sprintf("@every %s", duration*ratio)
_, err = eventScheduler.AddFunc(spec, Connections.checkIdles)
util.PanicOnError(err)
logger.Info(logSender, "", "scheduled idle connections check, schedule %q", spec)
}
}
@ -1162,3 +1145,117 @@ func (c *ConnectionStatus) GetTransfersAsString() string {
}
return result
}
// ActiveQuotaScan defines an active quota scan for a user home dir
type ActiveQuotaScan struct {
// Username to which the quota scan refers
Username string `json:"username"`
// quota scan start time as unix timestamp in milliseconds
StartTime int64 `json:"start_time"`
}
// ActiveVirtualFolderQuotaScan defines an active quota scan for a virtual folder
type ActiveVirtualFolderQuotaScan struct {
// folder name to which the quota scan refers
Name string `json:"name"`
// quota scan start time as unix timestamp in milliseconds
StartTime int64 `json:"start_time"`
}
// ActiveScans holds the active quota scans
type ActiveScans struct {
sync.RWMutex
UserScans []ActiveQuotaScan
FolderScans []ActiveVirtualFolderQuotaScan
}
// GetUsersQuotaScans returns the active quota scans for users home directories
func (s *ActiveScans) GetUsersQuotaScans() []ActiveQuotaScan {
s.RLock()
defer s.RUnlock()
scans := make([]ActiveQuotaScan, len(s.UserScans))
copy(scans, s.UserScans)
return scans
}
// AddUserQuotaScan adds a user to the ones with active quota scans.
// Returns false if the user has a quota scan already running
func (s *ActiveScans) AddUserQuotaScan(username string) bool {
s.Lock()
defer s.Unlock()
for _, scan := range s.UserScans {
if scan.Username == username {
return false
}
}
s.UserScans = append(s.UserScans, ActiveQuotaScan{
Username: username,
StartTime: util.GetTimeAsMsSinceEpoch(time.Now()),
})
return true
}
// RemoveUserQuotaScan removes a user from the ones with active quota scans.
// Returns false if the user has no active quota scans
func (s *ActiveScans) RemoveUserQuotaScan(username string) bool {
s.Lock()
defer s.Unlock()
for idx, scan := range s.UserScans {
if scan.Username == username {
lastIdx := len(s.UserScans) - 1
s.UserScans[idx] = s.UserScans[lastIdx]
s.UserScans = s.UserScans[:lastIdx]
return true
}
}
return false
}
// GetVFoldersQuotaScans returns the active quota scans for virtual folders
func (s *ActiveScans) GetVFoldersQuotaScans() []ActiveVirtualFolderQuotaScan {
s.RLock()
defer s.RUnlock()
scans := make([]ActiveVirtualFolderQuotaScan, len(s.FolderScans))
copy(scans, s.FolderScans)
return scans
}
// AddVFolderQuotaScan adds a virtual folder to the ones with active quota scans.
// Returns false if the folder has a quota scan already running
func (s *ActiveScans) AddVFolderQuotaScan(folderName string) bool {
s.Lock()
defer s.Unlock()
for _, scan := range s.FolderScans {
if scan.Name == folderName {
return false
}
}
s.FolderScans = append(s.FolderScans, ActiveVirtualFolderQuotaScan{
Name: folderName,
StartTime: util.GetTimeAsMsSinceEpoch(time.Now()),
})
return true
}
// RemoveVFolderQuotaScan removes a folder from the ones with active quota scans.
// Returns false if the folder has no active quota scans
func (s *ActiveScans) RemoveVFolderQuotaScan(folderName string) bool {
s.Lock()
defer s.Unlock()
for idx, scan := range s.FolderScans {
if scan.Name == folderName {
lastIdx := len(s.FolderScans) - 1
s.FolderScans[idx] = s.FolderScans[lastIdx]
s.FolderScans = s.FolderScans[:lastIdx]
return true
}
}
return false
}

View file

@ -21,7 +21,6 @@ import (
"net"
"os"
"os/exec"
"path"
"path/filepath"
"runtime"
"strings"
@ -533,19 +532,19 @@ func TestIdleConnections(t *testing.T) {
assert.Len(t, Connections.sshConnections, 2)
Connections.RUnlock()
startPeriodicTimeoutTicker(100 * time.Millisecond)
assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond)
startPeriodicChecks(100 * time.Millisecond)
assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 2*time.Second, 200*time.Millisecond)
assert.Eventually(t, func() bool {
Connections.RLock()
defer Connections.RUnlock()
return len(Connections.sshConnections) == 1
}, 1*time.Second, 200*time.Millisecond)
stopPeriodicTimeoutTicker()
stopEventScheduler()
assert.Len(t, Connections.GetStats(), 2)
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
sshConn2.lastActivity = c.lastActivity
startPeriodicTimeoutTicker(100 * time.Millisecond)
startPeriodicChecks(100 * time.Millisecond)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond)
assert.Eventually(t, func() bool {
Connections.RLock()
@ -553,7 +552,7 @@ func TestIdleConnections(t *testing.T) {
return len(Connections.sshConnections) == 0
}, 1*time.Second, 200*time.Millisecond)
assert.Equal(t, int32(0), Connections.GetClientConnections())
stopPeriodicTimeoutTicker()
stopEventScheduler()
assert.True(t, customConn1.isClosed)
assert.True(t, customConn2.isClosed)
@ -719,6 +718,35 @@ func TestConnectionStatus(t *testing.T) {
assert.Len(t, stats, 0)
}
func TestQuotaScans(t *testing.T) {
username := "username"
assert.True(t, QuotaScans.AddUserQuotaScan(username))
assert.False(t, QuotaScans.AddUserQuotaScan(username))
usersScans := QuotaScans.GetUsersQuotaScans()
if assert.Len(t, usersScans, 1) {
assert.Equal(t, usersScans[0].Username, username)
assert.Equal(t, QuotaScans.UserScans[0].StartTime, usersScans[0].StartTime)
QuotaScans.UserScans[0].StartTime = 0
assert.NotEqual(t, QuotaScans.UserScans[0].StartTime, usersScans[0].StartTime)
}
assert.True(t, QuotaScans.RemoveUserQuotaScan(username))
assert.False(t, QuotaScans.RemoveUserQuotaScan(username))
assert.Len(t, QuotaScans.GetUsersQuotaScans(), 0)
assert.Len(t, usersScans, 1)
folderName := "folder"
assert.True(t, QuotaScans.AddVFolderQuotaScan(folderName))
assert.False(t, QuotaScans.AddVFolderQuotaScan(folderName))
if assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 1) {
assert.Equal(t, QuotaScans.GetVFoldersQuotaScans()[0].Name, folderName)
}
assert.True(t, QuotaScans.RemoveVFolderQuotaScan(folderName))
assert.False(t, QuotaScans.RemoveVFolderQuotaScan(folderName))
assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 0)
}
func TestProxyProtocolVersion(t *testing.T) {
c := Configuration{
ProxyProtocol: 0,
@ -1033,110 +1061,6 @@ func TestUserRecentActivity(t *testing.T) {
assert.True(t, res)
}
func TestEventRuleMatch(t *testing.T) {
conditions := dataprovider.EventConditions{
ProviderEvents: []string{"add", "update"},
Options: dataprovider.ConditionOptions{
Names: []dataprovider.ConditionPattern{
{
Pattern: "user1",
InverseMatch: true,
},
},
},
}
res := conditions.ProviderEventMatch(dataprovider.EventParams{
Name: "user1",
Event: "add",
})
assert.False(t, res)
res = conditions.ProviderEventMatch(dataprovider.EventParams{
Name: "user2",
Event: "update",
})
assert.True(t, res)
res = conditions.ProviderEventMatch(dataprovider.EventParams{
Name: "user2",
Event: "delete",
})
assert.False(t, res)
conditions.Options.ProviderObjects = []string{"api_key"}
res = conditions.ProviderEventMatch(dataprovider.EventParams{
Name: "user2",
Event: "update",
ObjectType: "share",
})
assert.False(t, res)
res = conditions.ProviderEventMatch(dataprovider.EventParams{
Name: "user2",
Event: "update",
ObjectType: "api_key",
})
assert.True(t, res)
// now test fs events
conditions = dataprovider.EventConditions{
FsEvents: []string{operationUpload, operationDownload},
Options: dataprovider.ConditionOptions{
Names: []dataprovider.ConditionPattern{
{
Pattern: "user*",
},
{
Pattern: "tester*",
},
},
FsPaths: []dataprovider.ConditionPattern{
{
Pattern: "*.txt",
},
},
Protocols: []string{ProtocolSFTP},
MinFileSize: 10,
MaxFileSize: 30,
},
}
params := dataprovider.EventParams{
Name: "tester4",
Event: operationDelete,
VirtualPath: "/path.txt",
Protocol: ProtocolSFTP,
ObjectName: "path.txt",
FileSize: 20,
}
res = conditions.FsEventMatch(params)
assert.False(t, res)
params.Event = operationDownload
res = conditions.FsEventMatch(params)
assert.True(t, res)
params.Name = "name"
res = conditions.FsEventMatch(params)
assert.False(t, res)
params.Name = "user5"
res = conditions.FsEventMatch(params)
assert.True(t, res)
params.VirtualPath = "/sub/f.jpg"
params.ObjectName = path.Base(params.VirtualPath)
res = conditions.FsEventMatch(params)
assert.False(t, res)
params.VirtualPath = "/sub/f.txt"
params.ObjectName = path.Base(params.VirtualPath)
res = conditions.FsEventMatch(params)
assert.True(t, res)
params.Protocol = ProtocolHTTP
res = conditions.FsEventMatch(params)
assert.False(t, res)
params.Protocol = ProtocolSFTP
params.FileSize = 5
res = conditions.FsEventMatch(params)
assert.False(t, res)
params.FileSize = 50
res = conditions.FsEventMatch(params)
assert.False(t, res)
params.FileSize = 25
res = conditions.FsEventMatch(params)
assert.True(t, res)
}
func BenchmarkBcryptHashing(b *testing.B) {
bcryptPassword := "bcryptpassword"
for i := 0; i < b.N; i++ {

View file

@ -0,0 +1,776 @@
// Copyright (C) 2019-2022 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package common
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/robfig/cron/v3"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
)
var (
// eventManager handle the supported event rules actions
eventManager eventRulesContainer
)
func init() {
eventManager = eventRulesContainer{
schedulesMapping: make(map[string][]cron.EntryID),
}
dataprovider.SetEventRulesCallbacks(eventManager.loadRules, eventManager.RemoveRule,
func(operation, executor, ip, objectType, objectName string, object plugin.Renderer) {
eventManager.handleProviderEvent(EventParams{
Name: executor,
ObjectName: objectName,
Event: operation,
Status: 1,
ObjectType: objectType,
IP: ip,
Timestamp: time.Now().UnixNano(),
Object: object,
})
})
}
// eventRulesContainer stores event rules by trigger
type eventRulesContainer struct {
sync.RWMutex
FsEvents []dataprovider.EventRule
ProviderEvents []dataprovider.EventRule
Schedules []dataprovider.EventRule
schedulesMapping map[string][]cron.EntryID
lastLoad int64
}
func (r *eventRulesContainer) getLastLoadTime() int64 {
return atomic.LoadInt64(&r.lastLoad)
}
func (r *eventRulesContainer) setLastLoadTime(modTime int64) {
atomic.StoreInt64(&r.lastLoad, modTime)
}
// RemoveRule deletes the rule with the specified name
func (r *eventRulesContainer) RemoveRule(name string) {
r.Lock()
defer r.Unlock()
r.removeRuleInternal(name)
eventManagerLog(logger.LevelDebug, "event rules updated after delete, fs events: %d, provider events: %d, schedules: %d",
len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules))
}
func (r *eventRulesContainer) removeRuleInternal(name string) {
for idx := range r.FsEvents {
if r.FsEvents[idx].Name == name {
lastIdx := len(r.FsEvents) - 1
r.FsEvents[idx] = r.FsEvents[lastIdx]
r.FsEvents = r.FsEvents[:lastIdx]
eventManagerLog(logger.LevelDebug, "removed rule %q from fs events", name)
return
}
}
for idx := range r.ProviderEvents {
if r.ProviderEvents[idx].Name == name {
lastIdx := len(r.ProviderEvents) - 1
r.ProviderEvents[idx] = r.ProviderEvents[lastIdx]
r.ProviderEvents = r.ProviderEvents[:lastIdx]
eventManagerLog(logger.LevelDebug, "removed rule %q from provider events", name)
return
}
}
for idx := range r.Schedules {
if r.Schedules[idx].Name == name {
if schedules, ok := r.schedulesMapping[name]; ok {
for _, entryID := range schedules {
eventManagerLog(logger.LevelDebug, "removing scheduled entry id %d for rule %q", entryID, name)
eventScheduler.Remove(entryID)
}
delete(r.schedulesMapping, name)
}
lastIdx := len(r.Schedules) - 1
r.Schedules[idx] = r.Schedules[lastIdx]
r.Schedules = r.Schedules[:lastIdx]
eventManagerLog(logger.LevelDebug, "removed rule %q from scheduled events", name)
return
}
}
}
func (r *eventRulesContainer) addUpdateRuleInternal(rule dataprovider.EventRule) {
r.removeRuleInternal(rule.Name)
if rule.DeletedAt > 0 {
deletedAt := util.GetTimeFromMsecSinceEpoch(rule.DeletedAt)
if deletedAt.Add(30 * time.Minute).Before(time.Now()) {
eventManagerLog(logger.LevelDebug, "removing rule %q deleted at %s", rule.Name, deletedAt)
go dataprovider.RemoveEventRule(rule) //nolint:errcheck
}
return
}
switch rule.Trigger {
case dataprovider.EventTriggerFsEvent:
r.FsEvents = append(r.FsEvents, rule)
eventManagerLog(logger.LevelDebug, "added rule %q to fs events", rule.Name)
case dataprovider.EventTriggerProviderEvent:
r.ProviderEvents = append(r.ProviderEvents, rule)
eventManagerLog(logger.LevelDebug, "added rule %q to provider events", rule.Name)
case dataprovider.EventTriggerSchedule:
for _, schedule := range rule.Conditions.Schedules {
cronSpec := schedule.GetCronSpec()
job := &eventCronJob{
ruleName: dataprovider.ConvertName(rule.Name),
}
entryID, err := eventScheduler.AddJob(cronSpec, job)
if err != nil {
eventManagerLog(logger.LevelError, "unable to add scheduled rule %q, cron string %q: %v", rule.Name, cronSpec, err)
return
}
r.schedulesMapping[rule.Name] = append(r.schedulesMapping[rule.Name], entryID)
eventManagerLog(logger.LevelDebug, "schedule for rule %q added, id: %d, cron string %q, active scheduling rules: %d",
rule.Name, entryID, cronSpec, len(r.schedulesMapping))
}
r.Schedules = append(r.Schedules, rule)
eventManagerLog(logger.LevelDebug, "added rule %q to scheduled events", rule.Name)
default:
eventManagerLog(logger.LevelError, "unsupported trigger: %d", rule.Trigger)
}
}
func (r *eventRulesContainer) loadRules() {
eventManagerLog(logger.LevelDebug, "loading updated rules")
modTime := util.GetTimeAsMsSinceEpoch(time.Now())
rules, err := dataprovider.GetRecentlyUpdatedRules(r.getLastLoadTime())
if err != nil {
eventManagerLog(logger.LevelError, "unable to load event rules: %v", err)
return
}
eventManagerLog(logger.LevelDebug, "recently updated event rules loaded: %d", len(rules))
if len(rules) > 0 {
r.Lock()
defer r.Unlock()
for _, rule := range rules {
r.addUpdateRuleInternal(rule)
}
}
eventManagerLog(logger.LevelDebug, "event rules updated, fs events: %d, provider events: %d, schedules: %d",
len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules))
r.setLastLoadTime(modTime)
}
func (r *eventRulesContainer) checkProviderEventMatch(conditions dataprovider.EventConditions, params EventParams) bool {
if !util.Contains(conditions.ProviderEvents, params.Event) {
return false
}
if !checkEventConditionPatterns(params.Name, conditions.Options.Names) {
return false
}
if len(conditions.Options.ProviderObjects) > 0 && !util.Contains(conditions.Options.ProviderObjects, params.ObjectType) {
return false
}
return true
}
func (r *eventRulesContainer) checkFsEventMatch(conditions dataprovider.EventConditions, params EventParams) bool {
if !util.Contains(conditions.FsEvents, params.Event) {
return false
}
if !checkEventConditionPatterns(params.Name, conditions.Options.Names) {
return false
}
if !checkEventConditionPatterns(params.VirtualPath, conditions.Options.FsPaths) {
if !checkEventConditionPatterns(params.ObjectName, conditions.Options.FsPaths) {
return false
}
}
if len(conditions.Options.Protocols) > 0 && !util.Contains(conditions.Options.Protocols, params.Protocol) {
return false
}
if params.Event == operationUpload || params.Event == operationDownload {
if conditions.Options.MinFileSize > 0 {
if params.FileSize < conditions.Options.MinFileSize {
return false
}
}
if conditions.Options.MaxFileSize > 0 {
if params.FileSize > conditions.Options.MaxFileSize {
return false
}
}
}
return true
}
// hasFsRules returns true if there are any rules for filesystem event triggers
func (r *eventRulesContainer) hasFsRules() bool {
r.RLock()
defer r.RUnlock()
return len(r.FsEvents) > 0
}
// handleFsEvent executes the rules actions defined for the specified event
func (r *eventRulesContainer) handleFsEvent(params EventParams) error {
r.RLock()
var rulesWithSyncActions, rulesAsync []dataprovider.EventRule
for _, rule := range r.FsEvents {
if r.checkFsEventMatch(rule.Conditions, params) {
hasSyncActions := false
for _, action := range rule.Actions {
if action.Options.ExecuteSync {
hasSyncActions = true
break
}
}
if hasSyncActions {
rulesWithSyncActions = append(rulesWithSyncActions, rule)
} else {
rulesAsync = append(rulesAsync, rule)
}
}
}
r.RUnlock()
if len(rulesAsync) > 0 {
go executeAsyncRulesActions(rulesAsync, params)
}
if len(rulesWithSyncActions) > 0 {
return executeSyncRulesActions(rulesWithSyncActions, params)
}
return nil
}
func (r *eventRulesContainer) handleProviderEvent(params EventParams) {
r.RLock()
defer r.RUnlock()
var rules []dataprovider.EventRule
for _, rule := range r.ProviderEvents {
if r.checkProviderEventMatch(rule.Conditions, params) {
rules = append(rules, rule)
}
}
if len(rules) > 0 {
go executeAsyncRulesActions(rules, params)
}
}
// EventParams defines the supported event parameters
type EventParams struct {
Name string
Event string
Status int
VirtualPath string
FsPath string
VirtualTargetPath string
FsTargetPath string
ObjectName string
ObjectType string
FileSize int64
Protocol string
IP string
Timestamp int64
Object plugin.Renderer
}
func (p *EventParams) getStringReplacements(addObjectData bool) []string {
replacements := []string{
"{{Name}}", p.Name,
"{{Event}}", p.Event,
"{{Status}}", fmt.Sprintf("%d", p.Status),
"{{VirtualPath}}", p.VirtualPath,
"{{FsPath}}", p.FsPath,
"{{VirtualTargetPath}}", p.VirtualTargetPath,
"{{FsTargetPath}}", p.FsTargetPath,
"{{ObjectName}}", p.ObjectName,
"{{ObjectType}}", p.ObjectType,
"{{FileSize}}", fmt.Sprintf("%d", p.FileSize),
"{{Protocol}}", p.Protocol,
"{{IP}}", p.IP,
"{{Timestamp}}", fmt.Sprintf("%d", p.Timestamp),
}
if addObjectData {
data, err := p.Object.RenderAsJSON(p.Event != operationDelete)
if err == nil {
replacements = append(replacements, "{{ObjectData}}", string(data))
}
}
return replacements
}
func replaceWithReplacer(input string, replacer *strings.Replacer) string {
if !strings.Contains(input, "{{") {
return input
}
return replacer.Replace(input)
}
func checkEventConditionPattern(p dataprovider.ConditionPattern, name string) bool {
matched, err := path.Match(p.Pattern, name)
if err != nil {
eventManagerLog(logger.LevelError, "pattern matching error %q, err: %v", p.Pattern, err)
return false
}
if p.InverseMatch {
return !matched
}
return matched
}
// checkConditionPatterns returns false if patterns are defined and no match is found
func checkEventConditionPatterns(name string, patterns []dataprovider.ConditionPattern) bool {
if len(patterns) == 0 {
return true
}
for _, p := range patterns {
if checkEventConditionPattern(p, name) {
return true
}
}
return false
}
func getHTTPRuleActionEndpoint(c dataprovider.EventActionHTTPConfig, replacer *strings.Replacer) (string, error) {
if len(c.QueryParameters) > 0 {
u, err := url.Parse(c.Endpoint)
if err != nil {
return "", fmt.Errorf("invalid endpoint: %w", err)
}
q := u.Query()
for _, keyVal := range c.QueryParameters {
q.Add(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer))
}
u.RawQuery = q.Encode()
return u.String(), nil
}
return c.Endpoint, nil
}
func executeHTTPRuleAction(c dataprovider.EventActionHTTPConfig, params EventParams) error {
if !c.Password.IsEmpty() {
if err := c.Password.TryDecrypt(); err != nil {
return fmt.Errorf("unable to decrypt password: %w", err)
}
}
addObjectData := false
if params.Object != nil {
if !addObjectData {
if strings.Contains(c.Body, "{{ObjectData}}") {
addObjectData = true
}
}
}
replacements := params.getStringReplacements(addObjectData)
replacer := strings.NewReplacer(replacements...)
endpoint, err := getHTTPRuleActionEndpoint(c, replacer)
if err != nil {
return err
}
var body io.Reader
if c.Body != "" && c.Method != http.MethodGet {
body = bytes.NewBufferString(replaceWithReplacer(c.Body, replacer))
}
req, err := http.NewRequest(c.Method, endpoint, body)
if err != nil {
return err
}
if c.Username != "" {
req.SetBasicAuth(replaceWithReplacer(c.Username, replacer), c.Password.GetAdditionalData())
}
for _, keyVal := range c.Headers {
req.Header.Set(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer))
}
client := c.GetHTTPClient()
defer client.CloseIdleConnections()
startTime := time.Now()
resp, err := client.Do(req)
if err != nil {
eventManagerLog(logger.LevelDebug, "unable to send http notification, endpoint: %s, elapsed: %s, err: %v",
endpoint, time.Since(startTime), err)
return err
}
defer resp.Body.Close()
eventManagerLog(logger.LevelDebug, "http notification sent, endopoint: %s, elapsed: %s, status code: %d",
endpoint, time.Since(startTime), resp.StatusCode)
if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
return nil
}
func executeCommandRuleAction(c dataprovider.EventActionCommandConfig, params EventParams) error {
envVars := make([]string, 0, len(c.EnvVars))
addObjectData := false
if params.Object != nil {
for _, k := range c.EnvVars {
if strings.Contains(k.Value, "{{ObjectData}}") {
addObjectData = true
break
}
}
}
replacements := params.getStringReplacements(addObjectData)
replacer := strings.NewReplacer(replacements...)
for _, keyVal := range c.EnvVars {
envVars = append(envVars, fmt.Sprintf("%s=%s", keyVal.Key, replaceWithReplacer(keyVal.Value, replacer)))
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.Timeout)*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, c.Cmd)
cmd.Env = append(cmd.Env, os.Environ()...)
cmd.Env = append(cmd.Env, envVars...)
startTime := time.Now()
err := cmd.Run()
eventManagerLog(logger.LevelDebug, "executed command %q, elapsed: %s, error: %v",
c.Cmd, time.Since(startTime), err)
return err
}
func executeEmailRuleAction(c dataprovider.EventActionEmailConfig, params EventParams) error {
addObjectData := false
if params.Object != nil {
if strings.Contains(c.Body, "{{ObjectData}}") {
addObjectData = true
}
}
replacements := params.getStringReplacements(addObjectData)
replacer := strings.NewReplacer(replacements...)
body := replaceWithReplacer(c.Body, replacer)
subject := replaceWithReplacer(c.Subject, replacer)
startTime := time.Now()
err := smtp.SendEmail(c.Recipients, subject, body, smtp.EmailContentTypeTextPlain)
eventManagerLog(logger.LevelDebug, "executed email notification action, elapsed: %s, error: %v",
time.Since(startTime), err)
return err
}
func executeUsersQuotaResetRuleAction(conditions dataprovider.ConditionOptions) error {
users, err := dataprovider.DumpUsers()
if err != nil {
return fmt.Errorf("unable to get users: %w", err)
}
var failedResets []string
for _, user := range users {
if !checkEventConditionPatterns(user.Username, conditions.Names) {
eventManagerLog(logger.LevelDebug, "skipping scheduled quota reset for user %s, name conditions don't match",
user.Username)
continue
}
if !QuotaScans.AddUserQuotaScan(user.Username) {
eventManagerLog(logger.LevelError, "another quota scan is already in progress for user %s", user.Username)
failedResets = append(failedResets, user.Username)
continue
}
numFiles, size, err := user.ScanQuota()
QuotaScans.RemoveUserQuotaScan(user.Username)
if err != nil {
eventManagerLog(logger.LevelError, "error scanning quota for user %s: %v", user.Username, err)
failedResets = append(failedResets, user.Username)
continue
}
err = dataprovider.UpdateUserQuota(&user, numFiles, size, true)
if err != nil {
eventManagerLog(logger.LevelError, "error updating quota for user %s: %v", user.Username, err)
failedResets = append(failedResets, user.Username)
continue
}
}
if len(failedResets) > 0 {
return fmt.Errorf("quota reset failed for users: %+v", failedResets)
}
return nil
}
func executeFoldersQuotaResetRuleAction(conditions dataprovider.ConditionOptions) error {
folders, err := dataprovider.DumpFolders()
if err != nil {
return fmt.Errorf("unable to get folders: %w", err)
}
var failedResets []string
for _, folder := range folders {
if !checkEventConditionPatterns(folder.Name, conditions.Names) {
eventManagerLog(logger.LevelDebug, "skipping scheduled quota reset for folder %s, name conditions don't match",
folder.Name)
continue
}
if !QuotaScans.AddVFolderQuotaScan(folder.Name) {
eventManagerLog(logger.LevelError, "another quota scan is already in progress for folder %s", folder.Name)
failedResets = append(failedResets, folder.Name)
continue
}
f := vfs.VirtualFolder{
BaseVirtualFolder: folder,
VirtualPath: "/",
}
numFiles, size, err := f.ScanQuota()
QuotaScans.RemoveVFolderQuotaScan(folder.Name)
if err != nil {
eventManagerLog(logger.LevelError, "error scanning quota for folder %s: %v", folder.Name, err)
failedResets = append(failedResets, folder.Name)
continue
}
err = dataprovider.UpdateVirtualFolderQuota(&folder, numFiles, size, true)
if err != nil {
eventManagerLog(logger.LevelError, "error updating quota for folder %s: %v", folder.Name, err)
failedResets = append(failedResets, folder.Name)
continue
}
}
if len(failedResets) > 0 {
return fmt.Errorf("quota reset failed for folders: %+v", failedResets)
}
return nil
}
func executeTransferQuotaResetRuleAction(conditions dataprovider.ConditionOptions) error {
users, err := dataprovider.DumpUsers()
if err != nil {
return fmt.Errorf("unable to get users: %w", err)
}
var failedResets []string
for _, user := range users {
if !checkEventConditionPatterns(user.Username, conditions.Names) {
eventManagerLog(logger.LevelDebug, "skipping scheduled transfer quota reset for user %s, name conditions don't match",
user.Username)
continue
}
err = dataprovider.UpdateUserTransferQuota(&user, 0, 0, true)
if err != nil {
eventManagerLog(logger.LevelError, "error updating transfer quota for user %s: %v", user.Username, err)
failedResets = append(failedResets, user.Username)
continue
}
}
if len(failedResets) > 0 {
return fmt.Errorf("transfer quota reset failed for users: %+v", failedResets)
}
return nil
}
func executeRuleAction(action dataprovider.BaseEventAction, params EventParams, conditions dataprovider.ConditionOptions) error {
switch action.Type {
case dataprovider.ActionTypeHTTP:
return executeHTTPRuleAction(action.Options.HTTPConfig, params)
case dataprovider.ActionTypeCommand:
return executeCommandRuleAction(action.Options.CmdConfig, params)
case dataprovider.ActionTypeEmail:
return executeEmailRuleAction(action.Options.EmailConfig, params)
case dataprovider.ActionTypeBackup:
return dataprovider.ExecuteBackup()
case dataprovider.ActionTypeUserQuotaReset:
return executeUsersQuotaResetRuleAction(conditions)
case dataprovider.ActionTypeFolderQuotaReset:
return executeFoldersQuotaResetRuleAction(conditions)
case dataprovider.ActionTypeTransferQuotaReset:
return executeTransferQuotaResetRuleAction(conditions)
default:
return fmt.Errorf("unsupported action type: %d", action.Type)
}
}
func executeSyncRulesActions(rules []dataprovider.EventRule, params EventParams) error {
var errRes error
for _, rule := range rules {
var failedActions []string
for _, action := range rule.Actions {
if !action.Options.IsFailureAction && action.Options.ExecuteSync {
startTime := time.Now()
if err := executeRuleAction(action.BaseEventAction, params, rule.Conditions.Options); err != nil {
eventManagerLog(logger.LevelError, "unable to execute sync action %q for rule %q, elapsed %s, err: %v",
action.Name, rule.Name, time.Since(startTime), err)
failedActions = append(failedActions, action.Name)
// we return the last error, it is ok for now
errRes = err
if action.Options.StopOnFailure {
break
}
} else {
eventManagerLog(logger.LevelDebug, "executed sync action %q for rule %q, elapsed: %s",
action.Name, rule.Name, time.Since(startTime))
}
}
}
// execute async actions if any, including failure actions
go executeRuleAsyncActions(rule, params, failedActions)
}
return errRes
}
func executeAsyncRulesActions(rules []dataprovider.EventRule, params EventParams) {
for _, rule := range rules {
executeRuleAsyncActions(rule, params, nil)
}
}
func executeRuleAsyncActions(rule dataprovider.EventRule, params EventParams, failedActions []string) {
for _, action := range rule.Actions {
if !action.Options.IsFailureAction && !action.Options.ExecuteSync {
startTime := time.Now()
if err := executeRuleAction(action.BaseEventAction, params, rule.Conditions.Options); err != nil {
eventManagerLog(logger.LevelError, "unable to execute action %q for rule %q, elapsed %s, err: %v",
action.Name, rule.Name, time.Since(startTime), err)
failedActions = append(failedActions, action.Name)
if action.Options.StopOnFailure {
break
}
} else {
eventManagerLog(logger.LevelDebug, "executed action %q for rule %q, elapsed %s",
action.Name, rule.Name, time.Since(startTime))
}
}
}
if len(failedActions) > 0 {
// execute failure actions
for _, action := range rule.Actions {
if action.Options.IsFailureAction {
startTime := time.Now()
if err := executeRuleAction(action.BaseEventAction, params, rule.Conditions.Options); err != nil {
eventManagerLog(logger.LevelError, "unable to execute failure action %q for rule %q, elapsed %s, err: %v",
action.Name, rule.Name, time.Since(startTime), err)
if action.Options.StopOnFailure {
break
}
} else {
eventManagerLog(logger.LevelDebug, "executed failure action %q for rule %q, elapsed: %s",
action.Name, rule.Name, time.Since(startTime))
}
}
}
}
}
type eventCronJob struct {
ruleName string
}
func (j *eventCronJob) getTask(rule dataprovider.EventRule) (dataprovider.Task, error) {
if rule.GuardFromConcurrentExecution() {
task, err := dataprovider.GetTaskByName(rule.Name)
if _, ok := err.(*util.RecordNotFoundError); ok {
eventManagerLog(logger.LevelDebug, "adding task for rule %q", rule.Name)
task = dataprovider.Task{
Name: rule.Name,
UpdateAt: 0,
Version: 0,
}
err = dataprovider.AddTask(rule.Name)
if err != nil {
eventManagerLog(logger.LevelWarn, "unable to add task for rule %q: %v", rule.Name, err)
return task, err
}
} else {
eventManagerLog(logger.LevelWarn, "unable to get task for rule %q: %v", rule.Name, err)
}
return task, err
}
return dataprovider.Task{}, nil
}
func (j *eventCronJob) Run() {
eventManagerLog(logger.LevelDebug, "executing scheduled rule %q", j.ruleName)
rule, err := dataprovider.EventRuleExists(j.ruleName)
if err != nil {
eventManagerLog(logger.LevelError, "unable to load rule with name %q", j.ruleName)
return
}
task, err := j.getTask(rule)
if err != nil {
return
}
if task.Name != "" {
updateInterval := 5 * time.Minute
updatedAt := util.GetTimeFromMsecSinceEpoch(task.UpdateAt)
if updatedAt.Add(updateInterval*2 + 1).After(time.Now()) {
eventManagerLog(logger.LevelDebug, "task for rule %q too recent: %s, skip execution", rule.Name, updatedAt)
return
}
err = dataprovider.UpdateTask(rule.Name, task.Version)
if err != nil {
eventManagerLog(logger.LevelInfo, "unable to update task timestamp for rule %q, skip execution, err: %v",
rule.Name, err)
return
}
ticker := time.NewTicker(updateInterval)
done := make(chan bool)
go func(taskName string) {
eventManagerLog(logger.LevelDebug, "update task %q timestamp worker started", taskName)
for {
select {
case <-done:
eventManagerLog(logger.LevelDebug, "update task %q timestamp worker finished", taskName)
return
case <-ticker.C:
err := dataprovider.UpdateTaskTimestamp(taskName)
eventManagerLog(logger.LevelInfo, "updated timestamp for task %q, err: %v", taskName, err)
}
}
}(task.Name)
executeRuleAsyncActions(rule, EventParams{}, nil)
done <- true
ticker.Stop()
} else {
executeRuleAsyncActions(rule, EventParams{}, nil)
}
eventManagerLog(logger.LevelDebug, "execution for scheduled rule %q finished", j.ruleName)
}
func eventManagerLog(level logger.LogLevel, format string, v ...any) {
logger.Log(level, "eventmanager", "", format, v...)
}

View file

@ -0,0 +1,674 @@
// Copyright (C) 2019-2022 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package common
import (
"fmt"
"net/http"
"os"
"path"
"path/filepath"
"testing"
"time"
"github.com/sftpgo/sdk"
sdkkms "github.com/sftpgo/sdk/kms"
"github.com/stretchr/testify/assert"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
)
func TestEventRuleMatch(t *testing.T) {
conditions := dataprovider.EventConditions{
ProviderEvents: []string{"add", "update"},
Options: dataprovider.ConditionOptions{
Names: []dataprovider.ConditionPattern{
{
Pattern: "user1",
InverseMatch: true,
},
},
},
}
res := eventManager.checkProviderEventMatch(conditions, EventParams{
Name: "user1",
Event: "add",
})
assert.False(t, res)
res = eventManager.checkProviderEventMatch(conditions, EventParams{
Name: "user2",
Event: "update",
})
assert.True(t, res)
res = eventManager.checkProviderEventMatch(conditions, EventParams{
Name: "user2",
Event: "delete",
})
assert.False(t, res)
conditions.Options.ProviderObjects = []string{"api_key"}
res = eventManager.checkProviderEventMatch(conditions, EventParams{
Name: "user2",
Event: "update",
ObjectType: "share",
})
assert.False(t, res)
res = eventManager.checkProviderEventMatch(conditions, EventParams{
Name: "user2",
Event: "update",
ObjectType: "api_key",
})
assert.True(t, res)
// now test fs events
conditions = dataprovider.EventConditions{
FsEvents: []string{operationUpload, operationDownload},
Options: dataprovider.ConditionOptions{
Names: []dataprovider.ConditionPattern{
{
Pattern: "user*",
},
{
Pattern: "tester*",
},
},
FsPaths: []dataprovider.ConditionPattern{
{
Pattern: "*.txt",
},
},
Protocols: []string{ProtocolSFTP},
MinFileSize: 10,
MaxFileSize: 30,
},
}
params := EventParams{
Name: "tester4",
Event: operationDelete,
VirtualPath: "/path.txt",
Protocol: ProtocolSFTP,
ObjectName: "path.txt",
FileSize: 20,
}
res = eventManager.checkFsEventMatch(conditions, params)
assert.False(t, res)
params.Event = operationDownload
res = eventManager.checkFsEventMatch(conditions, params)
assert.True(t, res)
params.Name = "name"
res = eventManager.checkFsEventMatch(conditions, params)
assert.False(t, res)
params.Name = "user5"
res = eventManager.checkFsEventMatch(conditions, params)
assert.True(t, res)
params.VirtualPath = "/sub/f.jpg"
params.ObjectName = path.Base(params.VirtualPath)
res = eventManager.checkFsEventMatch(conditions, params)
assert.False(t, res)
params.VirtualPath = "/sub/f.txt"
params.ObjectName = path.Base(params.VirtualPath)
res = eventManager.checkFsEventMatch(conditions, params)
assert.True(t, res)
params.Protocol = ProtocolHTTP
res = eventManager.checkFsEventMatch(conditions, params)
assert.False(t, res)
params.Protocol = ProtocolSFTP
params.FileSize = 5
res = eventManager.checkFsEventMatch(conditions, params)
assert.False(t, res)
params.FileSize = 50
res = eventManager.checkFsEventMatch(conditions, params)
assert.False(t, res)
params.FileSize = 25
res = eventManager.checkFsEventMatch(conditions, params)
assert.True(t, res)
// bad pattern
conditions.Options.Names = []dataprovider.ConditionPattern{
{
Pattern: "[-]",
},
}
res = eventManager.checkFsEventMatch(conditions, params)
assert.False(t, res)
}
func TestEventManager(t *testing.T) {
startEventScheduler()
action := &dataprovider.BaseEventAction{
Name: "test_action",
Type: dataprovider.ActionTypeHTTP,
Options: dataprovider.BaseEventActionOptions{
HTTPConfig: dataprovider.EventActionHTTPConfig{
Endpoint: "http://localhost",
Timeout: 20,
Method: http.MethodGet,
},
},
}
err := dataprovider.AddEventAction(action, "", "")
assert.NoError(t, err)
rule := &dataprovider.EventRule{
Name: "rule",
Trigger: dataprovider.EventTriggerFsEvent,
Conditions: dataprovider.EventConditions{
FsEvents: []string{operationUpload},
},
Actions: []dataprovider.EventAction{
{
BaseEventAction: dataprovider.BaseEventAction{
Name: action.Name,
},
Order: 1,
},
},
}
err = dataprovider.AddEventRule(rule, "", "")
assert.NoError(t, err)
eventManager.RLock()
assert.Len(t, eventManager.FsEvents, 1)
assert.Len(t, eventManager.ProviderEvents, 0)
assert.Len(t, eventManager.Schedules, 0)
assert.Len(t, eventManager.schedulesMapping, 0)
eventManager.RUnlock()
rule.Trigger = dataprovider.EventTriggerProviderEvent
rule.Conditions = dataprovider.EventConditions{
ProviderEvents: []string{"add"},
}
err = dataprovider.UpdateEventRule(rule, "", "")
assert.NoError(t, err)
eventManager.RLock()
assert.Len(t, eventManager.FsEvents, 0)
assert.Len(t, eventManager.ProviderEvents, 1)
assert.Len(t, eventManager.Schedules, 0)
assert.Len(t, eventManager.schedulesMapping, 0)
eventManager.RUnlock()
rule.Trigger = dataprovider.EventTriggerSchedule
rule.Conditions = dataprovider.EventConditions{
Schedules: []dataprovider.Schedule{
{
Hours: "0",
DayOfWeek: "*",
DayOfMonth: "*",
Month: "*",
},
},
}
rule.DeletedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-12 * time.Hour))
eventManager.addUpdateRuleInternal(*rule)
eventManager.RLock()
assert.Len(t, eventManager.FsEvents, 0)
assert.Len(t, eventManager.ProviderEvents, 0)
assert.Len(t, eventManager.Schedules, 0)
assert.Len(t, eventManager.schedulesMapping, 0)
eventManager.RUnlock()
assert.Eventually(t, func() bool {
_, err = dataprovider.EventRuleExists(rule.Name)
_, ok := err.(*util.RecordNotFoundError)
return ok
}, 2*time.Second, 100*time.Millisecond)
rule.DeletedAt = 0
err = dataprovider.AddEventRule(rule, "", "")
assert.NoError(t, err)
eventManager.RLock()
assert.Len(t, eventManager.FsEvents, 0)
assert.Len(t, eventManager.ProviderEvents, 0)
assert.Len(t, eventManager.Schedules, 1)
assert.Len(t, eventManager.schedulesMapping, 1)
eventManager.RUnlock()
err = dataprovider.DeleteEventRule(rule.Name, "", "")
assert.NoError(t, err)
eventManager.RLock()
assert.Len(t, eventManager.FsEvents, 0)
assert.Len(t, eventManager.ProviderEvents, 0)
assert.Len(t, eventManager.Schedules, 0)
assert.Len(t, eventManager.schedulesMapping, 0)
eventManager.RUnlock()
err = dataprovider.DeleteEventAction(action.Name, "", "")
assert.NoError(t, err)
stopEventScheduler()
}
func TestEventManagerErrors(t *testing.T) {
startEventScheduler()
providerConf := dataprovider.GetProviderConfig()
err := dataprovider.Close()
assert.NoError(t, err)
err = executeUsersQuotaResetRuleAction(dataprovider.ConditionOptions{})
assert.Error(t, err)
err = executeFoldersQuotaResetRuleAction(dataprovider.ConditionOptions{})
assert.Error(t, err)
err = executeTransferQuotaResetRuleAction(dataprovider.ConditionOptions{})
assert.Error(t, err)
eventManager.loadRules()
eventManager.RLock()
assert.Len(t, eventManager.FsEvents, 0)
assert.Len(t, eventManager.ProviderEvents, 0)
assert.Len(t, eventManager.Schedules, 0)
eventManager.RUnlock()
// rule with invalid trigger
eventManager.addUpdateRuleInternal(dataprovider.EventRule{
Name: "test rule",
Trigger: -1,
})
eventManager.RLock()
assert.Len(t, eventManager.FsEvents, 0)
assert.Len(t, eventManager.ProviderEvents, 0)
assert.Len(t, eventManager.Schedules, 0)
eventManager.RUnlock()
// rule with invalid cronspec
eventManager.addUpdateRuleInternal(dataprovider.EventRule{
Name: "test rule",
Trigger: dataprovider.EventTriggerSchedule,
Conditions: dataprovider.EventConditions{
Schedules: []dataprovider.Schedule{
{
Hours: "1000",
},
},
},
})
eventManager.RLock()
assert.Len(t, eventManager.FsEvents, 0)
assert.Len(t, eventManager.ProviderEvents, 0)
assert.Len(t, eventManager.Schedules, 0)
eventManager.RUnlock()
err = dataprovider.Initialize(providerConf, configDir, true)
assert.NoError(t, err)
stopEventScheduler()
}
func TestEventRuleActions(t *testing.T) {
actionName := "test rule action"
action := dataprovider.BaseEventAction{
Name: actionName,
Type: dataprovider.ActionTypeBackup,
}
err := executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{})
assert.NoError(t, err)
action.Type = -1
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{})
assert.Error(t, err)
action = dataprovider.BaseEventAction{
Name: actionName,
Type: dataprovider.ActionTypeHTTP,
Options: dataprovider.BaseEventActionOptions{
HTTPConfig: dataprovider.EventActionHTTPConfig{
Endpoint: "http://foo\x7f.com/", // invalid URL
SkipTLSVerify: true,
Body: "{{ObjectData}}",
Method: http.MethodPost,
QueryParameters: []dataprovider.KeyValue{
{
Key: "param",
Value: "value",
},
},
Timeout: 5,
Headers: []dataprovider.KeyValue{
{
Key: "Content-Type",
Value: "application/json",
},
},
Username: "httpuser",
},
},
}
action.Options.SetEmptySecretsIfNil()
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{})
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "invalid endpoint")
}
action.Options.HTTPConfig.Endpoint = fmt.Sprintf("http://%v", httpAddr)
params := EventParams{
Name: "a",
Object: &dataprovider.User{
BaseUser: sdk.BaseUser{
Username: "test user",
},
},
}
err = executeRuleAction(action, params, dataprovider.ConditionOptions{})
assert.NoError(t, err)
action.Options.HTTPConfig.Endpoint = fmt.Sprintf("http://%v/404", httpAddr)
err = executeRuleAction(action, params, dataprovider.ConditionOptions{})
if assert.Error(t, err) {
assert.Equal(t, err.Error(), "unexpected status code: 404")
}
action.Options.HTTPConfig.Endpoint = "http://invalid:1234"
err = executeRuleAction(action, params, dataprovider.ConditionOptions{})
assert.Error(t, err)
action.Options.HTTPConfig.QueryParameters = nil
action.Options.HTTPConfig.Endpoint = "http://bar\x7f.com/"
err = executeRuleAction(action, params, dataprovider.ConditionOptions{})
assert.Error(t, err)
action.Options.HTTPConfig.Password = kms.NewSecret(sdkkms.SecretStatusSecretBox, "payload", "key", "data")
err = executeRuleAction(action, params, dataprovider.ConditionOptions{})
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "unable to decrypt password")
}
// test disk and transfer quota reset
username1 := "user1"
username2 := "user2"
user1 := dataprovider.User{
BaseUser: sdk.BaseUser{
Username: username1,
HomeDir: filepath.Join(os.TempDir(), username1),
Status: 1,
Permissions: map[string][]string{
"/": {dataprovider.PermAny},
},
},
}
user2 := dataprovider.User{
BaseUser: sdk.BaseUser{
Username: username2,
HomeDir: filepath.Join(os.TempDir(), username2),
Status: 1,
Permissions: map[string][]string{
"/": {dataprovider.PermAny},
},
},
}
err = dataprovider.AddUser(&user1, "", "")
assert.NoError(t, err)
err = dataprovider.AddUser(&user2, "", "")
assert.NoError(t, err)
action = dataprovider.BaseEventAction{
Type: dataprovider.ActionTypeUserQuotaReset,
}
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{
Names: []dataprovider.ConditionPattern{
{
Pattern: username1,
},
},
})
assert.Error(t, err) // no home dir
// create the home dir
err = os.MkdirAll(user1.GetHomeDir(), os.ModePerm)
assert.NoError(t, err)
err = os.WriteFile(filepath.Join(user1.GetHomeDir(), "file.txt"), []byte("user"), 0666)
assert.NoError(t, err)
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{
Names: []dataprovider.ConditionPattern{
{
Pattern: username1,
},
},
})
assert.NoError(t, err)
userGet, err := dataprovider.UserExists(username1)
assert.NoError(t, err)
assert.Equal(t, 1, userGet.UsedQuotaFiles)
assert.Equal(t, int64(4), userGet.UsedQuotaSize)
// simulate another quota scan in progress
assert.True(t, QuotaScans.AddUserQuotaScan(username1))
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{
Names: []dataprovider.ConditionPattern{
{
Pattern: username1,
},
},
})
assert.Error(t, err)
assert.True(t, QuotaScans.RemoveUserQuotaScan(username1))
err = os.RemoveAll(user1.GetHomeDir())
assert.NoError(t, err)
err = dataprovider.UpdateUserTransferQuota(&user1, 100, 100, true)
assert.NoError(t, err)
action.Type = dataprovider.ActionTypeTransferQuotaReset
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{
Names: []dataprovider.ConditionPattern{
{
Pattern: username1,
},
},
})
assert.NoError(t, err)
userGet, err = dataprovider.UserExists(username1)
assert.NoError(t, err)
assert.Equal(t, int64(0), userGet.UsedDownloadDataTransfer)
assert.Equal(t, int64(0), userGet.UsedUploadDataTransfer)
err = dataprovider.DeleteUser(username1, "", "")
assert.NoError(t, err)
err = dataprovider.DeleteUser(username2, "", "")
assert.NoError(t, err)
// test folder quota reset
foldername1 := "f1"
foldername2 := "f2"
folder1 := vfs.BaseVirtualFolder{
Name: foldername1,
MappedPath: filepath.Join(os.TempDir(), foldername1),
}
folder2 := vfs.BaseVirtualFolder{
Name: foldername2,
MappedPath: filepath.Join(os.TempDir(), foldername2),
}
err = dataprovider.AddFolder(&folder1, "", "")
assert.NoError(t, err)
err = dataprovider.AddFolder(&folder2, "", "")
assert.NoError(t, err)
action = dataprovider.BaseEventAction{
Type: dataprovider.ActionTypeFolderQuotaReset,
}
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{
Names: []dataprovider.ConditionPattern{
{
Pattern: foldername1,
},
},
})
assert.Error(t, err) // no home dir
err = os.MkdirAll(folder1.MappedPath, os.ModePerm)
assert.NoError(t, err)
err = os.WriteFile(filepath.Join(folder1.MappedPath, "file.txt"), []byte("folder"), 0666)
assert.NoError(t, err)
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{
Names: []dataprovider.ConditionPattern{
{
Pattern: foldername1,
},
},
})
assert.NoError(t, err)
folderGet, err := dataprovider.GetFolderByName(foldername1)
assert.NoError(t, err)
assert.Equal(t, 1, folderGet.UsedQuotaFiles)
assert.Equal(t, int64(6), folderGet.UsedQuotaSize)
// simulate another quota scan in progress
assert.True(t, QuotaScans.AddVFolderQuotaScan(foldername1))
err = executeRuleAction(action, EventParams{}, dataprovider.ConditionOptions{
Names: []dataprovider.ConditionPattern{
{
Pattern: foldername1,
},
},
})
assert.Error(t, err)
assert.True(t, QuotaScans.RemoveVFolderQuotaScan(foldername1))
err = os.RemoveAll(folder1.MappedPath)
assert.NoError(t, err)
err = dataprovider.DeleteFolder(foldername1, "", "")
assert.NoError(t, err)
err = dataprovider.DeleteFolder(foldername2, "", "")
assert.NoError(t, err)
}
func TestQuotaActionsWithQuotaTrackDisabled(t *testing.T) {
oldProviderConf := dataprovider.GetProviderConfig()
providerConf := dataprovider.GetProviderConfig()
providerConf.TrackQuota = 0
err := dataprovider.Close()
assert.NoError(t, err)
err = dataprovider.Initialize(providerConf, configDir, true)
assert.NoError(t, err)
username := "u1"
user := dataprovider.User{
BaseUser: sdk.BaseUser{
Username: username,
HomeDir: filepath.Join(os.TempDir(), username),
Status: 1,
Permissions: map[string][]string{
"/": {dataprovider.PermAny},
},
},
FsConfig: vfs.Filesystem{
Provider: sdk.LocalFilesystemProvider,
},
}
err = dataprovider.AddUser(&user, "", "")
assert.NoError(t, err)
err = os.MkdirAll(user.GetHomeDir(), os.ModePerm)
assert.NoError(t, err)
err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeUserQuotaReset},
EventParams{}, dataprovider.ConditionOptions{
Names: []dataprovider.ConditionPattern{
{
Pattern: username,
},
},
})
assert.Error(t, err)
err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeTransferQuotaReset},
EventParams{}, dataprovider.ConditionOptions{
Names: []dataprovider.ConditionPattern{
{
Pattern: username,
},
},
})
assert.Error(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
err = dataprovider.DeleteUser(username, "", "")
assert.NoError(t, err)
foldername := "f1"
folder := vfs.BaseVirtualFolder{
Name: foldername,
MappedPath: filepath.Join(os.TempDir(), foldername),
}
err = dataprovider.AddFolder(&folder, "", "")
assert.NoError(t, err)
err = os.MkdirAll(folder.MappedPath, os.ModePerm)
assert.NoError(t, err)
err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeFolderQuotaReset},
EventParams{}, dataprovider.ConditionOptions{
Names: []dataprovider.ConditionPattern{
{
Pattern: foldername,
},
},
})
assert.Error(t, err)
err = os.RemoveAll(folder.MappedPath)
assert.NoError(t, err)
err = dataprovider.DeleteFolder(foldername, "", "")
assert.NoError(t, err)
err = dataprovider.Close()
assert.NoError(t, err)
err = dataprovider.Initialize(oldProviderConf, configDir, true)
assert.NoError(t, err)
}
func TestScheduledActions(t *testing.T) {
startEventScheduler()
backupsPath := filepath.Join(os.TempDir(), "backups")
err := os.RemoveAll(backupsPath)
assert.NoError(t, err)
action := &dataprovider.BaseEventAction{
Name: "action",
Type: dataprovider.ActionTypeBackup,
}
err = dataprovider.AddEventAction(action, "", "")
assert.NoError(t, err)
rule := &dataprovider.EventRule{
Name: "rule",
Trigger: dataprovider.EventTriggerSchedule,
Conditions: dataprovider.EventConditions{
Schedules: []dataprovider.Schedule{
{
Hours: "11",
DayOfWeek: "*",
DayOfMonth: "*",
Month: "*",
},
},
},
Actions: []dataprovider.EventAction{
{
BaseEventAction: dataprovider.BaseEventAction{
Name: action.Name,
},
Order: 1,
},
},
}
job := eventCronJob{
ruleName: rule.Name,
}
job.Run() // rule not found
assert.NoDirExists(t, backupsPath)
err = dataprovider.AddEventRule(rule, "", "")
assert.NoError(t, err)
job.Run()
assert.DirExists(t, backupsPath)
err = dataprovider.DeleteEventRule(rule.Name, "", "")
assert.NoError(t, err)
err = dataprovider.DeleteEventAction(action.Name, "", "")
assert.NoError(t, err)
err = os.RemoveAll(backupsPath)
assert.NoError(t, err)
stopEventScheduler()
}

View file

@ -0,0 +1,43 @@
// Copyright (C) 2019-2022 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package common
import (
"time"
"github.com/robfig/cron/v3"
"github.com/drakkan/sftpgo/v2/internal/util"
)
var (
eventScheduler *cron.Cron
)
func stopEventScheduler() {
if eventScheduler != nil {
eventScheduler.Stop()
eventScheduler = nil
}
}
func startEventScheduler() {
stopEventScheduler()
eventScheduler = cron.New(cron.WithLocation(time.UTC))
_, err := eventScheduler.AddFunc("@every 10m", eventManager.loadRules)
util.PanicOnError(err)
eventScheduler.Start()
}

View file

@ -18,6 +18,7 @@ import (
"bufio"
"bytes"
"crypto/rand"
"encoding/json"
"fmt"
"io"
"math"
@ -77,6 +78,7 @@ var (
allPerms = []string{dataprovider.PermAny}
homeBasePath string
logFilePath string
backupsPath string
testFileContent = []byte("test data")
lastReceivedEmail receivedEmail
)
@ -84,6 +86,7 @@ var (
func TestMain(m *testing.M) {
homeBasePath = os.TempDir()
logFilePath = filepath.Join(configDir, "common_test.log")
backupsPath = filepath.Join(os.TempDir(), "backups")
logger.InitLogger(logFilePath, 5, 1, 28, false, false, zerolog.DebugLevel)
os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1")
@ -95,6 +98,7 @@ func TestMain(m *testing.M) {
os.Exit(1)
}
providerConf := config.GetProviderConf()
providerConf.BackupsPath = backupsPath
logger.InfoToConsole("Starting COMMON tests, provider: %v", providerConf.Driver)
err = common.Initialize(config.GetCommonConfig(), 0)
@ -203,6 +207,7 @@ func TestMain(m *testing.M) {
exitCode := m.Run()
os.Remove(logFilePath)
os.RemoveAll(backupsPath)
os.Exit(exitCode)
}
@ -2848,7 +2853,7 @@ func TestEventRule(t *testing.T) {
EmailConfig: dataprovider.EventActionEmailConfig{
Recipients: []string{"test1@example.com", "test2@example.com"},
Subject: `New "{{Event}}" from "{{Name}}"`,
Body: "Fs path {{FsPath}}, size: {{FileSize}}, protocol: {{Protocol}}, IP: {{IP}}",
Body: "Fs path {{FsPath}}, size: {{FileSize}}, protocol: {{Protocol}}, IP: {{IP}} Data: {{ObjectData}}",
},
},
}
@ -2987,6 +2992,10 @@ func TestEventRule(t *testing.T) {
Key: "SFTPGO_ACTION_PATH",
Value: "{{FsPath}}",
},
{
Key: "CUSTOM_ENV_VAR",
Value: "value",
},
},
},
}
@ -3090,6 +3099,176 @@ func TestEventRule(t *testing.T) {
require.NoError(t, err)
}
func TestEventRuleProviderEvents(t *testing.T) {
if runtime.GOOS == osWindows {
t.Skip("this test is not available on Windows")
}
smtpCfg := smtp.Config{
Host: "127.0.0.1",
Port: 2525,
From: "notification@example.com",
TemplatesPath: "templates",
}
err := smtpCfg.Initialize(configDir)
require.NoError(t, err)
saveObjectScriptPath := filepath.Join(os.TempDir(), "provider.sh")
outPath := filepath.Join(os.TempDir(), "provider_out.json")
err = os.WriteFile(saveObjectScriptPath, getSaveProviderObjectScriptContent(outPath, 0), 0755)
assert.NoError(t, err)
a1 := dataprovider.BaseEventAction{
Name: "a1",
Type: dataprovider.ActionTypeCommand,
Options: dataprovider.BaseEventActionOptions{
CmdConfig: dataprovider.EventActionCommandConfig{
Cmd: saveObjectScriptPath,
Timeout: 10,
EnvVars: []dataprovider.KeyValue{
{
Key: "SFTPGO_OBJECT_DATA",
Value: "{{ObjectData}}",
},
},
},
},
}
a2 := dataprovider.BaseEventAction{
Name: "a2",
Type: dataprovider.ActionTypeEmail,
Options: dataprovider.BaseEventActionOptions{
EmailConfig: dataprovider.EventActionEmailConfig{
Recipients: []string{"test3@example.com"},
Subject: `New "{{Event}}" from "{{Name}}"`,
Body: "Object name: {{ObjectName}} object type: {{ObjectType}} Data: {{ObjectData}}",
},
},
}
a3 := dataprovider.BaseEventAction{
Name: "a3",
Type: dataprovider.ActionTypeEmail,
Options: dataprovider.BaseEventActionOptions{
EmailConfig: dataprovider.EventActionEmailConfig{
Recipients: []string{"failure@example.com"},
Subject: `Failed "{{Event}}" from "{{Name}}"`,
Body: "Object name: {{ObjectName}} object type: {{ObjectType}}, IP: {{IP}}",
},
},
}
action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated)
assert.NoError(t, err)
action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated)
assert.NoError(t, err)
action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated)
assert.NoError(t, err)
r := dataprovider.EventRule{
Name: "rule",
Trigger: dataprovider.EventTriggerProviderEvent,
Conditions: dataprovider.EventConditions{
ProviderEvents: []string{"update"},
},
Actions: []dataprovider.EventAction{
{
BaseEventAction: dataprovider.BaseEventAction{
Name: action1.Name,
},
Order: 1,
Options: dataprovider.EventActionOptions{
StopOnFailure: true,
},
},
{
BaseEventAction: dataprovider.BaseEventAction{
Name: action2.Name,
},
Order: 2,
},
{
BaseEventAction: dataprovider.BaseEventAction{
Name: action3.Name,
},
Order: 3,
Options: dataprovider.EventActionOptions{
IsFailureAction: true,
StopOnFailure: true,
},
},
},
}
rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated)
assert.NoError(t, err)
lastReceivedEmail.reset()
// create and update a folder to trigger the rule
folder := vfs.BaseVirtualFolder{
Name: "ftest rule",
MappedPath: filepath.Join(os.TempDir(), "p"),
}
folder, _, err = httpdtest.AddFolder(folder, http.StatusCreated)
assert.NoError(t, err)
// no action is triggered on add
assert.NoFileExists(t, outPath)
// update the folder
_, _, err = httpdtest.UpdateFolder(folder, http.StatusOK)
assert.NoError(t, err)
if assert.FileExists(t, outPath) {
content, err := os.ReadFile(outPath)
assert.NoError(t, err)
var folderGet vfs.BaseVirtualFolder
err = json.Unmarshal(content, &folderGet)
assert.NoError(t, err)
assert.Equal(t, folder, folderGet)
err = os.Remove(outPath)
assert.NoError(t, err)
assert.Eventually(t, func() bool {
return lastReceivedEmail.get().From != ""
}, 3000*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "test3@example.com"))
assert.Contains(t, string(email.Data), `Subject: New "update" from "admin"`)
}
// now delete the script to generate an error
lastReceivedEmail.reset()
err = os.Remove(saveObjectScriptPath)
assert.NoError(t, err)
_, _, err = httpdtest.UpdateFolder(folder, http.StatusOK)
assert.NoError(t, err)
assert.NoFileExists(t, outPath)
assert.Eventually(t, func() bool {
return lastReceivedEmail.get().From != ""
}, 3000*time.Millisecond, 100*time.Millisecond)
email := lastReceivedEmail.get()
assert.Len(t, email.To, 1)
assert.True(t, util.Contains(email.To, "failure@example.com"))
assert.Contains(t, string(email.Data), `Subject: Failed "update" from "admin"`)
assert.Contains(t, string(email.Data), fmt.Sprintf("Object name: %s object type: folder", folder.Name))
lastReceivedEmail.reset()
// generate an error for the failure action
smtpCfg = smtp.Config{}
err = smtpCfg.Initialize(configDir)
require.NoError(t, err)
_, _, err = httpdtest.UpdateFolder(folder, http.StatusOK)
assert.NoError(t, err)
assert.NoFileExists(t, outPath)
email = lastReceivedEmail.get()
assert.Len(t, email.To, 0)
_, err = httpdtest.RemoveFolder(folder, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveEventRule(rule, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveEventAction(action1, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveEventAction(action2, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveEventAction(action3, http.StatusOK)
assert.NoError(t, err)
}
func TestSyncUploadAction(t *testing.T) {
if runtime.GOOS == osWindows {
t.Skip("this test is not available on Windows")
@ -4113,6 +4292,13 @@ func getUploadScriptContent(movedPath string, exitStatus int) []byte {
return content
}
func getSaveProviderObjectScriptContent(outFilePath string, exitStatus int) []byte {
content := []byte("#!/bin/sh\n\n")
content = append(content, []byte(fmt.Sprintf("echo ${SFTPGO_OBJECT_DATA} > %v\n", outFilePath))...)
content = append(content, []byte(fmt.Sprintf("exit %d", exitStatus))...)
return content
}
func generateTOTPPasscode(secret string, algo otp.Algorithm) (string, error) {
return totp.GenerateCodeCustom(secret, time.Now(), totp.ValidateOpts{
Period: 30,

View file

@ -63,17 +63,8 @@ func executeAction(operation, executor, ip, objectType, objectName string, objec
Timestamp: time.Now().UnixNano(),
}, object)
}
if EventManager.hasProviderEvents() {
EventManager.handleProviderEvent(EventParams{
Name: executor,
ObjectName: objectName,
Event: operation,
Status: 1,
ObjectType: objectType,
IP: ip,
Timestamp: time.Now().UnixNano(),
Object: object,
})
if fnHandleRuleForProviderEvent != nil {
fnHandleRuleForProviderEvent(operation, executor, ip, objectType, objectName, object)
}
if config.Actions.Hook == "" {
return

View file

@ -191,6 +191,9 @@ var (
lastLoginMinDelay = 10 * time.Minute
usernameRegex = regexp.MustCompile("^[a-zA-Z0-9-_.~]+$")
tempPath string
fnReloadRules FnReloadRules
fnRemoveRule FnRemoveRule
fnHandleRuleForProviderEvent FnHandleRuleForProviderEvent
)
func initSQLTables() {
@ -214,6 +217,22 @@ func initSQLTables() {
sqlTableSchemaVersion = "schema_version"
}
// FnReloadRules defined the callback to reload event rules
type FnReloadRules func()
// FnRemoveRule defines the callback to remove an event rule
type FnRemoveRule func(name string)
// FnHandleRuleForProviderEvent define the callback to handle event rules for provider events
type FnHandleRuleForProviderEvent func(operation, executor, ip, objectType, objectName string, object plugin.Renderer)
// SetEventRulesCallbacks sets the event rules callbacks
func SetEventRulesCallbacks(reload FnReloadRules, remove FnRemoveRule, handle FnHandleRuleForProviderEvent) {
fnReloadRules = reload
fnRemoveRule = remove
fnHandleRuleForProviderEvent = handle
}
type schemaVersion struct {
Version int
}
@ -487,31 +506,36 @@ func (c *Config) requireCustomTLSForMySQL() bool {
func (c *Config) doBackup() error {
now := time.Now().UTC()
outputFile := filepath.Join(c.BackupsPath, fmt.Sprintf("backup_%s_%d.json", now.Weekday(), now.Hour()))
eventManagerLog(logger.LevelDebug, "starting backup to file %q", outputFile)
providerLog(logger.LevelDebug, "starting backup to file %q", outputFile)
err := os.MkdirAll(filepath.Dir(outputFile), 0700)
if err != nil {
eventManagerLog(logger.LevelError, "unable to create backup dir %q: %v", outputFile, err)
providerLog(logger.LevelError, "unable to create backup dir %q: %v", outputFile, err)
return fmt.Errorf("unable to create backup dir: %w", err)
}
backup, err := DumpData()
if err != nil {
eventManagerLog(logger.LevelError, "unable to execute backup: %v", err)
providerLog(logger.LevelError, "unable to execute backup: %v", err)
return fmt.Errorf("unable to dump backup data: %w", err)
}
dump, err := json.Marshal(backup)
if err != nil {
eventManagerLog(logger.LevelError, "unable to marshal backup as JSON: %v", err)
providerLog(logger.LevelError, "unable to marshal backup as JSON: %v", err)
return fmt.Errorf("unable to marshal backup data as JSON: %w", err)
}
err = os.WriteFile(outputFile, dump, 0600)
if err != nil {
eventManagerLog(logger.LevelError, "unable to save backup: %v", err)
providerLog(logger.LevelError, "unable to save backup: %v", err)
return fmt.Errorf("unable to save backup: %w", err)
}
eventManagerLog(logger.LevelDebug, "auto backup saved to %q", outputFile)
providerLog(logger.LevelDebug, "backup saved to %q", outputFile)
return nil
}
// ExecuteBackup executes a backup
func ExecuteBackup() error {
return config.doBackup()
}
// ConvertName converts the given name based on the configured rules
func ConvertName(name string) string {
return config.convertName(name)
@ -1568,7 +1592,9 @@ func AddEventAction(action *BaseEventAction, executor, ipAddress string) error {
func UpdateEventAction(action *BaseEventAction, executor, ipAddress string) error {
err := provider.updateEventAction(action)
if err == nil {
EventManager.loadRules()
if fnReloadRules != nil {
fnReloadRules()
}
executeAction(operationUpdate, executor, ipAddress, actionObjectEventAction, action.Name, action)
}
return err
@ -1597,6 +1623,11 @@ func GetEventRules(limit, offset int, order string) ([]EventRule, error) {
return provider.getEventRules(limit, offset, order)
}
// GetRecentlyUpdatedRules returns the event rules updated after the specified time
func GetRecentlyUpdatedRules(after int64) ([]EventRule, error) {
return provider.getRecentlyUpdatedRules(after)
}
// EventRuleExists returns the event rule with the given name if it exists
func EventRuleExists(name string) (EventRule, error) {
name = config.convertName(name)
@ -1608,7 +1639,9 @@ func AddEventRule(rule *EventRule, executor, ipAddress string) error {
rule.Name = config.convertName(rule.Name)
err := provider.addEventRule(rule)
if err == nil {
EventManager.loadRules()
if fnReloadRules != nil {
fnReloadRules()
}
executeAction(operationAdd, executor, ipAddress, actionObjectEventRule, rule.Name, rule)
}
return err
@ -1618,7 +1651,9 @@ func AddEventRule(rule *EventRule, executor, ipAddress string) error {
func UpdateEventRule(rule *EventRule, executor, ipAddress string) error {
err := provider.updateEventRule(rule)
if err == nil {
EventManager.loadRules()
if fnReloadRules != nil {
fnReloadRules()
}
executeAction(operationUpdate, executor, ipAddress, actionObjectEventRule, rule.Name, rule)
}
return err
@ -1633,12 +1668,39 @@ func DeleteEventRule(name string, executor, ipAddress string) error {
}
err = provider.deleteEventRule(rule, config.IsShared == 1)
if err == nil {
EventManager.RemoveRule(rule.Name)
if fnRemoveRule != nil {
fnRemoveRule(rule.Name)
}
executeAction(operationDelete, executor, ipAddress, actionObjectEventRule, rule.Name, &rule)
}
return err
}
// RemoveEventRule delets an existing event rule without marking it as deleted
func RemoveEventRule(rule EventRule) error {
return provider.deleteEventRule(rule, false)
}
// GetTaskByName returns the task with the specified name
func GetTaskByName(name string) (Task, error) {
return provider.getTaskByName(name)
}
// AddTask add a task with the specified name
func AddTask(name string) error {
return provider.addTask(name)
}
// UpdateTask updates the task with the specified name and version
func UpdateTask(name string, version int64) error {
return provider.updateTask(name, version)
}
// UpdateTaskTimestamp updates the timestamp for the task with the specified name
func UpdateTaskTimestamp(name string) error {
return provider.updateTaskTimestamp(name)
}
// HasAdmin returns true if the first admin has been created
// and so SFTPGo is ready to be used
func HasAdmin() bool {
@ -1971,6 +2033,16 @@ func GetFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtua
return provider.getFolders(limit, offset, order, minimal)
}
// DumpUsers returns all users, including confidential data
func DumpUsers() ([]User, error) {
return provider.dumpUsers()
}
// DumpFolders returns all folders, including confidential data
func DumpFolders() ([]vfs.BaseVirtualFolder, error) {
return provider.dumpFolders()
}
// DumpData returns all users, groups, folders, admins, api keys, shares, actions, rules
func DumpData() (BackupData, error) {
var data BackupData

View file

@ -15,16 +15,10 @@
package dataprovider
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
@ -34,9 +28,7 @@ import (
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
)
// Supported event actions
@ -204,25 +196,8 @@ func (c *EventActionHTTPConfig) validate(additionalData string) error {
return nil
}
func (c *EventActionHTTPConfig) getEndpoint(replacer *strings.Replacer) (string, error) {
if len(c.QueryParameters) > 0 {
u, err := url.Parse(c.Endpoint)
if err != nil {
return "", fmt.Errorf("invalid endpoint: %w", err)
}
q := u.Query()
for _, keyVal := range c.QueryParameters {
q.Add(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer))
}
u.RawQuery = q.Encode()
return u.String(), nil
}
return c.Endpoint, nil
}
func (c *EventActionHTTPConfig) getHTTPClient() *http.Client {
// GetHTTPClient returns an HTTP client based on the config
func (c *EventActionHTTPConfig) GetHTTPClient() *http.Client {
client := &http.Client{
Timeout: time.Duration(c.Timeout) * time.Second,
}
@ -241,63 +216,6 @@ func (c *EventActionHTTPConfig) getHTTPClient() *http.Client {
return client
}
func (c *EventActionHTTPConfig) execute(params EventParams) error {
if !c.Password.IsEmpty() {
if err := c.Password.TryDecrypt(); err != nil {
return fmt.Errorf("unable to decrypt password: %w", err)
}
}
addObjectData := false
if params.Object != nil {
if !addObjectData {
if strings.Contains(c.Body, "{{ObjectData}}") {
addObjectData = true
}
}
}
replacements := params.getStringReplacements(addObjectData)
replacer := strings.NewReplacer(replacements...)
endpoint, err := c.getEndpoint(replacer)
if err != nil {
return err
}
var body io.Reader
if c.Body != "" && c.Method != http.MethodGet {
body = bytes.NewBufferString(replaceWithReplacer(c.Body, replacer))
}
req, err := http.NewRequest(c.Method, endpoint, body)
if err != nil {
return err
}
if c.Username != "" {
req.SetBasicAuth(replaceWithReplacer(c.Username, replacer), c.Password.GetAdditionalData())
}
for _, keyVal := range c.Headers {
req.Header.Set(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer))
}
client := c.getHTTPClient()
defer client.CloseIdleConnections()
startTime := time.Now()
resp, err := client.Do(req)
if err != nil {
eventManagerLog(logger.LevelDebug, "unable to send http notification, endpoint: %s, elapsed: %s, err: %v",
endpoint, time.Since(startTime), err)
return err
}
defer resp.Body.Close()
eventManagerLog(logger.LevelDebug, "http notification sent, endopoint: %s, elapsed: %s, status code: %d",
endpoint, time.Since(startTime), resp.StatusCode)
if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
return nil
}
// EventActionCommandConfig defines the configuration for a command event target
type EventActionCommandConfig struct {
Cmd string `json:"cmd"`
@ -323,43 +241,6 @@ func (c *EventActionCommandConfig) validate() error {
return nil
}
func (c *EventActionCommandConfig) getEnvVars(params EventParams) []string {
envVars := make([]string, 0, len(c.EnvVars))
addObjectData := false
if params.Object != nil {
for _, k := range c.EnvVars {
if strings.Contains(k.Value, "{{ObjectData}}") {
addObjectData = true
break
}
}
}
replacements := params.getStringReplacements(addObjectData)
replacer := strings.NewReplacer(replacements...)
for _, keyVal := range c.EnvVars {
envVars = append(envVars, fmt.Sprintf("%s=%s", keyVal.Key, replaceWithReplacer(keyVal.Value, replacer)))
}
return envVars
}
func (c *EventActionCommandConfig) execute(params EventParams) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.Timeout)*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, c.Cmd)
cmd.Env = append(cmd.Env, os.Environ()...)
cmd.Env = append(cmd.Env, c.getEnvVars(params)...)
startTime := time.Now()
err := cmd.Run()
eventManagerLog(logger.LevelDebug, "executed command %q, elapsed: %s, error: %v",
c.Cmd, time.Since(startTime), err)
return err
}
// EventActionEmailConfig defines the configuration options for SMTP event actions
type EventActionEmailConfig struct {
Recipients []string `json:"recipients"`
@ -391,24 +272,6 @@ func (o *EventActionEmailConfig) validate() error {
return nil
}
func (o *EventActionEmailConfig) execute(params EventParams) error {
addObjectData := false
if params.Object != nil {
if strings.Contains(o.Body, "{{ObjectData}}") {
addObjectData = true
}
}
replacements := params.getStringReplacements(addObjectData)
replacer := strings.NewReplacer(replacements...)
body := replaceWithReplacer(o.Body, replacer)
subject := replaceWithReplacer(o.Subject, replacer)
startTime := time.Now()
err := smtp.SendEmail(o.Recipients, subject, body, smtp.EmailContentTypeTextPlain)
eventManagerLog(logger.LevelDebug, "executed email notification action, elapsed: %s, error: %v",
time.Since(startTime), err)
return err
}
// BaseEventActionOptions defines the supported configuration options for a base event actions
type BaseEventActionOptions struct {
HTTPConfig EventActionHTTPConfig `json:"http_config"`
@ -560,130 +423,6 @@ func (a *BaseEventAction) validate() error {
return a.Options.validate(a.Type, a.Name)
}
func (a *BaseEventAction) doUsersQuotaReset(conditions ConditionOptions) error {
users, err := provider.dumpUsers()
if err != nil {
return fmt.Errorf("unable to get users: %w", err)
}
var failedResets []string
for _, user := range users {
if !checkConditionPatterns(user.Username, conditions.Names) {
eventManagerLog(logger.LevelDebug, "skipping scheduled quota reset for user %s, name conditions don't match",
user.Username)
continue
}
if !QuotaScans.AddUserQuotaScan(user.Username) {
eventManagerLog(logger.LevelError, "another quota scan is already in progress for user %s", user.Username)
failedResets = append(failedResets, user.Username)
continue
}
numFiles, size, err := user.ScanQuota()
QuotaScans.RemoveUserQuotaScan(user.Username)
if err != nil {
eventManagerLog(logger.LevelError, "error scanning quota for user %s: %v", user.Username, err)
failedResets = append(failedResets, user.Username)
continue
}
err = UpdateUserQuota(&user, numFiles, size, true)
if err != nil {
eventManagerLog(logger.LevelError, "error updating quota for user %s: %v", user.Username, err)
failedResets = append(failedResets, user.Username)
continue
}
}
if len(failedResets) > 0 {
return fmt.Errorf("quota reset failed for users: %+v", failedResets)
}
return nil
}
func (a *BaseEventAction) doFoldersQuotaReset(conditions ConditionOptions) error {
folders, err := provider.dumpFolders()
if err != nil {
return fmt.Errorf("unable to get folders: %w", err)
}
var failedResets []string
for _, folder := range folders {
if !checkConditionPatterns(folder.Name, conditions.Names) {
eventManagerLog(logger.LevelDebug, "skipping scheduled quota reset for folder %s, name conditions don't match",
folder.Name)
continue
}
if !QuotaScans.AddVFolderQuotaScan(folder.Name) {
eventManagerLog(logger.LevelError, "another quota scan is already in progress for folder %s", folder.Name)
failedResets = append(failedResets, folder.Name)
continue
}
f := vfs.VirtualFolder{
BaseVirtualFolder: folder,
VirtualPath: "/",
}
numFiles, size, err := f.ScanQuota()
QuotaScans.RemoveVFolderQuotaScan(folder.Name)
if err != nil {
eventManagerLog(logger.LevelError, "error scanning quota for folder %s: %v", folder.Name, err)
failedResets = append(failedResets, folder.Name)
continue
}
err = UpdateVirtualFolderQuota(&folder, numFiles, size, true)
if err != nil {
eventManagerLog(logger.LevelError, "error updating quota for folder %s: %v", folder.Name, err)
failedResets = append(failedResets, folder.Name)
continue
}
}
if len(failedResets) > 0 {
return fmt.Errorf("quota reset failed for folders: %+v", failedResets)
}
return nil
}
func (a *BaseEventAction) doTransferQuotaReset(conditions ConditionOptions) error {
users, err := provider.dumpUsers()
if err != nil {
return fmt.Errorf("unable to get users: %w", err)
}
var failedResets []string
for _, user := range users {
if !checkConditionPatterns(user.Username, conditions.Names) {
eventManagerLog(logger.LevelDebug, "skipping scheduled transfer quota reset for user %s, name conditions don't match",
user.Username)
continue
}
err = UpdateUserTransferQuota(&user, 0, 0, true)
if err != nil {
eventManagerLog(logger.LevelError, "error updating transfer quota for user %s: %v", user.Username, err)
failedResets = append(failedResets, user.Username)
continue
}
}
if len(failedResets) > 0 {
return fmt.Errorf("transfer quota reset failed for users: %+v", failedResets)
}
return nil
}
func (a *BaseEventAction) execute(params EventParams, conditions ConditionOptions) error {
switch a.Type {
case ActionTypeHTTP:
return a.Options.HTTPConfig.execute(params)
case ActionTypeCommand:
return a.Options.CmdConfig.execute(params)
case ActionTypeEmail:
return a.Options.EmailConfig.execute(params)
case ActionTypeBackup:
return config.doBackup()
case ActionTypeUserQuotaReset:
return a.doUsersQuotaReset(conditions)
case ActionTypeFolderQuotaReset:
return a.doFoldersQuotaReset(conditions)
case ActionTypeTransferQuotaReset:
return a.doTransferQuotaReset(conditions)
default:
return fmt.Errorf("unsupported action type: %d", a.Type)
}
}
// EventActionOptions defines the supported configuration options for an event action
type EventActionOptions struct {
IsFailureAction bool `json:"is_failure_action"`
@ -731,18 +470,6 @@ type ConditionPattern struct {
InverseMatch bool `json:"inverse_match,omitempty"`
}
func (p *ConditionPattern) match(name string) bool {
matched, err := path.Match(p.Pattern, name)
if err != nil {
eventManagerLog(logger.LevelError, "pattern matching error %q, err: %v", p.Pattern, err)
return false
}
if p.InverseMatch {
return !matched
}
return matched
}
func (p *ConditionPattern) validate() error {
if p.Pattern == "" {
return util.NewValidationError("empty condition pattern not allowed")
@ -826,12 +553,13 @@ type Schedule struct {
Month string `json:"month"`
}
func (s *Schedule) getCronSpec() string {
// GetCronSpec returns the cron compatible schedule string
func (s *Schedule) GetCronSpec() string {
return fmt.Sprintf("0 %s %s %s %s", s.Hours, s.DayOfMonth, s.Month, s.DayOfWeek)
}
func (s *Schedule) validate() error {
_, err := cron.ParseStandard(s.getCronSpec())
_, err := cron.ParseStandard(s.GetCronSpec())
if err != nil {
return util.NewValidationError(fmt.Sprintf("invalid schedule, hour: %q, day of month: %q, month: %q, day of week: %q",
s.Hours, s.DayOfMonth, s.Month, s.DayOfWeek))
@ -871,51 +599,6 @@ func (c *EventConditions) getACopy() EventConditions {
}
}
// ProviderEventMatch returns true if the specified provider event match
func (c *EventConditions) ProviderEventMatch(params EventParams) bool {
if !util.Contains(c.ProviderEvents, params.Event) {
return false
}
if !checkConditionPatterns(params.Name, c.Options.Names) {
return false
}
if len(c.Options.ProviderObjects) > 0 && !util.Contains(c.Options.ProviderObjects, params.ObjectType) {
return false
}
return true
}
// FsEventMatch returns true if the specified filesystem event match
func (c *EventConditions) FsEventMatch(params EventParams) bool {
if !util.Contains(c.FsEvents, params.Event) {
return false
}
if !checkConditionPatterns(params.Name, c.Options.Names) {
return false
}
if !checkConditionPatterns(params.VirtualPath, c.Options.FsPaths) {
if !checkConditionPatterns(params.ObjectName, c.Options.FsPaths) {
return false
}
}
if len(c.Options.Protocols) > 0 && !util.Contains(c.Options.Protocols, params.Protocol) {
return false
}
if params.Event == "upload" || params.Event == "download" {
if c.Options.MinFileSize > 0 {
if params.FileSize < c.Options.MinFileSize {
return false
}
}
if c.Options.MaxFileSize > 0 {
if params.FileSize > c.Options.MaxFileSize {
return false
}
}
}
return true
}
func (c *EventConditions) validate(trigger int) error {
switch trigger {
case EventTriggerFsEvent:
@ -1015,7 +698,9 @@ func (r *EventRule) getACopy() EventRule {
}
}
func (r *EventRule) guardFromConcurrentExecution() bool {
// GuardFromConcurrentExecution returns true if the rule cannot be executed concurrently
// from multiple instances
func (r *EventRule) GuardFromConcurrentExecution() bool {
if config.IsShared == 0 {
return false
}
@ -1102,6 +787,28 @@ func (r *EventRule) RenderAsJSON(reload bool) ([]byte, error) {
return json.Marshal(r)
}
func cloneKeyValues(keyVals []KeyValue) []KeyValue {
res := make([]KeyValue, 0, len(keyVals))
for _, kv := range keyVals {
res = append(res, KeyValue{
Key: kv.Key,
Value: kv.Value,
})
}
return res
}
func cloneConditionPatterns(patterns []ConditionPattern) []ConditionPattern {
res := make([]ConditionPattern, 0, len(patterns))
for _, p := range patterns {
res = append(res, ConditionPattern{
Pattern: p.Pattern,
InverseMatch: p.InverseMatch,
})
}
return res
}
// Task stores the state for a scheduled task
type Task struct {
Name string `json:"name"`

View file

@ -1,474 +0,0 @@
// Copyright (C) 2019-2022 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package dataprovider
import (
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/robfig/cron/v3"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/util"
)
var (
// EventManager handle the supported event rules actions
EventManager EventRulesContainer
)
func init() {
EventManager = EventRulesContainer{
schedulesMapping: make(map[string][]cron.EntryID),
}
}
// EventRulesContainer stores event rules by trigger
type EventRulesContainer struct {
sync.RWMutex
FsEvents []EventRule
ProviderEvents []EventRule
Schedules []EventRule
schedulesMapping map[string][]cron.EntryID
lastLoad int64
}
func (r *EventRulesContainer) getLastLoadTime() int64 {
return atomic.LoadInt64(&r.lastLoad)
}
func (r *EventRulesContainer) setLastLoadTime(modTime int64) {
atomic.StoreInt64(&r.lastLoad, modTime)
}
// RemoveRule deletes the rule with the specified name
func (r *EventRulesContainer) RemoveRule(name string) {
r.Lock()
defer r.Unlock()
r.removeRuleInternal(name)
eventManagerLog(logger.LevelDebug, "event rules updated after delete, fs events: %d, provider events: %d, schedules: %d",
len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules))
}
func (r *EventRulesContainer) removeRuleInternal(name string) {
for idx := range r.FsEvents {
if r.FsEvents[idx].Name == name {
lastIdx := len(r.FsEvents) - 1
r.FsEvents[idx] = r.FsEvents[lastIdx]
r.FsEvents = r.FsEvents[:lastIdx]
eventManagerLog(logger.LevelDebug, "removed rule %q from fs events", name)
return
}
}
for idx := range r.ProviderEvents {
if r.ProviderEvents[idx].Name == name {
lastIdx := len(r.ProviderEvents) - 1
r.ProviderEvents[idx] = r.ProviderEvents[lastIdx]
r.ProviderEvents = r.ProviderEvents[:lastIdx]
eventManagerLog(logger.LevelDebug, "removed rule %q from provider events", name)
return
}
}
for idx := range r.Schedules {
if r.Schedules[idx].Name == name {
if schedules, ok := r.schedulesMapping[name]; ok {
for _, entryID := range schedules {
eventManagerLog(logger.LevelDebug, "removing scheduled entry id %d for rule %q", entryID, name)
scheduler.Remove(entryID)
}
delete(r.schedulesMapping, name)
}
lastIdx := len(r.Schedules) - 1
r.Schedules[idx] = r.Schedules[lastIdx]
r.Schedules = r.Schedules[:lastIdx]
eventManagerLog(logger.LevelDebug, "removed rule %q from scheduled events", name)
return
}
}
}
func (r *EventRulesContainer) addUpdateRuleInternal(rule EventRule) {
r.removeRuleInternal(rule.Name)
if rule.DeletedAt > 0 {
deletedAt := util.GetTimeFromMsecSinceEpoch(rule.DeletedAt)
if deletedAt.Add(30 * time.Minute).Before(time.Now()) {
eventManagerLog(logger.LevelDebug, "removing rule %q deleted at %s", rule.Name, deletedAt)
go provider.deleteEventRule(rule, false) //nolint:errcheck
}
return
}
switch rule.Trigger {
case EventTriggerFsEvent:
r.FsEvents = append(r.FsEvents, rule)
eventManagerLog(logger.LevelDebug, "added rule %q to fs events", rule.Name)
case EventTriggerProviderEvent:
r.ProviderEvents = append(r.ProviderEvents, rule)
eventManagerLog(logger.LevelDebug, "added rule %q to provider events", rule.Name)
case EventTriggerSchedule:
r.Schedules = append(r.Schedules, rule)
eventManagerLog(logger.LevelDebug, "added rule %q to scheduled events", rule.Name)
for _, schedule := range rule.Conditions.Schedules {
cronSpec := schedule.getCronSpec()
job := &cronJob{
ruleName: ConvertName(rule.Name),
}
entryID, err := scheduler.AddJob(cronSpec, job)
if err != nil {
eventManagerLog(logger.LevelError, "unable to add scheduled rule %q: %v", rule.Name, err)
} else {
r.schedulesMapping[rule.Name] = append(r.schedulesMapping[rule.Name], entryID)
eventManagerLog(logger.LevelDebug, "scheduled rule %q added, id: %d, active scheduling rules: %d",
rule.Name, entryID, len(r.schedulesMapping))
}
}
default:
eventManagerLog(logger.LevelError, "unsupported trigger: %d", rule.Trigger)
}
}
func (r *EventRulesContainer) loadRules() {
eventManagerLog(logger.LevelDebug, "loading updated rules")
modTime := util.GetTimeAsMsSinceEpoch(time.Now())
rules, err := provider.getRecentlyUpdatedRules(r.getLastLoadTime())
if err != nil {
eventManagerLog(logger.LevelError, "unable to load event rules: %v", err)
return
}
eventManagerLog(logger.LevelDebug, "recently updated event rules loaded: %d", len(rules))
if len(rules) > 0 {
r.Lock()
defer r.Unlock()
for _, rule := range rules {
r.addUpdateRuleInternal(rule)
}
}
eventManagerLog(logger.LevelDebug, "event rules updated, fs events: %d, provider events: %d, schedules: %d",
len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules))
r.setLastLoadTime(modTime)
}
// HasFsRules returns true if there are any rules for filesystem event triggers
func (r *EventRulesContainer) HasFsRules() bool {
r.RLock()
defer r.RUnlock()
return len(r.FsEvents) > 0
}
func (r *EventRulesContainer) hasProviderEvents() bool {
r.RLock()
defer r.RUnlock()
return len(r.ProviderEvents) > 0
}
// HandleFsEvent executes the rules actions defined for the specified event
func (r *EventRulesContainer) HandleFsEvent(params EventParams) error {
r.RLock()
var rulesWithSyncActions, rulesAsync []EventRule
for _, rule := range r.FsEvents {
if rule.Conditions.FsEventMatch(params) {
hasSyncActions := false
for _, action := range rule.Actions {
if action.Options.ExecuteSync {
hasSyncActions = true
break
}
}
if hasSyncActions {
rulesWithSyncActions = append(rulesWithSyncActions, rule)
} else {
rulesAsync = append(rulesAsync, rule)
}
}
}
r.RUnlock()
if len(rulesAsync) > 0 {
go executeAsyncActions(rulesAsync, params)
}
if len(rulesWithSyncActions) > 0 {
return executeSyncActions(rulesWithSyncActions, params)
}
return nil
}
func (r *EventRulesContainer) handleProviderEvent(params EventParams) {
r.RLock()
defer r.RUnlock()
var rules []EventRule
for _, rule := range r.ProviderEvents {
if rule.Conditions.ProviderEventMatch(params) {
rules = append(rules, rule)
}
}
go executeAsyncActions(rules, params)
}
// EventParams defines the supported event parameters
type EventParams struct {
Name string
Event string
Status int
VirtualPath string
FsPath string
VirtualTargetPath string
FsTargetPath string
ObjectName string
ObjectType string
FileSize int64
Protocol string
IP string
Timestamp int64
Object plugin.Renderer
}
func (p *EventParams) getStringReplacements(addObjectData bool) []string {
replacements := []string{
"{{Name}}", p.Name,
"{{Event}}", p.Event,
"{{Status}}", fmt.Sprintf("%d", p.Status),
"{{VirtualPath}}", p.VirtualPath,
"{{FsPath}}", p.FsPath,
"{{VirtualTargetPath}}", p.VirtualTargetPath,
"{{FsTargetPath}}", p.FsTargetPath,
"{{ObjectName}}", p.ObjectName,
"{{ObjectType}}", p.ObjectType,
"{{FileSize}}", fmt.Sprintf("%d", p.FileSize),
"{{Protocol}}", p.Protocol,
"{{IP}}", p.IP,
"{{Timestamp}}", fmt.Sprintf("%d", p.Timestamp),
}
if addObjectData {
data, err := p.Object.RenderAsJSON(p.Event != operationDelete)
if err == nil {
replacements = append(replacements, "{{ObjectData}}", string(data))
}
}
return replacements
}
func replaceWithReplacer(input string, replacer *strings.Replacer) string {
if !strings.Contains(input, "{{") {
return input
}
return replacer.Replace(input)
}
// checkConditionPatterns returns false if patterns are defined and no match is found
func checkConditionPatterns(name string, patterns []ConditionPattern) bool {
if len(patterns) == 0 {
return true
}
for _, p := range patterns {
if p.match(name) {
return true
}
}
return false
}
func executeSyncActions(rules []EventRule, params EventParams) error {
var errRes error
for _, rule := range rules {
var failedActions []string
for _, action := range rule.Actions {
if !action.Options.IsFailureAction && action.Options.ExecuteSync {
startTime := time.Now()
if err := action.execute(params, rule.Conditions.Options); err != nil {
eventManagerLog(logger.LevelError, "unable to execute sync action %q for rule %q, elapsed %s, err: %v",
action.Name, rule.Name, time.Since(startTime), err)
failedActions = append(failedActions, action.Name)
// we return the last error, it is ok for now
errRes = err
if action.Options.StopOnFailure {
break
}
} else {
eventManagerLog(logger.LevelDebug, "executed sync action %q for rule %q, elapsed: %s",
action.Name, rule.Name, time.Since(startTime))
}
}
}
// execute async actions if any, including failure actions
go executeRuleAsyncActions(rule, params, failedActions)
}
return errRes
}
func executeAsyncActions(rules []EventRule, params EventParams) {
for _, rule := range rules {
executeRuleAsyncActions(rule, params, nil)
}
}
func executeRuleAsyncActions(rule EventRule, params EventParams, failedActions []string) {
for _, action := range rule.Actions {
if !action.Options.IsFailureAction && !action.Options.ExecuteSync {
startTime := time.Now()
if err := action.execute(params, rule.Conditions.Options); err != nil {
eventManagerLog(logger.LevelError, "unable to execute action %q for rule %q, elapsed %s, err: %v",
action.Name, rule.Name, time.Since(startTime), err)
failedActions = append(failedActions, action.Name)
if action.Options.StopOnFailure {
break
}
} else {
eventManagerLog(logger.LevelDebug, "executed action %q for rule %q, elapsed %s",
action.Name, rule.Name, time.Since(startTime))
}
}
if len(failedActions) > 0 {
// execute failure actions
for _, action := range rule.Actions {
if action.Options.IsFailureAction {
startTime := time.Now()
if err := action.execute(params, rule.Conditions.Options); err != nil {
eventManagerLog(logger.LevelError, "unable to execute failure action %q for rule %q, elapsed %s, err: %v",
action.Name, rule.Name, time.Since(startTime), err)
if action.Options.StopOnFailure {
break
}
} else {
eventManagerLog(logger.LevelDebug, "executed failure action %q for rule %q, elapsed: %s",
action.Name, rule.Name, time.Since(startTime))
}
}
}
}
}
}
type cronJob struct {
ruleName string
}
func (j *cronJob) getTask(rule EventRule) (Task, error) {
if rule.guardFromConcurrentExecution() {
task, err := provider.getTaskByName(rule.Name)
if _, ok := err.(*util.RecordNotFoundError); ok {
eventManagerLog(logger.LevelDebug, "adding task for rule %q", rule.Name)
task = Task{
Name: rule.Name,
UpdateAt: 0,
Version: 0,
}
err = provider.addTask(rule.Name)
if err != nil {
eventManagerLog(logger.LevelWarn, "unable to add task for rule %q: %v", rule.Name, err)
return task, err
}
} else {
eventManagerLog(logger.LevelWarn, "unable to get task for rule %q: %v", rule.Name, err)
}
return task, err
}
return Task{}, nil
}
func (j *cronJob) Run() {
eventManagerLog(logger.LevelDebug, "executing scheduled rule %q", j.ruleName)
rule, err := provider.eventRuleExists(j.ruleName)
if err != nil {
eventManagerLog(logger.LevelError, "unable to load rule with name %q", j.ruleName)
return
}
task, err := j.getTask(rule)
if err != nil {
return
}
if task.Name != "" {
updateInterval := 5 * time.Minute
updatedAt := util.GetTimeFromMsecSinceEpoch(task.UpdateAt)
if updatedAt.Add(updateInterval*2 + 1).After(time.Now()) {
eventManagerLog(logger.LevelDebug, "task for rule %q too recent: %s, skip execution", rule.Name, updatedAt)
return
}
err = provider.updateTask(rule.Name, task.Version)
if err != nil {
eventManagerLog(logger.LevelInfo, "unable to update task timestamp for rule %q, skip execution, err: %v",
rule.Name, err)
return
}
ticker := time.NewTicker(updateInterval)
done := make(chan bool)
go func(taskName string) {
eventManagerLog(logger.LevelDebug, "update task %q timestamp worker started", taskName)
for {
select {
case <-done:
eventManagerLog(logger.LevelDebug, "update task %q timestamp worker finished", taskName)
return
case <-ticker.C:
err := provider.updateTaskTimestamp(taskName)
eventManagerLog(logger.LevelInfo, "updated timestamp for task %q, err: %v", taskName, err)
}
}
}(task.Name)
executeRuleAsyncActions(rule, EventParams{}, nil)
done <- true
ticker.Stop()
} else {
executeRuleAsyncActions(rule, EventParams{}, nil)
}
eventManagerLog(logger.LevelDebug, "execution for scheduled rule %q finished", j.ruleName)
}
func cloneKeyValues(keyVals []KeyValue) []KeyValue {
res := make([]KeyValue, 0, len(keyVals))
for _, kv := range keyVals {
res = append(res, KeyValue{
Key: kv.Key,
Value: kv.Value,
})
}
return res
}
func cloneConditionPatterns(patterns []ConditionPattern) []ConditionPattern {
res := make([]ConditionPattern, 0, len(patterns))
for _, p := range patterns {
res = append(res, ConditionPattern{
Pattern: p.Pattern,
InverseMatch: p.InverseMatch,
})
}
return res
}
func eventManagerLog(level logger.LogLevel, format string, v ...any) {
logger.Log(level, "eventmanager", "", format, v...)
}

View file

@ -1,141 +0,0 @@
// Copyright (C) 2019-2022 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package dataprovider
import (
"sync"
"time"
"github.com/drakkan/sftpgo/v2/internal/util"
)
var (
// QuotaScans is the list of active quota scans
QuotaScans ActiveScans
)
// ActiveQuotaScan defines an active quota scan for a user home dir
type ActiveQuotaScan struct {
// Username to which the quota scan refers
Username string `json:"username"`
// quota scan start time as unix timestamp in milliseconds
StartTime int64 `json:"start_time"`
}
// ActiveVirtualFolderQuotaScan defines an active quota scan for a virtual folder
type ActiveVirtualFolderQuotaScan struct {
// folder name to which the quota scan refers
Name string `json:"name"`
// quota scan start time as unix timestamp in milliseconds
StartTime int64 `json:"start_time"`
}
// ActiveScans holds the active quota scans
type ActiveScans struct {
sync.RWMutex
UserScans []ActiveQuotaScan
FolderScans []ActiveVirtualFolderQuotaScan
}
// GetUsersQuotaScans returns the active quota scans for users home directories
func (s *ActiveScans) GetUsersQuotaScans() []ActiveQuotaScan {
s.RLock()
defer s.RUnlock()
scans := make([]ActiveQuotaScan, len(s.UserScans))
copy(scans, s.UserScans)
return scans
}
// AddUserQuotaScan adds a user to the ones with active quota scans.
// Returns false if the user has a quota scan already running
func (s *ActiveScans) AddUserQuotaScan(username string) bool {
s.Lock()
defer s.Unlock()
for _, scan := range s.UserScans {
if scan.Username == username {
return false
}
}
s.UserScans = append(s.UserScans, ActiveQuotaScan{
Username: username,
StartTime: util.GetTimeAsMsSinceEpoch(time.Now()),
})
return true
}
// RemoveUserQuotaScan removes a user from the ones with active quota scans.
// Returns false if the user has no active quota scans
func (s *ActiveScans) RemoveUserQuotaScan(username string) bool {
s.Lock()
defer s.Unlock()
for idx, scan := range s.UserScans {
if scan.Username == username {
lastIdx := len(s.UserScans) - 1
s.UserScans[idx] = s.UserScans[lastIdx]
s.UserScans = s.UserScans[:lastIdx]
return true
}
}
return false
}
// GetVFoldersQuotaScans returns the active quota scans for virtual folders
func (s *ActiveScans) GetVFoldersQuotaScans() []ActiveVirtualFolderQuotaScan {
s.RLock()
defer s.RUnlock()
scans := make([]ActiveVirtualFolderQuotaScan, len(s.FolderScans))
copy(scans, s.FolderScans)
return scans
}
// AddVFolderQuotaScan adds a virtual folder to the ones with active quota scans.
// Returns false if the folder has a quota scan already running
func (s *ActiveScans) AddVFolderQuotaScan(folderName string) bool {
s.Lock()
defer s.Unlock()
for _, scan := range s.FolderScans {
if scan.Name == folderName {
return false
}
}
s.FolderScans = append(s.FolderScans, ActiveVirtualFolderQuotaScan{
Name: folderName,
StartTime: util.GetTimeAsMsSinceEpoch(time.Now()),
})
return true
}
// RemoveVFolderQuotaScan removes a folder from the ones with active quota scans.
// Returns false if the folder has no active quota scans
func (s *ActiveScans) RemoveVFolderQuotaScan(folderName string) bool {
s.Lock()
defer s.Unlock()
for idx, scan := range s.FolderScans {
if scan.Name == folderName {
lastIdx := len(s.FolderScans) - 1
s.FolderScans[idx] = s.FolderScans[lastIdx]
s.FolderScans = s.FolderScans[:lastIdx]
return true
}
}
return false
}

View file

@ -54,7 +54,9 @@ func startScheduler() error {
if err != nil {
return err
}
EventManager.loadRules()
if fnReloadRules != nil {
fnReloadRules()
}
scheduler.Start()
return nil
}
@ -77,7 +79,7 @@ func checkDataprovider() {
}
func checkCacheUpdates() {
providerLog(logger.LevelDebug, "start caches check, update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate))
providerLog(logger.LevelDebug, "start user cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate))
checkTime := util.GetTimeAsMsSinceEpoch(time.Now())
users, err := provider.getRecentlyUpdatedUsers(lastUserCacheUpdate)
if err != nil {
@ -101,8 +103,7 @@ func checkCacheUpdates() {
}
lastUserCacheUpdate = checkTime
EventManager.loadRules()
providerLog(logger.LevelDebug, "end caches check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate))
providerLog(logger.LevelDebug, "end user cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate))
}
func setLastUserUpdate() {

View file

@ -27,6 +27,7 @@ import (
"github.com/go-chi/render"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
@ -265,7 +266,7 @@ func RestoreFolders(folders []vfs.BaseVirtualFolder, inputFile string, mode, sca
return fmt.Errorf("unable to restore folder %#v: %w", folder.Name, err)
}
if scanQuota >= 1 {
if dataprovider.QuotaScans.AddVFolderQuotaScan(folder.Name) {
if common.QuotaScans.AddVFolderQuotaScan(folder.Name) {
logger.Debug(logSender, "", "starting quota scan for restored folder: %#v", folder.Name)
go doFolderQuotaScan(folder) //nolint:errcheck
}
@ -453,7 +454,7 @@ func RestoreUsers(users []dataprovider.User, inputFile string, mode, scanQuota i
return fmt.Errorf("unable to restore user %#v: %w", user.Username, err)
}
if scanQuota == 1 || (scanQuota == 2 && user.HasQuotaRestrictions()) {
if dataprovider.QuotaScans.AddUserQuotaScan(user.Username) {
if common.QuotaScans.AddUserQuotaScan(user.Username) {
logger.Debug(logSender, "", "starting quota scan for restored user: %#v", user.Username)
go doUserQuotaScan(user) //nolint:errcheck
}

View file

@ -21,6 +21,7 @@ import (
"github.com/go-chi/render"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/vfs"
@ -43,12 +44,12 @@ type transferQuotaUsage struct {
func getUsersQuotaScans(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
render.JSON(w, r, dataprovider.QuotaScans.GetUsersQuotaScans())
render.JSON(w, r, common.QuotaScans.GetUsersQuotaScans())
}
func getFoldersQuotaScans(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
render.JSON(w, r, dataprovider.QuotaScans.GetVFoldersQuotaScans())
render.JSON(w, r, common.QuotaScans.GetVFoldersQuotaScans())
}
func updateUserQuotaUsage(w http.ResponseWriter, r *http.Request) {
@ -141,11 +142,11 @@ func doUpdateUserQuotaUsage(w http.ResponseWriter, r *http.Request, username str
"", http.StatusBadRequest)
return
}
if !dataprovider.QuotaScans.AddUserQuotaScan(user.Username) {
if !common.QuotaScans.AddUserQuotaScan(user.Username) {
sendAPIResponse(w, r, err, "A quota scan is in progress for this user", http.StatusConflict)
return
}
defer dataprovider.QuotaScans.RemoveUserQuotaScan(user.Username)
defer common.QuotaScans.RemoveUserQuotaScan(user.Username)
err = dataprovider.UpdateUserQuota(&user, usage.UsedQuotaFiles, usage.UsedQuotaSize, mode == quotaUpdateModeReset)
if err != nil {
sendAPIResponse(w, r, err, "", getRespStatus(err))
@ -170,11 +171,11 @@ func doUpdateFolderQuotaUsage(w http.ResponseWriter, r *http.Request, name strin
sendAPIResponse(w, r, err, "", getRespStatus(err))
return
}
if !dataprovider.QuotaScans.AddVFolderQuotaScan(folder.Name) {
if !common.QuotaScans.AddVFolderQuotaScan(folder.Name) {
sendAPIResponse(w, r, err, "A quota scan is in progress for this folder", http.StatusConflict)
return
}
defer dataprovider.QuotaScans.RemoveVFolderQuotaScan(folder.Name)
defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name)
err = dataprovider.UpdateVirtualFolderQuota(&folder, usage.UsedQuotaFiles, usage.UsedQuotaSize, mode == quotaUpdateModeReset)
if err != nil {
sendAPIResponse(w, r, err, "", getRespStatus(err))
@ -193,7 +194,7 @@ func doStartUserQuotaScan(w http.ResponseWriter, r *http.Request, username strin
sendAPIResponse(w, r, err, "", getRespStatus(err))
return
}
if !dataprovider.QuotaScans.AddUserQuotaScan(user.Username) {
if !common.QuotaScans.AddUserQuotaScan(user.Username) {
sendAPIResponse(w, r, nil, fmt.Sprintf("Another scan is already in progress for user %#v", username),
http.StatusConflict)
return
@ -212,7 +213,7 @@ func doStartFolderQuotaScan(w http.ResponseWriter, r *http.Request, name string)
sendAPIResponse(w, r, err, "", getRespStatus(err))
return
}
if !dataprovider.QuotaScans.AddVFolderQuotaScan(folder.Name) {
if !common.QuotaScans.AddVFolderQuotaScan(folder.Name) {
sendAPIResponse(w, r, err, fmt.Sprintf("Another scan is already in progress for folder %#v", name),
http.StatusConflict)
return
@ -222,7 +223,7 @@ func doStartFolderQuotaScan(w http.ResponseWriter, r *http.Request, name string)
}
func doUserQuotaScan(user dataprovider.User) error {
defer dataprovider.QuotaScans.RemoveUserQuotaScan(user.Username)
defer common.QuotaScans.RemoveUserQuotaScan(user.Username)
numFiles, size, err := user.ScanQuota()
if err != nil {
logger.Warn(logSender, "", "error scanning user quota %#v: %v", user.Username, err)
@ -234,7 +235,7 @@ func doUserQuotaScan(user dataprovider.User) error {
}
func doFolderQuotaScan(folder vfs.BaseVirtualFolder) error {
defer dataprovider.QuotaScans.RemoveVFolderQuotaScan(folder.Name)
defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name)
f := vfs.VirtualFolder{
BaseVirtualFolder: folder,
VirtualPath: "/",

View file

@ -1951,6 +1951,8 @@ func TestHTTPUserAuthEmptyPassword(t *testing.T) {
c.CloseIdleConnections()
assert.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
err = resp.Body.Close()
assert.NoError(t, err)
_, err = getJWTAPIUserTokenFromTestServer(defaultUsername, "")
if assert.Error(t, err) {
@ -1986,6 +1988,8 @@ func TestHTTPAnonymousUser(t *testing.T) {
c.CloseIdleConnections()
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
err = resp.Body.Close()
assert.NoError(t, err)
_, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
if assert.Error(t, err) {
@ -9347,12 +9351,12 @@ func TestUpdateUserQuotaUsageMock(t *testing.T) {
setBearerForReq(req, token)
rr = executeRequest(req)
checkResponseCode(t, http.StatusBadRequest, rr)
assert.True(t, dataprovider.QuotaScans.AddUserQuotaScan(user.Username))
assert.True(t, common.QuotaScans.AddUserQuotaScan(user.Username))
req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "usage"), bytes.NewBuffer(userAsJSON))
setBearerForReq(req, token)
rr = executeRequest(req)
checkResponseCode(t, http.StatusConflict, rr)
assert.True(t, dataprovider.QuotaScans.RemoveUserQuotaScan(user.Username))
assert.True(t, common.QuotaScans.RemoveUserQuotaScan(user.Username))
req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil)
setBearerForReq(req, token)
rr = executeRequest(req)
@ -9624,12 +9628,12 @@ func TestStartQuotaScanMock(t *testing.T) {
assert.NoError(t, err)
}
// simulate a duplicate quota scan
dataprovider.QuotaScans.AddUserQuotaScan(user.Username)
common.QuotaScans.AddUserQuotaScan(user.Username)
req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil)
setBearerForReq(req, token)
rr = executeRequest(req)
checkResponseCode(t, http.StatusConflict, rr)
assert.True(t, dataprovider.QuotaScans.RemoveUserQuotaScan(user.Username))
assert.True(t, common.QuotaScans.RemoveUserQuotaScan(user.Username))
req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil)
setBearerForReq(req, token)
@ -9743,13 +9747,13 @@ func TestUpdateFolderQuotaUsageMock(t *testing.T) {
rr = executeRequest(req)
checkResponseCode(t, http.StatusBadRequest, rr)
assert.True(t, dataprovider.QuotaScans.AddVFolderQuotaScan(folderName))
assert.True(t, common.QuotaScans.AddVFolderQuotaScan(folderName))
req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "folders", folder.Name, "usage"),
bytes.NewBuffer(folderAsJSON))
setBearerForReq(req, token)
rr = executeRequest(req)
checkResponseCode(t, http.StatusConflict, rr)
assert.True(t, dataprovider.QuotaScans.RemoveVFolderQuotaScan(folderName))
assert.True(t, common.QuotaScans.RemoveVFolderQuotaScan(folderName))
req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil)
setBearerForReq(req, token)
@ -9778,12 +9782,12 @@ func TestStartFolderQuotaScanMock(t *testing.T) {
assert.NoError(t, err)
}
// simulate a duplicate quota scan
dataprovider.QuotaScans.AddVFolderQuotaScan(folderName)
common.QuotaScans.AddVFolderQuotaScan(folderName)
req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "folders", folder.Name, "scan"), nil)
setBearerForReq(req, token)
rr = executeRequest(req)
checkResponseCode(t, http.StatusConflict, rr)
assert.True(t, dataprovider.QuotaScans.RemoveVFolderQuotaScan(folderName))
assert.True(t, common.QuotaScans.RemoveVFolderQuotaScan(folderName))
// and now a real quota scan
_, err = os.Stat(mappedPath)
if err != nil && errors.Is(err, fs.ErrNotExist) {
@ -20392,7 +20396,7 @@ func startOIDCMockServer() {
func waitForUsersQuotaScan(t *testing.T, token string) {
for {
var scans []dataprovider.ActiveQuotaScan
var scans []common.ActiveQuotaScan
req, _ := http.NewRequest(http.MethodGet, quotaScanPath, nil)
setBearerForReq(req, token)
rr := executeRequest(req)
@ -20410,7 +20414,7 @@ func waitForUsersQuotaScan(t *testing.T, token string) {
}
func waitForFoldersQuotaScanPath(t *testing.T, token string) {
var scans []dataprovider.ActiveVirtualFolderQuotaScan
var scans []common.ActiveVirtualFolderQuotaScan
for {
req, _ := http.NewRequest(http.MethodGet, quotaScanVFolderPath, nil)
setBearerForReq(req, token)

View file

@ -1421,7 +1421,7 @@ func TestQuotaScanInvalidFs(t *testing.T) {
Provider: sdk.S3FilesystemProvider,
},
}
dataprovider.QuotaScans.AddUserQuotaScan(user.Username)
common.QuotaScans.AddUserQuotaScan(user.Username)
err := doUserQuotaScan(user)
assert.Error(t, err)
}

View file

@ -833,8 +833,8 @@ func GetEventRules(limit, offset int64, expectedStatusCode int) ([]dataprovider.
}
// GetQuotaScans gets active quota scans for users and checks the received HTTP Status code against expectedStatusCode.
func GetQuotaScans(expectedStatusCode int) ([]dataprovider.ActiveQuotaScan, []byte, error) {
var quotaScans []dataprovider.ActiveQuotaScan
func GetQuotaScans(expectedStatusCode int) ([]common.ActiveQuotaScan, []byte, error) {
var quotaScans []common.ActiveQuotaScan
var body []byte
resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanPath), nil, "", getDefaultToken())
if err != nil {
@ -1077,8 +1077,8 @@ func GetFolders(limit int64, offset int64, expectedStatusCode int) ([]vfs.BaseVi
}
// GetFoldersQuotaScans gets active quota scans for folders and checks the received HTTP Status code against expectedStatusCode.
func GetFoldersQuotaScans(expectedStatusCode int) ([]dataprovider.ActiveVirtualFolderQuotaScan, []byte, error) {
var quotaScans []dataprovider.ActiveVirtualFolderQuotaScan
func GetFoldersQuotaScans(expectedStatusCode int) ([]common.ActiveVirtualFolderQuotaScan, []byte, error) {
var quotaScans []common.ActiveVirtualFolderQuotaScan
var body []byte
resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanVFolderPath), nil, "", getDefaultToken())
if err != nil {

View file

@ -155,7 +155,7 @@ func newMockOsFs(err, statErr error, atomicUpload bool, connectionID, rootDir st
}
func TestRemoveNonexistentQuotaScan(t *testing.T) {
assert.False(t, dataprovider.QuotaScans.RemoveUserQuotaScan("username"))
assert.False(t, common.QuotaScans.RemoveUserQuotaScan("username"))
}
func TestGetOSOpenFlags(t *testing.T) {

View file

@ -1213,9 +1213,6 @@ func TestRealPath(t *testing.T) {
for _, user := range []dataprovider.User{localUser, sftpUser} {
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
p, err := client.RealPath("../..")
assert.NoError(t, err)
assert.Equal(t, "/", p)
@ -1247,6 +1244,9 @@ func TestRealPath(t *testing.T) {
assert.NoError(t, err)
_, err = client.RealPath(path.Join(subdir, "temp"))
assert.ErrorIs(t, err, os.ErrPermission)
conn.Close()
client.Close()
err = os.Remove(filepath.Join(localUser.GetHomeDir(), subdir, "temp"))
assert.NoError(t, err)
if user.Username == localUser.Username {
@ -4590,14 +4590,14 @@ func TestQuotaScan(t *testing.T) {
}
func TestMultipleQuotaScans(t *testing.T) {
res := dataprovider.QuotaScans.AddUserQuotaScan(defaultUsername)
res := common.QuotaScans.AddUserQuotaScan(defaultUsername)
assert.True(t, res)
res = dataprovider.QuotaScans.AddUserQuotaScan(defaultUsername)
res = common.QuotaScans.AddUserQuotaScan(defaultUsername)
assert.False(t, res, "add quota must fail if another scan is already active")
assert.True(t, dataprovider.QuotaScans.RemoveUserQuotaScan(defaultUsername))
activeScans := dataprovider.QuotaScans.GetUsersQuotaScans()
assert.True(t, common.QuotaScans.RemoveUserQuotaScan(defaultUsername))
activeScans := common.QuotaScans.GetUsersQuotaScans()
assert.Equal(t, 0, len(activeScans))
assert.False(t, dataprovider.QuotaScans.RemoveUserQuotaScan(defaultUsername))
assert.False(t, common.QuotaScans.RemoveUserQuotaScan(defaultUsername))
}
func TestQuotaLimits(t *testing.T) {
@ -6949,15 +6949,15 @@ func TestVirtualFolderQuotaScan(t *testing.T) {
func TestVFolderMultipleQuotaScan(t *testing.T) {
folderName := "folder_name"
res := dataprovider.QuotaScans.AddVFolderQuotaScan(folderName)
res := common.QuotaScans.AddVFolderQuotaScan(folderName)
assert.True(t, res)
res = dataprovider.QuotaScans.AddVFolderQuotaScan(folderName)
res = common.QuotaScans.AddVFolderQuotaScan(folderName)
assert.False(t, res)
res = dataprovider.QuotaScans.RemoveVFolderQuotaScan(folderName)
res = common.QuotaScans.RemoveVFolderQuotaScan(folderName)
assert.True(t, res)
activeScans := dataprovider.QuotaScans.GetVFoldersQuotaScans()
activeScans := common.QuotaScans.GetVFoldersQuotaScans()
assert.Len(t, activeScans, 0)
res = dataprovider.QuotaScans.RemoveVFolderQuotaScan(folderName)
res = common.QuotaScans.RemoveVFolderQuotaScan(folderName)
assert.False(t, res)
}

View file

@ -607,3 +607,10 @@ func GetTLSVersion(val int) uint16 {
func IsEmailValid(email string) bool {
return emailRegex.MatchString(email)
}
// PanicOnError calls panic if err is not nil
func PanicOnError(err error) {
if err != nil {
panic(fmt.Errorf("unexpected error: %w", err))
}
}

View file

@ -156,7 +156,7 @@ func (fs *CryptFs) Create(name string, flag int) (File, *PipeWriter, func(), err
if flag == 0 {
f, err = os.Create(name)
} else {
f, err = os.OpenFile(name, flag, os.ModePerm)
f, err = os.OpenFile(name, flag, 0666)
}
if err != nil {
return nil, nil, nil, err