From 1df1225eed90d4030b94c4a72bf18c3328dee7ec Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sun, 30 Jan 2022 11:42:36 +0100 Subject: [PATCH] add support for data transfer bandwidth limits with total limit or separate settings for uploads and downloads and overrides based on the client's IP address. Limits can be reset using the REST API Signed-off-by: Nicola Murino --- README.md | 3 +- cmd/revertprovider.go | 2 +- cmd/startsubsys.go | 4 +- common/common.go | 47 ++-- common/common_test.go | 22 +- common/connection.go | 107 ++++++-- common/connection_test.go | 7 + common/protocol_test.go | 111 +++++++- common/transfer.go | 68 ++++- common/transfer_test.go | 140 +++++++++- common/transferschecker.go | 330 ++++++++++++++++------ common/transferschecker_test.go | 325 ++++++++++++++++++++-- dataprovider/bolt.go | 73 ++++- dataprovider/dataprovider.go | 187 +++++++++++-- dataprovider/memory.go | 57 +++- dataprovider/mysql.go | 81 +++++- dataprovider/pgsql.go | 83 +++++- dataprovider/quotaupdater.go | 72 ++++- dataprovider/sqlcommon.go | 163 +++++++++-- dataprovider/sqlite.go | 79 +++++- dataprovider/sqlqueries.go | 74 ++++- dataprovider/user.go | 116 ++++++-- docs/full-configuration.md | 2 +- docs/howto/getting-started.md | 2 +- ftpd/ftpd_test.go | 152 ++++++++--- ftpd/handler.go | 73 +++-- ftpd/internal_test.go | 6 +- ftpd/transfer.go | 7 +- go.mod | 40 +-- go.sum | 78 +++--- httpd/api_http_user.go | 7 + httpd/api_quota.go | 46 +++- httpd/api_shares.go | 15 + httpd/api_utils.go | 8 +- httpd/file.go | 7 +- httpd/handler.go | 20 +- httpd/httpd_test.go | 471 +++++++++++++++++++++++++++++++- httpd/internal_test.go | 2 +- httpd/server.go | 8 +- httpd/webadmin.go | 106 +++++-- httpdtest/httpdtest.go | 76 +++++- openapi/openapi.yaml | 124 +++++++-- service/service.go | 6 +- sftpd/handler.go | 25 +- sftpd/internal_test.go | 42 ++- sftpd/scp.go | 18 +- sftpd/server.go | 2 - sftpd/sftpd_test.go | 121 +++++++- sftpd/ssh_cmd.go | 43 +-- sftpd/transfer.go | 19 +- templates/webadmin/user.html | 197 ++++++++++++- webdavd/file.go | 16 +- webdavd/handler.go | 21 +- webdavd/internal_test.go | 26 +- webdavd/webdavd_test.go | 279 ++++++++++++++----- 55 files changed, 3573 insertions(+), 643 deletions(-) diff --git a/README.md b/README.md index 3129ec99..e1164a27 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,8 @@ Several storage backends are supported: local filesystem, encrypted local filesy - [Data At Rest Encryption](./docs/dare.md). - Dynamic user modification before login via external programs/HTTP API. - Quota support: accounts can have individual quota expressed as max total size and/or max number of files. -- Bandwidth throttling, with distinct settings for upload and download and overrides based on the client IP address. +- Bandwidth throttling, with separate settings for upload and download and overrides based on the client's IP address. +- Data transfer bandwidth limits, with total limit or separate settings for uploads and downloads and overrides based on the client's IP address. Limits can be reset using the REST API. - Per-protocol [rate limiting](./docs/rate-limiting.md) is supported and can be optionally connected to the built-in defender to automatically block hosts that repeatedly exceed the configured limit. - Per user maximum concurrent sessions. - Per user and global IP filters: login can be restricted to specific ranges of IP addresses or to a specific IP address. diff --git a/cmd/revertprovider.go b/cmd/revertprovider.go index 64d82e4c..5d0a9ca9 100644 --- a/cmd/revertprovider.go +++ b/cmd/revertprovider.go @@ -57,7 +57,7 @@ Please take a look at the usage below to customize the options.`, func init() { addConfigFlags(revertProviderCmd) - revertProviderCmd.Flags().IntVar(&revertProviderTargetVersion, "to-version", 15, `15 means the version supported in v2.2.1`) + revertProviderCmd.Flags().IntVar(&revertProviderTargetVersion, "to-version", 15, `15 means the version supported in v2.2.x`) revertProviderCmd.MarkFlagRequired("to-version") //nolint:errcheck rootCmd.AddCommand(revertProviderCmd) diff --git a/cmd/startsubsys.go b/cmd/startsubsys.go index d45f0242..7971f35a 100644 --- a/cmd/startsubsys.go +++ b/cmd/startsubsys.go @@ -64,11 +64,12 @@ Command-line flags should be specified in the Subsystem declaration. logger.Error(logSender, connectionID, "unable to load configuration: %v", err) os.Exit(1) } + dataProviderConf := config.GetProviderConf() commonConfig := config.GetCommonConfig() // idle connection are managed externally commonConfig.IdleTimeout = 0 config.SetCommonConfig(commonConfig) - if err := common.Initialize(config.GetCommonConfig()); err != nil { + if err := common.Initialize(config.GetCommonConfig(), dataProviderConf.GetShared()); err != nil { logger.Error(logSender, connectionID, "%v", err) os.Exit(1) } @@ -93,7 +94,6 @@ Command-line flags should be specified in the Subsystem declaration. logger.Error(logSender, connectionID, "unable to initialize SMTP configuration: %v", err) os.Exit(1) } - dataProviderConf := config.GetProviderConf() if dataProviderConf.Driver == dataprovider.SQLiteDataProviderName || dataProviderConf.Driver == dataprovider.BoltDataProviderName { logger.Debug(logSender, connectionID, "data provider %#v not supported in subsystem mode, using %#v provider", dataProviderConf.Driver, dataprovider.MemoryDataProviderName) diff --git a/common/common.go b/common/common.go index dffadb0c..2ba2cb94 100644 --- a/common/common.go +++ b/common/common.go @@ -105,6 +105,7 @@ var ( ErrOpUnsupported = errors.New("operation unsupported") ErrGenericFailure = errors.New("failure") ErrQuotaExceeded = errors.New("denying write due to space limit") + ErrReadQuotaExceeded = errors.New("denying read due to quota limit") ErrSkipPermissionsCheck = errors.New("permission check skipped") ErrConnectionDenied = errors.New("you are not allowed to connect") ErrNoBinding = errors.New("no binding configured") @@ -134,7 +135,7 @@ var ( ) // Initialize sets the common configuration -func Initialize(c Configuration) error { +func Initialize(c Configuration, isShared int) error { Config = c Config.idleLoginTimeout = 2 * time.Minute Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute @@ -177,7 +178,7 @@ func Initialize(c Configuration) error { } vfs.SetTempPath(c.TempPath) dataprovider.SetTempPath(c.TempPath) - transfersChecker = getTransfersChecker() + transfersChecker = getTransfersChecker(isShared) return nil } @@ -314,7 +315,7 @@ type ActiveTransfer interface { GetRealFsPath(fsPath string) string SetTimes(fsPath string, atime time.Time, mtime time.Time) bool GetTruncatedSize() int64 - GetMaxAllowedSize() int64 + HasSizeLimit() bool } // ActiveConnection defines the interface for the current active connections @@ -349,14 +350,14 @@ type StatAttributes struct { // ConnectionTransfer defines the trasfer details to expose type ConnectionTransfer struct { - ID int64 `json:"-"` - OperationType string `json:"operation_type"` - StartTime int64 `json:"start_time"` - Size int64 `json:"size"` - VirtualPath string `json:"path"` - MaxAllowedSize int64 `json:"-"` - ULSize int64 `json:"-"` - DLSize int64 `json:"-"` + ID int64 `json:"-"` + OperationType string `json:"operation_type"` + StartTime int64 `json:"start_time"` + Size int64 `json:"size"` + VirtualPath string `json:"path"` + HasSizeLimit bool `json:"-"` + ULSize int64 `json:"-"` + DLSize int64 `json:"-"` } func (t *ConnectionTransfer) getConnectionTransferAsString() string { @@ -851,20 +852,24 @@ func (conns *ActiveConnections) checkTransfers() { atomic.StoreInt32(&conns.transfersCheckStatus, 1) defer atomic.StoreInt32(&conns.transfersCheckStatus, 0) - var wg sync.WaitGroup - - logger.Debug(logSender, "", "start concurrent transfers check") conns.RLock() + if len(conns.connections) < 2 { + conns.RUnlock() + return + } + var wg sync.WaitGroup + logger.Debug(logSender, "", "start concurrent transfers check") + // update the current size for transfers to monitors for _, c := range conns.connections { for _, t := range c.GetTransfers() { - if t.MaxAllowedSize > 0 { + if t.HasSizeLimit { wg.Add(1) go func(transfer ConnectionTransfer, connID string) { defer wg.Done() - transfersChecker.UpdateTransferCurrentSize(transfer.ULSize, transfer.DLSize, transfer.ID, connID) + transfersChecker.UpdateTransferCurrentSizes(transfer.ULSize, transfer.DLSize, transfer.ID, connID) }(t, c.GetID()) } } @@ -887,9 +892,15 @@ func (conns *ActiveConnections) checkTransfers() { for _, c := range conns.connections { for _, overquotaTransfer := range overquotaTransfers { if c.GetID() == overquotaTransfer.ConnID { - logger.Info(logSender, c.GetID(), "user %#v is overquota, try to close transfer id %v ", + logger.Info(logSender, c.GetID(), "user %#v is overquota, try to close transfer id %v", c.GetUsername(), overquotaTransfer.TransferID) - c.SignalTransferClose(overquotaTransfer.TransferID, getQuotaExceededError(c.GetProtocol())) + var err error + if overquotaTransfer.TransferType == TransferDownload { + err = getReadQuotaExceededError(c.GetProtocol()) + } else { + err = getQuotaExceededError(c.GetProtocol()) + } + c.SignalTransferClose(overquotaTransfer.TransferID, err) } } } diff --git a/common/common_test.go b/common/common_test.go index 694c2e44..e7ce72b3 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -159,20 +159,20 @@ func TestDefenderIntegration(t *testing.T) { EntriesSoftLimit: 100, EntriesHardLimit: 150, } - err = Initialize(Config) + err = Initialize(Config, 0) // ScoreInvalid cannot be greater than threshold assert.Error(t, err) Config.DefenderConfig.Driver = "unsupported" - err = Initialize(Config) + err = Initialize(Config, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unsupported defender driver") } Config.DefenderConfig.Driver = DefenderDriverMemory - err = Initialize(Config) + err = Initialize(Config, 0) // ScoreInvalid cannot be greater than threshold assert.Error(t, err) Config.DefenderConfig.Threshold = 3 - err = Initialize(Config) + err = Initialize(Config, 0) assert.NoError(t, err) assert.Nil(t, ReloadDefender()) @@ -241,18 +241,18 @@ func TestRateLimitersIntegration(t *testing.T) { EntriesHardLimit: 150, }, } - err := Initialize(Config) + err := Initialize(Config, 0) assert.Error(t, err) Config.RateLimitersConfig[0].Period = 1000 Config.RateLimitersConfig[0].AllowList = []string{"1.1.1", "1.1.1.2"} - err = Initialize(Config) + err = Initialize(Config, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to parse rate limiter allow list") } Config.RateLimitersConfig[0].AllowList = []string{"172.16.24.7"} Config.RateLimitersConfig[1].AllowList = []string{"172.16.0.0/16"} - err = Initialize(Config) + err = Initialize(Config, 0) assert.NoError(t, err) assert.Len(t, rateLimiters, 4) @@ -355,7 +355,7 @@ func TestIdleConnections(t *testing.T) { configCopy := Config Config.IdleTimeout = 1 - err := Initialize(Config) + err := Initialize(Config, 0) assert.NoError(t, err) conn1, conn2 := net.Pipe() @@ -505,9 +505,9 @@ func TestConnectionStatus(t *testing.T) { fakeConn1 := &fakeConnection{ BaseConnection: c1, } - t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs) + t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) t1.BytesReceived = 123 - t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs) + t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) t2.BytesSent = 456 c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user) fakeConn2 := &fakeConnection{ @@ -519,7 +519,7 @@ func TestConnectionStatus(t *testing.T) { BaseConnection: c3, command: "PROPFIND", } - t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs) + t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) Connections.Add(fakeConn1) Connections.Add(fakeConn2) Connections.Add(fakeConn3) diff --git a/common/connection.go b/common/connection.go index c3b2cce4..e5b5635d 100644 --- a/common/connection.go +++ b/common/connection.go @@ -125,7 +125,7 @@ func (c *BaseConnection) AddTransfer(t ActiveTransfer) { c.activeTransfers = append(c.activeTransfers, t) c.Log(logger.LevelDebug, "transfer added, id: %v, active transfers: %v", t.GetID(), len(c.activeTransfers)) - if t.GetMaxAllowedSize() > 0 { + if t.HasSizeLimit() { folderName := "" if t.GetType() == TransferUpload { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(t.GetVirtualPath())) @@ -141,6 +141,7 @@ func (c *BaseConnection) AddTransfer(t ActiveTransfer) { ConnID: c.ID, Username: c.GetUsername(), FolderName: folderName, + IP: c.GetRemoteIP(), TruncatedSize: t.GetTruncatedSize(), CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), @@ -153,7 +154,7 @@ func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) { c.Lock() defer c.Unlock() - if t.GetMaxAllowedSize() > 0 { + if t.HasSizeLimit() { go transfersChecker.RemoveTransfer(t.GetID(), c.ID) } @@ -199,14 +200,14 @@ func (c *BaseConnection) GetTransfers() []ConnectionTransfer { operationType = operationUpload } transfers = append(transfers, ConnectionTransfer{ - ID: t.GetID(), - OperationType: operationType, - StartTime: util.GetTimeAsMsSinceEpoch(t.GetStartTime()), - Size: t.GetSize(), - VirtualPath: t.GetVirtualPath(), - MaxAllowedSize: t.GetMaxAllowedSize(), - ULSize: t.GetUploadedSize(), - DLSize: t.GetDownloadedSize(), + ID: t.GetID(), + OperationType: operationType, + StartTime: util.GetTimeAsMsSinceEpoch(t.GetStartTime()), + Size: t.GetSize(), + VirtualPath: t.GetVirtualPath(), + HasSizeLimit: t.HasSizeLimit(), + ULSize: t.GetUploadedSize(), + DLSize: t.GetDownloadedSize(), }) } @@ -896,7 +897,7 @@ func (c *BaseConnection) hasSpaceForRename(fs vfs.Fs, virtualSourcePath, virtual // rename between user root dir and a virtual folder included in user quota return true } - quotaResult := c.HasSpace(true, false, virtualTargetPath) + quotaResult, _ := c.HasSpace(true, false, virtualTargetPath) return c.hasSpaceForCrossRename(fs, quotaResult, initialSize, fsSourcePath) } @@ -958,7 +959,9 @@ func (c *BaseConnection) hasSpaceForCrossRename(fs vfs.Fs, quotaResult vfs.Quota // GetMaxWriteSize returns the allowed size for an upload or an error // if no enough size is available for a resume/append -func (c *BaseConnection) GetMaxWriteSize(quotaResult vfs.QuotaCheckResult, isResume bool, fileSize int64, isUploadResumeSupported bool) (int64, error) { +func (c *BaseConnection) GetMaxWriteSize(quotaResult vfs.QuotaCheckResult, isResume bool, fileSize int64, + isUploadResumeSupported bool, +) (int64, error) { maxWriteSize := quotaResult.GetRemainingSize() if isResume { @@ -986,8 +989,49 @@ func (c *BaseConnection) GetMaxWriteSize(quotaResult vfs.QuotaCheckResult, isRes return maxWriteSize, nil } +// GetTransferQuota returns the data transfers quota +func (c *BaseConnection) GetTransferQuota() dataprovider.TransferQuota { + result, _, _ := c.checkUserQuota() + return result +} + +func (c *BaseConnection) checkUserQuota() (dataprovider.TransferQuota, int, int64) { + clientIP := c.GetRemoteIP() + ul, dl, total := c.User.GetDataTransferLimits(clientIP) + result := dataprovider.TransferQuota{ + ULSize: ul, + DLSize: dl, + TotalSize: total, + AllowedULSize: 0, + AllowedDLSize: 0, + AllowedTotalSize: 0, + } + if !c.User.HasTransferQuotaRestrictions() { + return result, -1, -1 + } + usedFiles, usedSize, usedULSize, usedDLSize, err := dataprovider.GetUsedQuota(c.User.Username) + if err != nil { + c.Log(logger.LevelError, "error getting used quota for %#v: %v", c.User.Username, err) + result.AllowedTotalSize = -1 + return result, -1, -1 + } + if result.TotalSize > 0 { + result.AllowedTotalSize = result.TotalSize - (usedULSize + usedDLSize) + } + if result.ULSize > 0 { + result.AllowedULSize = result.ULSize - usedULSize + } + if result.DLSize > 0 { + result.AllowedDLSize = result.DLSize - usedDLSize + } + + return result, usedFiles, usedSize +} + // HasSpace checks user's quota usage -func (c *BaseConnection) HasSpace(checkFiles, getUsage bool, requestPath string) vfs.QuotaCheckResult { +func (c *BaseConnection) HasSpace(checkFiles, getUsage bool, requestPath string) (vfs.QuotaCheckResult, + dataprovider.TransferQuota, +) { result := vfs.QuotaCheckResult{ HasSpace: true, AllowedSize: 0, @@ -997,32 +1041,39 @@ func (c *BaseConnection) HasSpace(checkFiles, getUsage bool, requestPath string) QuotaSize: 0, QuotaFiles: 0, } - if dataprovider.GetQuotaTracking() == 0 { - return result + return result, dataprovider.TransferQuota{} } + transferQuota, usedFiles, usedSize := c.checkUserQuota() + var err error var vfolder vfs.VirtualFolder vfolder, err = c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil && !vfolder.IsIncludedInUserQuota() { if vfolder.HasNoQuotaRestrictions(checkFiles) && !getUsage { - return result + return result, transferQuota } result.QuotaSize = vfolder.QuotaSize result.QuotaFiles = vfolder.QuotaFiles result.UsedFiles, result.UsedSize, err = dataprovider.GetUsedVirtualFolderQuota(vfolder.Name) } else { if c.User.HasNoQuotaRestrictions(checkFiles) && !getUsage { - return result + return result, transferQuota } result.QuotaSize = c.User.QuotaSize result.QuotaFiles = c.User.QuotaFiles - result.UsedFiles, result.UsedSize, err = dataprovider.GetUsedQuota(c.User.Username) + if usedSize == -1 { + result.UsedFiles, result.UsedSize, _, _, err = dataprovider.GetUsedQuota(c.User.Username) + } else { + err = nil + result.UsedFiles = usedFiles + result.UsedSize = usedSize + } } if err != nil { c.Log(logger.LevelError, "error getting used quota for %#v request path %#v: %v", c.User.Username, requestPath, err) result.HasSpace = false - return result + return result, transferQuota } result.AllowedFiles = result.QuotaFiles - result.UsedFiles result.AllowedSize = result.QuotaSize - result.UsedSize @@ -1031,9 +1082,9 @@ func (c *BaseConnection) HasSpace(checkFiles, getUsage bool, requestPath string) c.Log(logger.LevelDebug, "quota exceed for user %#v, request path %#v, num files: %v/%v, size: %v/%v check files: %v", c.User.Username, requestPath, result.UsedFiles, result.QuotaFiles, result.UsedSize, result.QuotaSize, checkFiles) result.HasSpace = false - return result + return result, transferQuota } - return result + return result, transferQuota } // returns true if this is a rename on the same fs or local virtual folders @@ -1261,11 +1312,25 @@ func getQuotaExceededError(protocol string) error { } } +func getReadQuotaExceededError(protocol string) error { + switch protocol { + case ProtocolSFTP: + return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, ErrReadQuotaExceeded.Error()) + default: + return ErrReadQuotaExceeded + } +} + // GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol func (c *BaseConnection) GetQuotaExceededError() error { return getQuotaExceededError(c.protocol) } +// GetReadQuotaExceededError returns an appropriate read quota limit exceeded error for the connection protocol +func (c *BaseConnection) GetReadQuotaExceededError() error { + return getReadQuotaExceededError(c.protocol) +} + // IsQuotaExceededError returns true if the given error is a quota exceeded error func (c *BaseConnection) IsQuotaExceededError(err error) bool { switch c.protocol { diff --git a/common/connection_test.go b/common/connection_test.go index 20be3ed6..b726347a 100644 --- a/common/connection_test.go +++ b/common/connection_test.go @@ -308,6 +308,13 @@ func TestErrorsMapping(t *testing.T) { } err = conn.GetQuotaExceededError() assert.True(t, conn.IsQuotaExceededError(err)) + err = conn.GetReadQuotaExceededError() + if protocol == ProtocolSFTP { + assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) + assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) + } else { + assert.ErrorIs(t, err, ErrReadQuotaExceeded) + } err = conn.GetNotExistError() assert.True(t, conn.IsNotExistError(err)) err = conn.GetFsError(fs, nil) diff --git a/common/protocol_test.go b/common/protocol_test.go index ad2f3968..01bb4728 100644 --- a/common/protocol_test.go +++ b/common/protocol_test.go @@ -78,7 +78,7 @@ func TestMain(m *testing.M) { providerConf := config.GetProviderConf() logger.InfoToConsole("Starting COMMON tests, provider: %v", providerConf.Driver) - err = common.Initialize(config.GetCommonConfig()) + err = common.Initialize(config.GetCommonConfig(), 0) if err != nil { logger.WarnToConsole("error initializing common: %v", err) os.Exit(1) @@ -625,6 +625,8 @@ func TestFileNotAllowedErrors(t *testing.T) { func TestTruncateQuotaLimits(t *testing.T) { u := getTestUser() u.QuotaSize = 20 + u.UploadDataTransfer = 1000 + u.DownloadDataTransfer = 5000 mappedPath1 := filepath.Join(os.TempDir(), "mapped1") folderName1 := filepath.Base(mappedPath1) vdirPath1 := "/vmapped1" @@ -912,6 +914,13 @@ func TestVirtualFoldersQuotaRenameOverwrite(t *testing.T) { defer client.Close() err = writeSFTPFile(path.Join(vdirPath1, testFileName), testFileSize, client) assert.NoError(t, err) + f, err := client.Open(path.Join(vdirPath1, testFileName)) + assert.NoError(t, err) + contents, err := io.ReadAll(f) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + assert.Len(t, contents, int(testFileSize)) err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) assert.NoError(t, err) err = writeSFTPFile(path.Join(vdirPath1, testFileName1), testFileSize1, client) @@ -1914,6 +1923,84 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) { assert.NoError(t, err) } +func TestTransferQuotaLimits(t *testing.T) { + u := getTestUser() + u.TotalDataTransfer = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + testFileSize := int64(524288) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + f, err := client.Open(testFileName) + assert.NoError(t, err) + contents := make([]byte, testFileSize) + n, err := io.ReadFull(f, contents) + assert.NoError(t, err) + assert.Equal(t, int(testFileSize), n) + assert.Len(t, contents, int(testFileSize)) + err = f.Close() + assert.NoError(t, err) + _, err = client.Open(testFileName) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_FAILURE") + assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) + } + err = writeSFTPFile(testFileName, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_FAILURE") + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + } + // test the limit while uploading/downloading + user.TotalDataTransfer = 0 + user.UploadDataTransfer = 1 + user.DownloadDataTransfer = 1 + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + testFileSize := int64(450000) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + f, err := client.Open(testFileName) + if assert.NoError(t, err) { + _, err = io.ReadAll(f) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + } + f, err = client.Open(testFileName) + if assert.NoError(t, err) { + _, err = io.ReadAll(f) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_FAILURE") + assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) + } + err = f.Close() + assert.Error(t, err) + } + + err = writeSFTPFile(testFileName, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_FAILURE") + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + func TestVirtualFoldersLink(t *testing.T) { u := getTestUser() mappedPath1 := filepath.Join(os.TempDir(), "vdir1") @@ -2284,7 +2371,7 @@ func TestDbDefenderErrors(t *testing.T) { configCopy := common.Config common.Config.DefenderConfig.Enabled = true common.Config.DefenderConfig.Driver = common.DefenderDriverProvider - err := common.Initialize(common.Config) + err := common.Initialize(common.Config, 0) assert.NoError(t, err) testIP := "127.1.1.1" @@ -2325,7 +2412,7 @@ func TestDbDefenderErrors(t *testing.T) { assert.NoError(t, err) common.Config = configCopy - err = common.Initialize(common.Config) + err = common.Initialize(common.Config, 0) assert.NoError(t, err) } @@ -2341,32 +2428,45 @@ func TestDelayedQuotaUpdater(t *testing.T) { u := getTestUser() u.QuotaFiles = 100 + u.TotalDataTransfer = 2000 user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) err = dataprovider.UpdateUserQuota(&user, 10, 6000, false) assert.NoError(t, err) - files, size, err := dataprovider.GetUsedQuota(user.Username) + err = dataprovider.UpdateUserTransferQuota(&user, 100, 200, false) + assert.NoError(t, err) + files, size, ulSize, dlSize, err := dataprovider.GetUsedQuota(user.Username) assert.NoError(t, err) assert.Equal(t, 10, files) assert.Equal(t, int64(6000), size) + assert.Equal(t, int64(100), ulSize) + assert.Equal(t, int64(200), dlSize) userGet, err := dataprovider.UserExists(user.Username) assert.NoError(t, err) assert.Equal(t, 0, userGet.UsedQuotaFiles) assert.Equal(t, int64(0), userGet.UsedQuotaSize) + assert.Equal(t, int64(0), userGet.UsedUploadDataTransfer) + assert.Equal(t, int64(0), userGet.UsedDownloadDataTransfer) err = dataprovider.UpdateUserQuota(&user, 10, 6000, true) assert.NoError(t, err) - files, size, err = dataprovider.GetUsedQuota(user.Username) + err = dataprovider.UpdateUserTransferQuota(&user, 100, 200, true) + assert.NoError(t, err) + files, size, ulSize, dlSize, err = dataprovider.GetUsedQuota(user.Username) assert.NoError(t, err) assert.Equal(t, 10, files) assert.Equal(t, int64(6000), size) + assert.Equal(t, int64(100), ulSize) + assert.Equal(t, int64(200), dlSize) userGet, err = dataprovider.UserExists(user.Username) assert.NoError(t, err) assert.Equal(t, 10, userGet.UsedQuotaFiles) assert.Equal(t, int64(6000), userGet.UsedQuotaSize) + assert.Equal(t, int64(100), userGet.UsedUploadDataTransfer) + assert.Equal(t, int64(200), userGet.UsedDownloadDataTransfer) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) @@ -2559,6 +2659,7 @@ func TestGetQuotaError(t *testing.T) { t.Skip("this test is not available with the memory provider") } u := getTestUser() + u.TotalDataTransfer = 2000 mappedPath := filepath.Join(os.TempDir(), "vdir") folderName := filepath.Base(mappedPath) vdirPath := "/vpath" diff --git a/common/transfer.go b/common/transfer.go index 95c479cb..f33f5c68 100644 --- a/common/transfer.go +++ b/common/transfer.go @@ -41,6 +41,7 @@ type BaseTransfer struct { //nolint:maligned AbortTransfer int32 aTime time.Time mTime time.Time + transferQuota dataprovider.TransferQuota sync.Mutex errAbort error ErrTransfer error @@ -49,6 +50,7 @@ type BaseTransfer struct { //nolint:maligned // NewBaseTransfer returns a new BaseTransfer and adds it to the given connection func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPath, effectiveFsPath, requestPath string, transferType int, minWriteOffset, initialSize, maxWriteSize, truncatedSize int64, isNewFile bool, fs vfs.Fs, + transferQuota dataprovider.TransferQuota, ) *BaseTransfer { t := &BaseTransfer{ ID: conn.GetTransferID(), @@ -68,6 +70,7 @@ func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPat MaxWriteSize: maxWriteSize, AbortTransfer: 0, truncatedSize: truncatedSize, + transferQuota: transferQuota, Fs: fs, } @@ -75,6 +78,11 @@ func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPat return t } +// GetTransferQuota returns data transfer quota limits +func (t *BaseTransfer) GetTransferQuota() dataprovider.TransferQuota { + return t.transferQuota +} + // SetFtpMode sets the FTP mode for the current transfer func (t *BaseTransfer) SetFtpMode(mode string) { t.ftpMode = mode @@ -140,9 +148,17 @@ func (t *BaseTransfer) GetTruncatedSize() int64 { return t.truncatedSize } -// GetMaxAllowedSize returns the max allowed size -func (t *BaseTransfer) GetMaxAllowedSize() int64 { - return t.MaxWriteSize +// HasSizeLimit returns true if there is an upload or download size limit +func (t *BaseTransfer) HasSizeLimit() bool { + if t.MaxWriteSize > 0 { + return true + } + if t.transferQuota.AllowedDLSize > 0 || t.transferQuota.AllowedULSize > 0 || + t.transferQuota.AllowedTotalSize > 0 { + return true + } + + return false } // GetVirtualPath returns the transfer virtual path @@ -182,6 +198,43 @@ func (t *BaseTransfer) SetCancelFn(cancelFn func()) { t.cancelFn = cancelFn } +// CheckRead returns an error if read if not allowed +func (t *BaseTransfer) CheckRead() error { + if t.transferQuota.AllowedDLSize == 0 && t.transferQuota.AllowedTotalSize == 0 { + return nil + } + if t.transferQuota.AllowedTotalSize > 0 { + if atomic.LoadInt64(&t.BytesSent)+atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedTotalSize { + return t.Connection.GetReadQuotaExceededError() + } + } else if t.transferQuota.AllowedDLSize > 0 { + if atomic.LoadInt64(&t.BytesSent) > t.transferQuota.AllowedDLSize { + return t.Connection.GetReadQuotaExceededError() + } + } + return nil +} + +// CheckWrite returns an error if write if not allowed +func (t *BaseTransfer) CheckWrite() error { + if t.MaxWriteSize > 0 && atomic.LoadInt64(&t.BytesReceived) > t.MaxWriteSize { + return t.Connection.GetQuotaExceededError() + } + if t.transferQuota.AllowedULSize == 0 && t.transferQuota.AllowedTotalSize == 0 { + return nil + } + if t.transferQuota.AllowedTotalSize > 0 { + if atomic.LoadInt64(&t.BytesSent)+atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedTotalSize { + return t.Connection.GetQuotaExceededError() + } + } else if t.transferQuota.AllowedULSize > 0 { + if atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedULSize { + return t.Connection.GetQuotaExceededError() + } + } + return nil +} + // Truncate changes the size of the opened file. // Supported for local fs only func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) { @@ -196,6 +249,10 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) { sizeDiff := initialSize - size t.MaxWriteSize += sizeDiff metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.transferType, t.ErrTransfer) + go func(ulSize, dlSize int64, user dataprovider.User) { + dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck + }(atomic.LoadInt64(&t.BytesReceived), atomic.LoadInt64(&t.BytesSent), t.Connection.User) + atomic.StoreInt64(&t.BytesReceived, 0) } t.Unlock() @@ -262,7 +319,10 @@ func (t *BaseTransfer) Close() error { if t.isNewFile { numFiles = 1 } - metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.transferType, t.ErrTransfer) + metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), + t.transferType, t.ErrTransfer) + dataprovider.UpdateUserTransferQuota(&t.Connection.User, atomic.LoadInt64(&t.BytesReceived), //nolint:errcheck + atomic.LoadInt64(&t.BytesSent), false) if t.File != nil && t.Connection.IsQuotaExceededError(t.ErrTransfer) { // if quota is exceeded we try to remove the partial file for uploads to local filesystem err = t.Fs.Remove(t.File.Name(), false) diff --git a/common/transfer_test.go b/common/transfer_test.go index f498cab1..8fcd1b9d 100644 --- a/common/transfer_test.go +++ b/common/transfer_test.go @@ -65,7 +65,7 @@ func TestTransferThrottling(t *testing.T) { wantedUploadElapsed -= wantedDownloadElapsed / 10 wantedDownloadElapsed -= wantedDownloadElapsed / 10 conn := NewBaseConnection("id", ProtocolSCP, "", "", u) - transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs) + transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) transfer.BytesReceived = testFileSize transfer.Connection.UpdateLastActivity() startTime := transfer.Connection.GetLastActivity() @@ -75,7 +75,7 @@ func TestTransferThrottling(t *testing.T) { err := transfer.Close() assert.NoError(t, err) - transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs) + transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) transfer.BytesSent = testFileSize transfer.Connection.UpdateLastActivity() startTime = transfer.Connection.GetLastActivity() @@ -102,7 +102,7 @@ func TestRealPath(t *testing.T) { require.NoError(t, err) conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", - TransferUpload, 0, 0, 0, 0, true, fs) + TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) rPath := transfer.GetRealFsPath(testFile) assert.Equal(t, testFile, rPath) rPath = conn.getRealFsPath(testFile) @@ -140,7 +140,7 @@ func TestTruncate(t *testing.T) { assert.NoError(t, err) conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 5, - 100, 0, false, fs) + 100, 0, false, fs, dataprovider.TransferQuota{}) err = conn.SetStat("/transfer_test_file", &StatAttributes{ Size: 2, @@ -158,7 +158,7 @@ func TestTruncate(t *testing.T) { } transfer = NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, - 100, 0, true, fs) + 100, 0, true, fs, dataprovider.TransferQuota{}) // file.Stat will fail on a closed file err = conn.SetStat("/transfer_test_file", &StatAttributes{ Size: 2, @@ -168,7 +168,8 @@ func TestTruncate(t *testing.T) { err = transfer.Close() assert.NoError(t, err) - transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, 0, true, fs) + transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, 0, true, + fs, dataprovider.TransferQuota{}) _, err = transfer.Truncate("mismatch", 0) assert.EqualError(t, err, errTransferMismatch.Error()) _, err = transfer.Truncate(testFile, 0) @@ -206,7 +207,7 @@ func TestTransferErrors(t *testing.T) { } conn := NewBaseConnection("id", ProtocolSFTP, "", "", u) transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, - 0, 0, 0, 0, true, fs) + 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) assert.Nil(t, transfer.cancelFn) assert.Equal(t, testFile, transfer.GetFsPath()) transfer.SetCancelFn(cancelFn) @@ -232,7 +233,8 @@ func TestTransferErrors(t *testing.T) { assert.FailNow(t, "unable to open test file") } fsPath := filepath.Join(os.TempDir(), "test_file") - transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs) + transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, + fs, dataprovider.TransferQuota{}) transfer.BytesReceived = 9 transfer.TransferError(errFake) assert.Error(t, transfer.ErrTransfer, errFake.Error()) @@ -251,7 +253,8 @@ func TestTransferErrors(t *testing.T) { if !assert.NoError(t, err) { assert.FailNow(t, "unable to open test file") } - transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs) + transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, + fs, dataprovider.TransferQuota{}) transfer.BytesReceived = 9 // the file is closed from the embedding struct before to call close err = file.Close() @@ -278,7 +281,7 @@ func TestRemovePartialCryptoFile(t *testing.T) { } conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, - 0, 0, 0, 0, true, fs) + 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) transfer.ErrTransfer = errors.New("test error") _, err = transfer.getUploadFileSize() assert.Error(t, err) @@ -302,3 +305,120 @@ func TestFTPMode(t *testing.T) { transfer.SetFtpMode("active") assert.Equal(t, "active", transfer.ftpMode) } + +func TestTransferQuota(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + TotalDataTransfer: -1, + UploadDataTransfer: -1, + DownloadDataTransfer: -1, + }, + } + user.Filters.DataTransferLimits = []sdk.DataTransferLimit{ + { + Sources: []string{"127.0.0.1/32", "192.168.1.0/24"}, + TotalDataTransfer: 100, + UploadDataTransfer: 0, + DownloadDataTransfer: 0, + }, + { + Sources: []string{"172.16.0.0/24"}, + TotalDataTransfer: 0, + UploadDataTransfer: 120, + DownloadDataTransfer: 150, + }, + } + ul, dl, total := user.GetDataTransferLimits("127.0.1.1") + assert.Equal(t, int64(0), ul) + assert.Equal(t, int64(0), dl) + assert.Equal(t, int64(0), total) + ul, dl, total = user.GetDataTransferLimits("127.0.0.1") + assert.Equal(t, int64(0), ul) + assert.Equal(t, int64(0), dl) + assert.Equal(t, int64(100*1048576), total) + ul, dl, total = user.GetDataTransferLimits("192.168.1.4") + assert.Equal(t, int64(0), ul) + assert.Equal(t, int64(0), dl) + assert.Equal(t, int64(100*1048576), total) + ul, dl, total = user.GetDataTransferLimits("172.16.0.2") + assert.Equal(t, int64(120*1048576), ul) + assert.Equal(t, int64(150*1048576), dl) + assert.Equal(t, int64(0), total) + transferQuota := dataprovider.TransferQuota{} + assert.True(t, transferQuota.HasDownloadSpace()) + assert.True(t, transferQuota.HasUploadSpace()) + transferQuota.TotalSize = -1 + transferQuota.ULSize = -1 + transferQuota.DLSize = -1 + assert.True(t, transferQuota.HasDownloadSpace()) + assert.True(t, transferQuota.HasUploadSpace()) + transferQuota.TotalSize = 100 + transferQuota.AllowedTotalSize = 10 + assert.True(t, transferQuota.HasDownloadSpace()) + assert.True(t, transferQuota.HasUploadSpace()) + transferQuota.AllowedTotalSize = 0 + assert.False(t, transferQuota.HasDownloadSpace()) + assert.False(t, transferQuota.HasUploadSpace()) + transferQuota.TotalSize = 0 + transferQuota.DLSize = 100 + transferQuota.ULSize = 50 + transferQuota.AllowedTotalSize = 0 + assert.False(t, transferQuota.HasDownloadSpace()) + assert.False(t, transferQuota.HasUploadSpace()) + transferQuota.AllowedDLSize = 1 + transferQuota.AllowedULSize = 1 + assert.True(t, transferQuota.HasDownloadSpace()) + assert.True(t, transferQuota.HasUploadSpace()) + transferQuota.AllowedDLSize = -10 + transferQuota.AllowedULSize = -1 + assert.False(t, transferQuota.HasDownloadSpace()) + assert.False(t, transferQuota.HasUploadSpace()) + + conn := NewBaseConnection("", ProtocolSFTP, "", "", user) + transfer := NewBaseTransfer(nil, conn, nil, "file.txt", "file.txt", "/transfer_test_file", TransferUpload, + 0, 0, 0, 0, true, vfs.NewOsFs("", os.TempDir(), ""), dataprovider.TransferQuota{}) + err := transfer.CheckRead() + assert.NoError(t, err) + err = transfer.CheckWrite() + assert.NoError(t, err) + + transfer.transferQuota = dataprovider.TransferQuota{ + AllowedTotalSize: 10, + } + transfer.BytesReceived = 5 + transfer.BytesSent = 4 + err = transfer.CheckRead() + assert.NoError(t, err) + err = transfer.CheckWrite() + assert.NoError(t, err) + + transfer.BytesSent = 6 + err = transfer.CheckRead() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) + } + err = transfer.CheckWrite() + assert.True(t, conn.IsQuotaExceededError(err)) + + transferQuota = dataprovider.TransferQuota{ + AllowedTotalSize: 0, + AllowedULSize: 10, + AllowedDLSize: 5, + } + transfer.transferQuota = transferQuota + assert.Equal(t, transferQuota, transfer.GetTransferQuota()) + err = transfer.CheckRead() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) + } + err = transfer.CheckWrite() + assert.NoError(t, err) + + transfer.BytesReceived = 11 + err = transfer.CheckRead() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) + } + err = transfer.CheckWrite() + assert.True(t, conn.IsQuotaExceededError(err)) +} diff --git a/common/transferschecker.go b/common/transferschecker.go index 35ba128b..2e47eb61 100644 --- a/common/transferschecker.go +++ b/common/transferschecker.go @@ -11,8 +11,14 @@ import ( ) type overquotaTransfer struct { - ConnID string - TransferID int64 + ConnID string + TransferID int64 + TransferType int +} + +type uploadAggregationKey struct { + Username string + FolderName string } // TransfersChecker defines the interface that transfer checkers must implement. @@ -21,17 +27,205 @@ type overquotaTransfer struct { type TransfersChecker interface { AddTransfer(transfer dataprovider.ActiveTransfer) RemoveTransfer(ID int64, connectionID string) - UpdateTransferCurrentSize(ulSize int64, dlSize int64, ID int64, connectionID string) + UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) GetOverquotaTransfers() []overquotaTransfer } -func getTransfersChecker() TransfersChecker { +func getTransfersChecker(isShared int) TransfersChecker { + if isShared == 1 { + logger.Info(logSender, "", "using provider transfer checker") + return &transfersCheckerDB{} + } + logger.Info(logSender, "", "using memory transfer checker") return &transfersCheckerMem{} } +type baseTransferChecker struct { + transfers []dataprovider.ActiveTransfer +} + +func (t *baseTransferChecker) isDataTransferExceeded(user dataprovider.User, transfer dataprovider.ActiveTransfer, ulSize, + dlSize int64, +) bool { + ulQuota, dlQuota, totalQuota := user.GetDataTransferLimits(transfer.IP) + if totalQuota > 0 { + allowedSize := totalQuota - (user.UsedUploadDataTransfer + user.UsedDownloadDataTransfer) + if ulSize+dlSize > allowedSize { + return transfer.CurrentDLSize > 0 || transfer.CurrentULSize > 0 + } + } + if dlQuota > 0 { + allowedSize := dlQuota - user.UsedDownloadDataTransfer + if dlSize > allowedSize { + return transfer.CurrentDLSize > 0 + } + } + if ulQuota > 0 { + allowedSize := ulQuota - user.UsedUploadDataTransfer + if ulSize > allowedSize { + return transfer.CurrentULSize > 0 + } + } + return false +} + +func (t *baseTransferChecker) getRemainingDiskQuota(user dataprovider.User, folderName string) (int64, error) { + var result int64 + + if folderName != "" { + for _, folder := range user.VirtualFolders { + if folder.Name == folderName { + if folder.QuotaSize > 0 { + return folder.QuotaSize - folder.UsedQuotaSize, nil + } + } + } + } else { + if user.QuotaSize > 0 { + return user.QuotaSize - user.UsedQuotaSize, nil + } + } + + return result, errors.New("no quota limit defined") +} + +func (t *baseTransferChecker) aggregateTransfersByUser(usersToFetch map[string]bool, +) (map[string]bool, map[string][]dataprovider.ActiveTransfer) { + aggregations := make(map[string][]dataprovider.ActiveTransfer) + for _, transfer := range t.transfers { + aggregations[transfer.Username] = append(aggregations[transfer.Username], transfer) + if len(aggregations[transfer.Username]) > 1 { + if _, ok := usersToFetch[transfer.Username]; !ok { + usersToFetch[transfer.Username] = false + } + } + } + + return usersToFetch, aggregations +} + +func (t *baseTransferChecker) aggregateUploadTransfers() (map[string]bool, map[int][]dataprovider.ActiveTransfer) { + usersToFetch := make(map[string]bool) + aggregations := make(map[int][]dataprovider.ActiveTransfer) + var keys []uploadAggregationKey + + for _, transfer := range t.transfers { + if transfer.Type != TransferUpload { + continue + } + key := -1 + for idx, k := range keys { + if k.Username == transfer.Username && k.FolderName == transfer.FolderName { + key = idx + break + } + } + if key == -1 { + key = len(keys) + } + keys = append(keys, uploadAggregationKey{ + Username: transfer.Username, + FolderName: transfer.FolderName, + }) + + aggregations[key] = append(aggregations[key], transfer) + if len(aggregations[key]) > 1 { + if transfer.FolderName != "" { + usersToFetch[transfer.Username] = true + } else { + if _, ok := usersToFetch[transfer.Username]; !ok { + usersToFetch[transfer.Username] = false + } + } + } + } + + return usersToFetch, aggregations +} + +func (t *baseTransferChecker) getUsersToCheck(usersToFetch map[string]bool) (map[string]dataprovider.User, error) { + users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch) + if err != nil { + return nil, err + } + + usersMap := make(map[string]dataprovider.User) + + for _, user := range users { + usersMap[user.Username] = user + } + + return usersMap, nil +} + +func (t *baseTransferChecker) getOverquotaTransfers(usersToFetch map[string]bool, + uploadAggregations map[int][]dataprovider.ActiveTransfer, + userAggregations map[string][]dataprovider.ActiveTransfer, +) []overquotaTransfer { + if len(usersToFetch) == 0 { + return nil + } + usersMap, err := t.getUsersToCheck(usersToFetch) + if err != nil { + logger.Warn(logSender, "", "unable to check transfers, error getting users quota: %v", err) + return nil + } + + var overquotaTransfers []overquotaTransfer + + for _, transfers := range uploadAggregations { + username := transfers[0].Username + folderName := transfers[0].FolderName + remaningDiskQuota, err := t.getRemainingDiskQuota(usersMap[username], folderName) + if err != nil { + continue + } + var usedDiskQuota int64 + for _, tr := range transfers { + // We optimistically assume that a cloud transfer that replaces an existing + // file will be successful + usedDiskQuota += tr.CurrentULSize - tr.TruncatedSize + } + logger.Debug(logSender, "", "username %#v, folder %#v, concurrent transfers: %v, remaining disk quota (bytes): %v, disk quota used in ongoing transfers (bytes): %v", + username, folderName, len(transfers), remaningDiskQuota, usedDiskQuota) + if usedDiskQuota > remaningDiskQuota { + for _, tr := range transfers { + if tr.CurrentULSize > tr.TruncatedSize { + overquotaTransfers = append(overquotaTransfers, overquotaTransfer{ + ConnID: tr.ConnID, + TransferID: tr.ID, + TransferType: tr.Type, + }) + } + } + } + } + + for username, transfers := range userAggregations { + var ulSize, dlSize int64 + for _, tr := range transfers { + ulSize += tr.CurrentULSize + dlSize += tr.CurrentDLSize + } + logger.Debug(logSender, "", "username %#v, concurrent transfers: %v, quota (bytes) used in ongoing transfers, ul: %v, dl: %v", + username, len(transfers), ulSize, dlSize) + for _, tr := range transfers { + if t.isDataTransferExceeded(usersMap[username], tr, ulSize, dlSize) { + overquotaTransfers = append(overquotaTransfers, overquotaTransfer{ + ConnID: tr.ConnID, + TransferID: tr.ID, + TransferType: tr.Type, + }) + } + } + } + + return overquotaTransfers +} + type transfersCheckerMem struct { sync.RWMutex - transfers []dataprovider.ActiveTransfer + baseTransferChecker } func (t *transfersCheckerMem) AddTransfer(transfer dataprovider.ActiveTransfer) { @@ -55,7 +249,7 @@ func (t *transfersCheckerMem) RemoveTransfer(ID int64, connectionID string) { } } -func (t *transfersCheckerMem) UpdateTransferCurrentSize(ulSize int64, dlSize int64, ID int64, connectionID string) { +func (t *transfersCheckerMem) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) { t.Lock() defer t.Unlock() @@ -69,99 +263,53 @@ func (t *transfersCheckerMem) UpdateTransferCurrentSize(ulSize int64, dlSize int } } -func (t *transfersCheckerMem) getRemainingDiskQuota(user dataprovider.User, folderName string) (int64, error) { - var result int64 - - if folderName != "" { - for _, folder := range user.VirtualFolders { - if folder.Name == folderName { - if folder.QuotaSize > 0 { - return folder.QuotaSize - folder.UsedQuotaSize, nil - } - } - } - } else { - if user.QuotaSize > 0 { - return user.QuotaSize - user.UsedQuotaSize, nil - } - } - - return result, errors.New("no quota limit defined") -} - -func (t *transfersCheckerMem) aggregateTransfers() (map[string]bool, map[string][]dataprovider.ActiveTransfer) { - t.RLock() - defer t.RUnlock() - - usersToFetch := make(map[string]bool) - aggregations := make(map[string][]dataprovider.ActiveTransfer) - for _, transfer := range t.transfers { - key := transfer.GetKey() - aggregations[key] = append(aggregations[key], transfer) - if len(aggregations[key]) > 1 { - if transfer.FolderName != "" { - usersToFetch[transfer.Username] = true - } else { - if _, ok := usersToFetch[transfer.Username]; !ok { - usersToFetch[transfer.Username] = false - } - } - } - } - - return usersToFetch, aggregations -} - func (t *transfersCheckerMem) GetOverquotaTransfers() []overquotaTransfer { - usersToFetch, aggregations := t.aggregateTransfers() + t.RLock() - if len(usersToFetch) == 0 { - return nil - } + usersToFetch, uploadAggregations := t.aggregateUploadTransfers() + usersToFetch, userAggregations := t.aggregateTransfersByUser(usersToFetch) - users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch) - if err != nil { - logger.Warn(logSender, "", "unable to check transfers, error getting users quota: %v", err) - return nil - } + t.RUnlock() - usersMap := make(map[string]dataprovider.User) + return t.getOverquotaTransfers(usersToFetch, uploadAggregations, userAggregations) +} - for _, user := range users { - usersMap[user.Username] = user - } +type transfersCheckerDB struct { + baseTransferChecker + lastCleanup time.Time +} - var overquotaTransfers []overquotaTransfer +func (t *transfersCheckerDB) AddTransfer(transfer dataprovider.ActiveTransfer) { + dataprovider.AddActiveTransfer(transfer) +} - for _, transfers := range aggregations { - if len(transfers) > 1 { - username := transfers[0].Username - folderName := transfers[0].FolderName - // transfer type is always upload for now - remaningDiskQuota, err := t.getRemainingDiskQuota(usersMap[username], folderName) - if err != nil { - continue - } - var usedDiskQuota int64 - for _, tr := range transfers { - // We optimistically assume that a cloud transfer that replaces an existing - // file will be successful - usedDiskQuota += tr.CurrentULSize - tr.TruncatedSize - } - logger.Debug(logSender, "", "username %#v, folder %#v, concurrent transfers: %v, remaining disk quota: %v, disk quota used in ongoing transfers: %v", - username, folderName, len(transfers), remaningDiskQuota, usedDiskQuota) - if usedDiskQuota > remaningDiskQuota { - for _, tr := range transfers { - if tr.CurrentULSize > tr.TruncatedSize { - overquotaTransfers = append(overquotaTransfers, overquotaTransfer{ - ConnID: tr.ConnID, - TransferID: tr.ID, - }) - } - } - } +func (t *transfersCheckerDB) RemoveTransfer(ID int64, connectionID string) { + dataprovider.RemoveActiveTransfer(ID, connectionID) +} + +func (t *transfersCheckerDB) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) { + dataprovider.UpdateActiveTransferSizes(ulSize, dlSize, ID, connectionID) +} + +func (t *transfersCheckerDB) GetOverquotaTransfers() []overquotaTransfer { + if t.lastCleanup.IsZero() || t.lastCleanup.Add(periodicTimeoutCheckInterval*15).Before(time.Now()) { + before := time.Now().Add(-periodicTimeoutCheckInterval * 5) + err := dataprovider.CleanupActiveTransfers(before) + logger.Debug(logSender, "", "cleanup active transfers completed, err: %v", err) + if err == nil { + t.lastCleanup = time.Now() } } + var err error + from := time.Now().Add(-periodicTimeoutCheckInterval * 2) + t.transfers, err = dataprovider.GetActiveTransfers(from) + if err != nil { + logger.Error(logSender, "", "unable to check overquota transfers, error getting active transfers: %v", err) + return nil + } - return overquotaTransfers + usersToFetch, uploadAggregations := t.aggregateUploadTransfers() + usersToFetch, userAggregations := t.aggregateTransfersByUser(usersToFetch) + + return t.getOverquotaTransfers(usersToFetch, uploadAggregations, userAggregations) } diff --git a/common/transferschecker_test.go b/common/transferschecker_test.go index 9345b1d0..50846ea4 100644 --- a/common/transferschecker_test.go +++ b/common/transferschecker_test.go @@ -60,7 +60,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { BaseConnection: conn1, } transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), - "/file1", TransferUpload, 0, 0, 120, 0, true, fsUser) + "/file1", TransferUpload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{}) transfer1.BytesReceived = 150 Connections.Add(fakeConn1) // the transferschecker will do nothing if there is only one ongoing transfer @@ -73,7 +73,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { BaseConnection: conn2, } transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), - "/file2", TransferUpload, 0, 0, 120, 40, true, fsUser) + "/file2", TransferUpload, 0, 0, 120, 40, true, fsUser, dataprovider.TransferQuota{}) transfer1.BytesReceived = 50 transfer2.BytesReceived = 60 Connections.Add(fakeConn2) @@ -84,7 +84,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { BaseConnection: conn3, } transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"), - "/file3", TransferDownload, 0, 0, 120, 0, true, fsUser) + "/file3", TransferDownload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{}) transfer3.BytesReceived = 60 // this value will be ignored, this is a download Connections.Add(fakeConn3) @@ -145,7 +145,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { } transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(os.TempDir(), folderName, "file1"), filepath.Join(os.TempDir(), folderName, "file1"), path.Join(vdirPath, "/file1"), TransferUpload, 0, 0, - 100, 0, true, fsFolder) + 100, 0, true, fsFolder, dataprovider.TransferQuota{}) Connections.Add(fakeConn4) connID5 := xid.New().String() conn5 := NewBaseConnection(connID5, ProtocolSFTP, "", "", user) @@ -154,7 +154,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { } transfer5 := NewBaseTransfer(nil, conn5, nil, filepath.Join(os.TempDir(), folderName, "file2"), filepath.Join(os.TempDir(), folderName, "file2"), path.Join(vdirPath, "/file2"), TransferUpload, 0, 0, - 100, 0, true, fsFolder) + 100, 0, true, fsFolder, dataprovider.TransferQuota{}) Connections.Add(fakeConn5) transfer4.BytesReceived = 50 @@ -188,6 +188,17 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { assert.NoError(t, err) } + err = transfer1.Close() + assert.NoError(t, err) + err = transfer2.Close() + assert.NoError(t, err) + err = transfer3.Close() + assert.NoError(t, err) + err = transfer4.Close() + assert.NoError(t, err) + err = transfer5.Close() + assert.NoError(t, err) + Connections.Remove(fakeConn1.GetID()) Connections.Remove(fakeConn2.GetID()) Connections.Remove(fakeConn3.GetID()) @@ -207,6 +218,118 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { assert.NoError(t, err) } +func TestTransferCheckerTransferQuota(t *testing.T) { + username := "transfers_check_username" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Password: "test_pwd", + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + TotalDataTransfer: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + } + err := dataprovider.AddUser(&user, "", "") + assert.NoError(t, err) + + connID1 := xid.New().String() + fsUser, err := user.GetFilesystemForPath("/file1", connID1) + assert.NoError(t, err) + conn1 := NewBaseConnection(connID1, ProtocolSFTP, "", "192.168.1.1", user) + fakeConn1 := &fakeConnection{ + BaseConnection: conn1, + } + transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), + "/file1", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100}) + transfer1.BytesReceived = 150 + Connections.Add(fakeConn1) + // the transferschecker will do nothing if there is only one ongoing transfer + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + + connID2 := xid.New().String() + conn2 := NewBaseConnection(connID2, ProtocolSFTP, "", "127.0.0.1", user) + fakeConn2 := &fakeConnection{ + BaseConnection: conn2, + } + transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), + "/file2", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100}) + transfer2.BytesReceived = 150 + Connections.Add(fakeConn2) + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + // now test overquota + transfer1.BytesReceived = 1024*1024 + 1 + transfer2.BytesReceived = 0 + Connections.checkTransfers() + assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort)) + assert.Nil(t, transfer2.errAbort) + transfer1.errAbort = nil + transfer1.BytesReceived = 1024*1024 + 1 + transfer2.BytesReceived = 1024 + Connections.checkTransfers() + assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort)) + assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort)) + transfer1.BytesReceived = 0 + transfer2.BytesReceived = 0 + transfer1.errAbort = nil + transfer2.errAbort = nil + + err = transfer1.Close() + assert.NoError(t, err) + err = transfer2.Close() + assert.NoError(t, err) + Connections.Remove(fakeConn1.GetID()) + Connections.Remove(fakeConn2.GetID()) + + connID3 := xid.New().String() + conn3 := NewBaseConnection(connID3, ProtocolSFTP, "", "", user) + fakeConn3 := &fakeConnection{ + BaseConnection: conn3, + } + transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), + "/file1", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100}) + transfer3.BytesSent = 150 + Connections.Add(fakeConn3) + + connID4 := xid.New().String() + conn4 := NewBaseConnection(connID4, ProtocolSFTP, "", "", user) + fakeConn4 := &fakeConnection{ + BaseConnection: conn4, + } + transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), + "/file2", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100}) + transfer4.BytesSent = 150 + Connections.Add(fakeConn4) + Connections.checkTransfers() + assert.Nil(t, transfer3.errAbort) + assert.Nil(t, transfer4.errAbort) + + transfer3.BytesSent = 512 * 1024 + transfer4.BytesSent = 512*1024 + 1 + Connections.checkTransfers() + if assert.Error(t, transfer3.errAbort) { + assert.Contains(t, transfer3.errAbort.Error(), ErrReadQuotaExceeded.Error()) + } + if assert.Error(t, transfer4.errAbort) { + assert.Contains(t, transfer4.errAbort.Error(), ErrReadQuotaExceeded.Error()) + } + + Connections.Remove(fakeConn3.GetID()) + Connections.Remove(fakeConn4.GetID()) + stats := Connections.GetStats() + assert.Len(t, stats, 0) + + err = dataprovider.DeleteUser(user.Username, "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + func TestAggregateTransfers(t *testing.T) { checker := transfersCheckerMem{} checker.AddTransfer(dataprovider.ActiveTransfer{ @@ -221,7 +344,7 @@ func TestAggregateTransfers(t *testing.T) { CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) - usersToFetch, aggregations := checker.aggregateTransfers() + usersToFetch, aggregations := checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 0) assert.Len(t, aggregations, 1) @@ -238,9 +361,9 @@ func TestAggregateTransfers(t *testing.T) { UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) - usersToFetch, aggregations = checker.aggregateTransfers() + usersToFetch, aggregations = checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 0) - assert.Len(t, aggregations, 2) + assert.Len(t, aggregations, 1) checker.AddTransfer(dataprovider.ActiveTransfer{ ID: 1, @@ -255,9 +378,9 @@ func TestAggregateTransfers(t *testing.T) { UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) - usersToFetch, aggregations = checker.aggregateTransfers() + usersToFetch, aggregations = checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 0) - assert.Len(t, aggregations, 3) + assert.Len(t, aggregations, 2) checker.AddTransfer(dataprovider.ActiveTransfer{ ID: 1, @@ -272,9 +395,9 @@ func TestAggregateTransfers(t *testing.T) { UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) - usersToFetch, aggregations = checker.aggregateTransfers() + usersToFetch, aggregations = checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 0) - assert.Len(t, aggregations, 4) + assert.Len(t, aggregations, 3) checker.AddTransfer(dataprovider.ActiveTransfer{ ID: 1, @@ -289,13 +412,13 @@ func TestAggregateTransfers(t *testing.T) { UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) - usersToFetch, aggregations = checker.aggregateTransfers() + usersToFetch, aggregations = checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 1) val, ok := usersToFetch["user"] assert.True(t, ok) assert.False(t, val) - assert.Len(t, aggregations, 4) - aggregate, ok := aggregations["user0"] + assert.Len(t, aggregations, 3) + aggregate, ok := aggregations[0] assert.True(t, ok) assert.Len(t, aggregate, 2) @@ -312,13 +435,13 @@ func TestAggregateTransfers(t *testing.T) { UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) - usersToFetch, aggregations = checker.aggregateTransfers() + usersToFetch, aggregations = checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 1) val, ok = usersToFetch["user"] assert.True(t, ok) assert.False(t, val) - assert.Len(t, aggregations, 4) - aggregate, ok = aggregations["user0"] + assert.Len(t, aggregations, 3) + aggregate, ok = aggregations[0] assert.True(t, ok) assert.Len(t, aggregate, 3) @@ -335,16 +458,16 @@ func TestAggregateTransfers(t *testing.T) { UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) - usersToFetch, aggregations = checker.aggregateTransfers() + usersToFetch, aggregations = checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 1) val, ok = usersToFetch["user"] assert.True(t, ok) assert.True(t, val) - assert.Len(t, aggregations, 4) - aggregate, ok = aggregations["user0"] + assert.Len(t, aggregations, 3) + aggregate, ok = aggregations[0] assert.True(t, ok) assert.Len(t, aggregate, 3) - aggregate, ok = aggregations["userfolder0"] + aggregate, ok = aggregations[1] assert.True(t, ok) assert.Len(t, aggregate, 2) @@ -361,20 +484,67 @@ func TestAggregateTransfers(t *testing.T) { UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), }) - usersToFetch, aggregations = checker.aggregateTransfers() + usersToFetch, aggregations = checker.aggregateUploadTransfers() assert.Len(t, usersToFetch, 1) val, ok = usersToFetch["user"] assert.True(t, ok) assert.True(t, val) - assert.Len(t, aggregations, 4) - aggregate, ok = aggregations["user0"] + assert.Len(t, aggregations, 3) + aggregate, ok = aggregations[0] assert.True(t, ok) assert.Len(t, aggregate, 4) - aggregate, ok = aggregations["userfolder0"] + aggregate, ok = aggregations[1] assert.True(t, ok) assert.Len(t, aggregate, 2) } +func TestDataTransferExceeded(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + TotalDataTransfer: 1, + }, + } + transfer := dataprovider.ActiveTransfer{ + CurrentULSize: 0, + CurrentDLSize: 0, + } + user.UsedDownloadDataTransfer = 1024 * 1024 + user.UsedUploadDataTransfer = 512 * 1024 + checker := transfersCheckerMem{} + res := checker.isDataTransferExceeded(user, transfer, 100, 100) + assert.False(t, res) + transfer.CurrentULSize = 1 + res = checker.isDataTransferExceeded(user, transfer, 100, 100) + assert.True(t, res) + user.UsedDownloadDataTransfer = 512*1024 - 100 + user.UsedUploadDataTransfer = 512*1024 - 100 + res = checker.isDataTransferExceeded(user, transfer, 100, 100) + assert.False(t, res) + res = checker.isDataTransferExceeded(user, transfer, 101, 100) + assert.True(t, res) + + user.TotalDataTransfer = 0 + user.DownloadDataTransfer = 1 + user.UsedDownloadDataTransfer = 512 * 1024 + transfer.CurrentULSize = 0 + transfer.CurrentDLSize = 100 + res = checker.isDataTransferExceeded(user, transfer, 0, 512*1024) + assert.False(t, res) + res = checker.isDataTransferExceeded(user, transfer, 0, 512*1024+1) + assert.True(t, res) + + user.DownloadDataTransfer = 0 + user.UploadDataTransfer = 1 + user.UsedUploadDataTransfer = 512 * 1024 + transfer.CurrentULSize = 0 + transfer.CurrentDLSize = 0 + res = checker.isDataTransferExceeded(user, transfer, 512*1024+1, 0) + assert.False(t, res) + transfer.CurrentULSize = 1 + res = checker.isDataTransferExceeded(user, transfer, 512*1024+1, 0) + assert.True(t, res) +} + func TestGetUsersForQuotaCheck(t *testing.T) { usersToFetch := make(map[string]bool) for i := 0; i < 50; i++ { @@ -407,6 +577,17 @@ func TestGetUsersForQuotaCheck(t *testing.T) { QuotaSize: 100, }, }, + Filters: dataprovider.UserFilters{ + BaseUserFilters: sdk.BaseUserFilters{ + DataTransferLimits: []sdk.DataTransferLimit{ + { + Sources: []string{"172.16.0.0/16"}, + UploadDataTransfer: 50, + DownloadDataTransfer: 80, + }, + }, + }, + }, } err = dataprovider.AddUser(&user, "", "") assert.NoError(t, err) @@ -434,6 +615,14 @@ func TestGetUsersForQuotaCheck(t *testing.T) { assert.Len(t, user.VirtualFolders, 0, user.Username) } } + ul, dl, total := user.GetDataTransferLimits("127.1.1.1") + assert.Equal(t, int64(0), ul) + assert.Equal(t, int64(0), dl) + assert.Equal(t, int64(0), total) + ul, dl, total = user.GetDataTransferLimits("172.16.2.3") + assert.Equal(t, int64(50*1024*1024), ul) + assert.Equal(t, int64(80*1024*1024), dl) + assert.Equal(t, int64(0), total) } for i := 0; i < 40; i++ { @@ -447,3 +636,87 @@ func TestGetUsersForQuotaCheck(t *testing.T) { assert.NoError(t, err) assert.Len(t, users, 0) } + +func TestDBTransferChecker(t *testing.T) { + if !isDbTransferCheckerSupported() { + t.Skip("this test is not supported with the current database provider") + } + providerConf := dataprovider.GetProviderConfig() + err := dataprovider.Close() + assert.NoError(t, err) + providerConf.IsShared = 1 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + c := getTransfersChecker(1) + checker, ok := c.(*transfersCheckerDB) + assert.True(t, ok) + assert.True(t, checker.lastCleanup.IsZero()) + transfer1 := dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferDownload, + ConnID: xid.New().String(), + Username: "user1", + FolderName: "folder1", + IP: "127.0.0.1", + } + checker.AddTransfer(transfer1) + transfers, err := dataprovider.GetActiveTransfers(time.Now().Add(24 * time.Hour)) + assert.NoError(t, err) + assert.Len(t, transfers, 0) + transfers, err = dataprovider.GetActiveTransfers(time.Now().Add(-periodicTimeoutCheckInterval * 2)) + assert.NoError(t, err) + var createdAt, updatedAt int64 + if assert.Len(t, transfers, 1) { + transfer := transfers[0] + assert.Equal(t, transfer1.ID, transfer.ID) + assert.Equal(t, transfer1.Type, transfer.Type) + assert.Equal(t, transfer1.ConnID, transfer.ConnID) + assert.Equal(t, transfer1.Username, transfer.Username) + assert.Equal(t, transfer1.IP, transfer.IP) + assert.Equal(t, transfer1.FolderName, transfer.FolderName) + assert.Greater(t, transfer.CreatedAt, int64(0)) + assert.Greater(t, transfer.UpdatedAt, int64(0)) + assert.Equal(t, int64(0), transfer.CurrentDLSize) + assert.Equal(t, int64(0), transfer.CurrentULSize) + createdAt = transfer.CreatedAt + updatedAt = transfer.UpdatedAt + } + time.Sleep(100 * time.Millisecond) + checker.UpdateTransferCurrentSizes(100, 150, transfer1.ID, transfer1.ConnID) + transfers, err = dataprovider.GetActiveTransfers(time.Now().Add(-periodicTimeoutCheckInterval * 2)) + assert.NoError(t, err) + if assert.Len(t, transfers, 1) { + transfer := transfers[0] + assert.Equal(t, int64(150), transfer.CurrentDLSize) + assert.Equal(t, int64(100), transfer.CurrentULSize) + assert.Equal(t, createdAt, transfer.CreatedAt) + assert.Greater(t, transfer.UpdatedAt, updatedAt) + } + res := checker.GetOverquotaTransfers() + assert.Len(t, res, 0) + + checker.RemoveTransfer(transfer1.ID, transfer1.ConnID) + transfers, err = dataprovider.GetActiveTransfers(time.Now().Add(-periodicTimeoutCheckInterval * 2)) + assert.NoError(t, err) + assert.Len(t, transfers, 0) + + err = dataprovider.Close() + assert.NoError(t, err) + res = checker.GetOverquotaTransfers() + assert.Len(t, res, 0) + providerConf.IsShared = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func isDbTransferCheckerSupported() bool { + // SQLite shares the implementation with other SQL-based provider but it makes no sense + // to use it outside test cases + switch dataprovider.GetProviderStatus().Driver { + case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, + dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: + return true + default: + return false + } +} diff --git a/dataprovider/bolt.go b/dataprovider/bolt.go index 9d944855..a9b84036 100644 --- a/dataprovider/bolt.go +++ b/dataprovider/bolt.go @@ -251,6 +251,41 @@ func (p *BoltProvider) updateAdminLastLogin(username string) error { }) } +func (p *BoltProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := getUsersBucket(tx) + if err != nil { + return err + } + var u []byte + if u = bucket.Get([]byte(username)); u == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("username %#v does not exist, unable to update transfer quota", + username)) + } + var user User + err = json.Unmarshal(u, &user) + if err != nil { + return err + } + if !reset { + user.UsedUploadDataTransfer += uploadSize + user.UsedDownloadDataTransfer += downloadSize + } else { + user.UsedUploadDataTransfer = uploadSize + user.UsedDownloadDataTransfer = downloadSize + } + user.LastQuotaUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(user) + if err != nil { + return err + } + err = bucket.Put([]byte(username), buf) + providerLog(logger.LevelDebug, "transfer quota updated for user %#v, ul increment: %v dl increment: %v is reset? %v", + username, uploadSize, downloadSize, reset) + return err + }) +} + func (p *BoltProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { return p.dbHandle.Update(func(tx *bolt.Tx) error { bucket, err := getUsersBucket(tx) @@ -285,13 +320,13 @@ func (p *BoltProvider) updateQuota(username string, filesAdd int, sizeAdd int64, }) } -func (p *BoltProvider) getUsedQuota(username string) (int, int64, error) { +func (p *BoltProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { user, err := p.userExists(username) if err != nil { providerLog(logger.LevelError, "unable to get quota for user %v error: %v", username, err) - return 0, 0, err + return 0, 0, 0, 0, err } - return user.UsedQuotaFiles, user.UsedQuotaSize, err + return user.UsedQuotaFiles, user.UsedQuotaSize, user.UsedUploadDataTransfer, user.UsedDownloadDataTransfer, err } func (p *BoltProvider) adminExists(username string) (Admin, error) { @@ -513,6 +548,8 @@ func (p *BoltProvider) addUser(user *User) error { user.LastQuotaUpdate = 0 user.UsedQuotaSize = 0 user.UsedQuotaFiles = 0 + user.UsedUploadDataTransfer = 0 + user.UsedDownloadDataTransfer = 0 user.LastLogin = 0 user.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) @@ -569,6 +606,8 @@ func (p *BoltProvider) updateUser(user *User) error { user.LastQuotaUpdate = oldUser.LastQuotaUpdate user.UsedQuotaSize = oldUser.UsedQuotaSize user.UsedQuotaFiles = oldUser.UsedQuotaFiles + user.UsedUploadDataTransfer = oldUser.UsedUploadDataTransfer + user.UsedDownloadDataTransfer = oldUser.UsedDownloadDataTransfer user.LastLogin = oldUser.LastLogin user.CreatedAt = oldUser.CreatedAt user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) @@ -1444,6 +1483,26 @@ func (p *BoltProvider) cleanupDefender(from int64) error { return ErrNotImplemented } +func (p *BoltProvider) addActiveTransfer(transfer ActiveTransfer) error { + return ErrNotImplemented +} + +func (p *BoltProvider) updateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) error { + return ErrNotImplemented +} + +func (p *BoltProvider) removeActiveTransfer(transferID int64, connectionID string) error { + return ErrNotImplemented +} + +func (p *BoltProvider) cleanupActiveTransfers(before time.Time) error { + return ErrNotImplemented +} + +func (p *BoltProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, error) { + return nil, ErrNotImplemented +} + func (p *BoltProvider) close() error { return p.dbHandle.Close() } @@ -1471,6 +1530,8 @@ func (p *BoltProvider) migrateDatabase() error { providerLog(logger.LevelError, "%v", err) logger.ErrorToConsole("%v", err) return err + case version == 15: + return updateBoltDatabaseVersion(p.dbHandle, 16) default: if version > boltDatabaseVersion { providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version, @@ -1492,6 +1553,8 @@ func (p *BoltProvider) revertDatabase(targetVersion int) error { return errors.New("current version match target version, nothing to do") } switch dbVersion.Version { + case 16: + return updateBoltDatabaseVersion(p.dbHandle, 15) default: return fmt.Errorf("database version not handled: %v", dbVersion.Version) } @@ -1765,7 +1828,7 @@ func getBoltDatabaseVersion(dbHandle *bolt.DB) (schemaVersion, error) { return dbVersion, err } -/*func updateBoltDatabaseVersion(dbHandle *bolt.DB, version int) error { +func updateBoltDatabaseVersion(dbHandle *bolt.DB, version int) error { err := dbHandle.Update(func(tx *bolt.Tx) error { bucket := tx.Bucket(dbVersionBucket) if bucket == nil { @@ -1781,4 +1844,4 @@ func getBoltDatabaseVersion(dbHandle *bolt.DB) (schemaVersion, error) { return bucket.Put(dbVersionKey, buf) }) return err -}*/ +} diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index 4e2ccb61..68a7170a 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -71,7 +71,7 @@ const ( CockroachDataProviderName = "cockroachdb" // DumpVersion defines the version for the dump. // For restore/load we support the current version and the previous one - DumpVersion = 10 + DumpVersion = 11 argonPwdPrefix = "$argon2id$" bcryptPwdPrefix = "$2a$" @@ -165,6 +165,7 @@ var ( sqlTableShares = "shares" sqlTableDefenderHosts = "defender_hosts" sqlTableDefenderEvents = "defender_events" + sqlTableActiveTransfers = "active_transfers" sqlTableSchemaVersion = "schema_version" argon2Params *argon2id.Params lastLoginMinDelay = 10 * time.Minute @@ -367,10 +368,20 @@ type Config struct { // MySQL, PostgreSQL and CockroachDB can be shared, this setting is ignored for other data // providers. For shared data providers, SFTPGo periodically reloads the latest updated users, // based on the "updated_at" field, and updates its internal caches if users are updated from - // a different instance. This check, if enabled, is executed every 10 minutes + // a different instance. This check, if enabled, is executed every 10 minutes. + // For shared data providers, active transfers are persisted in the database and thus + // quota checks between ongoing transfers will work cross multiple instances IsShared int `json:"is_shared" mapstructure:"is_shared"` } +// GetShared returns the provider share mode +func (c *Config) GetShared() int { + if !util.IsStringInSlice(c.Driver, sharedProviders) { + return 0 + } + return c.IsShared +} + // IsDefenderSupported returns true if the configured provider supports the defender func (c *Config) IsDefenderSupported() bool { switch c.Driver { @@ -388,6 +399,7 @@ type ActiveTransfer struct { ConnID string Username string FolderName string + IP string TruncatedSize int64 CurrentULSize int64 CurrentDLSize int64 @@ -395,10 +407,36 @@ type ActiveTransfer struct { UpdatedAt int64 } -// GetKey returns an aggregation key. -// The same key will be returned for similar transfers -func (t *ActiveTransfer) GetKey() string { - return fmt.Sprintf("%v%v%v", t.Username, t.FolderName, t.Type) +// TransferQuota stores the allowed transfer quota fields +type TransferQuota struct { + ULSize int64 + DLSize int64 + TotalSize int64 + AllowedULSize int64 + AllowedDLSize int64 + AllowedTotalSize int64 +} + +// HasUploadSpace returns true if there is transfer upload space available +func (q *TransferQuota) HasUploadSpace() bool { + if q.TotalSize <= 0 && q.ULSize <= 0 { + return true + } + if q.TotalSize > 0 { + return q.AllowedTotalSize > 0 + } + return q.AllowedULSize > 0 +} + +// HasDownloadSpace returns true if there is transfer download space available +func (q *TransferQuota) HasDownloadSpace() bool { + if q.TotalSize <= 0 && q.DLSize <= 0 { + return true + } + if q.TotalSize > 0 { + return q.AllowedTotalSize > 0 + } + return q.AllowedDLSize > 0 } // DefenderEntry defines a defender entry @@ -488,7 +526,8 @@ type Provider interface { validateUserAndPubKey(username string, pubKey []byte) (User, string, error) validateUserAndTLSCert(username, protocol string, tlsCert *x509.Certificate) (User, error) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error - getUsedQuota(username string) (int, int64, error) + updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error + getUsedQuota(username string) (int, int64, int64, int64, error) userExists(username string) (User, error) addUser(user *User) error updateUser(user *User) error @@ -537,6 +576,11 @@ type Provider interface { addDefenderEvent(ip string, score int) error setDefenderBanTime(ip string, banTime int64) error cleanupDefender(from int64) error + addActiveTransfer(transfer ActiveTransfer) error + updateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) error + removeActiveTransfer(transferID int64, connectionID string) error + cleanupActiveTransfers(before time.Time) error + getActiveTransfers(from time.Time) ([]ActiveTransfer, error) checkAvailability() error close() error reloadConfig() error @@ -673,10 +717,14 @@ func validateSQLTablesPrefix() error { sqlTableAdmins = config.SQLTablesPrefix + sqlTableAdmins sqlTableAPIKeys = config.SQLTablesPrefix + sqlTableAPIKeys sqlTableShares = config.SQLTablesPrefix + sqlTableShares + sqlTableDefenderEvents = config.SQLTablesPrefix + sqlTableDefenderEvents + sqlTableDefenderHosts = config.SQLTablesPrefix + sqlTableDefenderHosts + sqlTableActiveTransfers = config.SQLTablesPrefix + sqlTableActiveTransfers sqlTableSchemaVersion = config.SQLTablesPrefix + sqlTableSchemaVersion providerLog(logger.LevelDebug, "sql table for users %#v, folders %#v folders mapping %#v admins %#v "+ - "api keys %#v shares %#v schema version %#v", sqlTableUsers, sqlTableFolders, sqlTableFoldersMapping, - sqlTableAdmins, sqlTableAPIKeys, sqlTableShares, sqlTableSchemaVersion) + "api keys %#v shares %#v defender hosts %#v defender events %#v transfers %#v schema version %#v", + sqlTableUsers, sqlTableFolders, sqlTableFoldersMapping, sqlTableAdmins, sqlTableAPIKeys, + sqlTableShares, sqlTableDefenderHosts, sqlTableDefenderEvents, sqlTableActiveTransfers, sqlTableSchemaVersion) } return nil } @@ -1026,7 +1074,7 @@ func UpdateAdminLastLogin(admin *Admin) { } } -// UpdateUserQuota updates the quota for the given SFTP user adding filesAdd and sizeAdd. +// UpdateUserQuota updates the quota for the given SFTPGo user adding filesAdd and sizeAdd. // If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference. func UpdateUserQuota(user *User, filesAdd int, sizeAdd int64, reset bool) error { if config.TrackQuota == 0 { @@ -1066,17 +1114,41 @@ func UpdateVirtualFolderQuota(vfolder *vfs.BaseVirtualFolder, filesAdd int, size return nil } -// GetUsedQuota returns the used quota for the given SFTP user. -func GetUsedQuota(username string) (int, int64, error) { +// UpdateUserTransferQuota updates the transfer quota for the given SFTPGo user. +// If reset is true uploadSize and downloadSize indicates the actual sizes instead of the difference. +func UpdateUserTransferQuota(user *User, uploadSize, downloadSize int64, reset bool) error { if config.TrackQuota == 0 { - return 0, 0, util.NewMethodDisabledError(trackQuotaDisabledError) + return util.NewMethodDisabledError(trackQuotaDisabledError) + } else if config.TrackQuota == 2 && !reset && !user.HasTransferQuotaRestrictions() { + return nil } - files, size, err := provider.getUsedQuota(username) + if downloadSize == 0 && uploadSize == 0 && !reset { + return nil + } + if config.DelayedQuotaUpdate == 0 || reset { + if reset { + delayedQuotaUpdater.resetUserTransferQuota(user.Username) + } + return provider.updateTransferQuota(user.Username, uploadSize, downloadSize, reset) + } + delayedQuotaUpdater.updateUserTransferQuota(user.Username, uploadSize, downloadSize) + return nil +} + +// GetUsedQuota returns the used quota for the given SFTPGo user. +func GetUsedQuota(username string) (int, int64, int64, int64, error) { + if config.TrackQuota == 0 { + return 0, 0, 0, 0, util.NewMethodDisabledError(trackQuotaDisabledError) + } + files, size, ulTransferSize, dlTransferSize, err := provider.getUsedQuota(username) if err != nil { - return files, size, err + return files, size, ulTransferSize, dlTransferSize, err } delayedFiles, delayedSize := delayedQuotaUpdater.getUserPendingQuota(username) - return files + delayedFiles, size + delayedSize, err + delayedUlTransferSize, delayedDLTransferSize := delayedQuotaUpdater.getUserPendingTransferQuota(username) + + return files + delayedFiles, size + delayedSize, ulTransferSize + delayedUlTransferSize, + dlTransferSize + delayedDLTransferSize, err } // GetUsedVirtualFolderQuota returns the used quota for the given virtual folder. @@ -1262,6 +1334,46 @@ func DeleteUser(username, executor, ipAddress string) error { return err } +// AddActiveTransfer stores the specified transfer +func AddActiveTransfer(transfer ActiveTransfer) { + if err := provider.addActiveTransfer(transfer); err != nil { + providerLog(logger.LevelError, "unable to add transfer id %v, connection id %v: %v", + transfer.ID, transfer.ConnID, err) + } +} + +// UpdateActiveTransferSizes updates the current upload and download sizes for the specified transfer +func UpdateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) { + if err := provider.updateActiveTransferSizes(ulSize, dlSize, transferID, connectionID); err != nil { + providerLog(logger.LevelError, "unable to update sizes for transfer id %v, connection id %v: %v", + transferID, connectionID, err) + } +} + +// RemoveActiveTransfer removes the specified transfer +func RemoveActiveTransfer(transferID int64, connectionID string) { + if err := provider.removeActiveTransfer(transferID, connectionID); err != nil { + providerLog(logger.LevelError, "unable to delete transfer id %v, connection id %v: %v", + transferID, connectionID, err) + } +} + +// CleanupActiveTransfers removes the transfer before the specified time +func CleanupActiveTransfers(before time.Time) error { + err := provider.cleanupActiveTransfers(before) + if err == nil { + providerLog(logger.LevelDebug, "deleted active transfers updated before: %v", before) + } else { + providerLog(logger.LevelError, "error deleting active transfers updated before %v: %v", before, err) + } + return err +} + +// GetActiveTransfers retrieves the active transfers with an update time after the specified value +func GetActiveTransfers(from time.Time) ([]ActiveTransfer, error) { + return provider.getActiveTransfers(from) +} + // ReloadConfig reloads provider configuration. // Currently only implemented for memory provider, allows to reload the users // from the configured file, if defined @@ -1780,6 +1892,9 @@ func validateIPFilters(user *User) error { } func validateBandwidthLimit(bl sdk.BandwidthLimit) error { + if len(bl.Sources) == 0 { + return util.NewValidationError("no bandwidth limit source specified") + } for _, source := range bl.Sources { _, _, err := net.ParseCIDR(source) if err != nil { @@ -1789,7 +1904,7 @@ func validateBandwidthLimit(bl sdk.BandwidthLimit) error { return nil } -func validateBandwidthLimitFilters(user *User) error { +func validateBandwidthLimitsFilter(user *User) error { for idx, bandwidthLimit := range user.Filters.BandwidthLimits { user.Filters.BandwidthLimits[idx].Sources = util.RemoveDuplicates(bandwidthLimit.Sources) if err := validateBandwidthLimit(bandwidthLimit); err != nil { @@ -1805,12 +1920,35 @@ func validateBandwidthLimitFilters(user *User) error { return nil } +func validateTransferLimitsFilter(user *User) error { + for idx, limit := range user.Filters.DataTransferLimits { + user.Filters.DataTransferLimits[idx].Sources = util.RemoveDuplicates(limit.Sources) + if len(limit.Sources) == 0 { + return util.NewValidationError("no data transfer limit source specified") + } + for _, source := range limit.Sources { + _, _, err := net.ParseCIDR(source) + if err != nil { + return util.NewValidationError(fmt.Sprintf("could not parse data transfer limit source %#v: %v", source, err)) + } + } + if limit.TotalDataTransfer > 0 { + user.Filters.DataTransferLimits[idx].UploadDataTransfer = 0 + user.Filters.DataTransferLimits[idx].DownloadDataTransfer = 0 + } + } + return nil +} + func validateFilters(user *User) error { checkEmptyFiltersStruct(user) if err := validateIPFilters(user); err != nil { return err } - if err := validateBandwidthLimitFilters(user); err != nil { + if err := validateBandwidthLimitsFilter(user); err != nil { + return err + } + if err := validateTransferLimitsFilter(user); err != nil { return err } user.Filters.DeniedLoginMethods = util.RemoveDuplicates(user.Filters.DeniedLoginMethods) @@ -1913,6 +2051,11 @@ func validateBaseParams(user *User) error { if user.UploadBandwidth < 0 { user.UploadBandwidth = 0 } + if user.TotalDataTransfer > 0 { + // if a total data transfer is defined we reset the separate upload and download limits + user.UploadDataTransfer = 0 + user.DownloadDataTransfer = 0 + } return nil } @@ -2814,6 +2957,8 @@ func executePreLoginHook(username, loginMethod, ip, protocol string) (User, erro userPwd := u.Password userUsedQuotaSize := u.UsedQuotaSize userUsedQuotaFiles := u.UsedQuotaFiles + userUsedDownloadTransfer := u.UsedDownloadDataTransfer + userUsedUploadTransfer := u.UsedUploadDataTransfer userLastQuotaUpdate := u.LastQuotaUpdate userLastLogin := u.LastLogin userCreatedAt := u.CreatedAt @@ -2826,6 +2971,8 @@ func executePreLoginHook(username, loginMethod, ip, protocol string) (User, erro u.ID = userID u.UsedQuotaSize = userUsedQuotaSize u.UsedQuotaFiles = userUsedQuotaFiles + u.UsedUploadDataTransfer = userUsedUploadTransfer + u.UsedDownloadDataTransfer = userUsedDownloadTransfer u.LastQuotaUpdate = userLastQuotaUpdate u.LastLogin = userLastLogin u.CreatedAt = userCreatedAt @@ -3034,6 +3181,8 @@ func doExternalAuth(username, password string, pubKey []byte, keyboardInteractiv user.ID = u.ID user.UsedQuotaSize = u.UsedQuotaSize user.UsedQuotaFiles = u.UsedQuotaFiles + user.UsedUploadDataTransfer = u.UsedUploadDataTransfer + user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer user.LastQuotaUpdate = u.LastQuotaUpdate user.LastLogin = u.LastLogin user.CreatedAt = u.CreatedAt @@ -3100,6 +3249,8 @@ func doPluginAuth(username, password string, pubKey []byte, ip, protocol string, user.ID = u.ID user.UsedQuotaSize = u.UsedQuotaSize user.UsedQuotaFiles = u.UsedQuotaFiles + user.UsedUploadDataTransfer = u.UsedUploadDataTransfer + user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer user.LastQuotaUpdate = u.LastQuotaUpdate user.LastLogin = u.LastLogin // preserve TOTP config and recovery codes diff --git a/dataprovider/memory.go b/dataprovider/memory.go index 81fb5d1b..cb089880 100644 --- a/dataprovider/memory.go +++ b/dataprovider/memory.go @@ -208,6 +208,31 @@ func (p *MemoryProvider) updateAdminLastLogin(username string) error { return nil } +func (p *MemoryProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + user, err := p.userExistsInternal(username) + if err != nil { + providerLog(logger.LevelError, "unable to update transfer quota for user %#v error: %v", username, err) + return err + } + if reset { + user.UsedUploadDataTransfer = uploadSize + user.UsedDownloadDataTransfer = downloadSize + } else { + user.UsedUploadDataTransfer += uploadSize + user.UsedDownloadDataTransfer += downloadSize + } + user.LastQuotaUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) + providerLog(logger.LevelDebug, "transfer quota updated for user %#v, ul increment: %v dl increment: %v is reset? %v", + username, uploadSize, downloadSize, reset) + p.dbHandle.users[user.Username] = user + return nil +} + func (p *MemoryProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { p.dbHandle.Lock() defer p.dbHandle.Unlock() @@ -233,18 +258,18 @@ func (p *MemoryProvider) updateQuota(username string, filesAdd int, sizeAdd int6 return nil } -func (p *MemoryProvider) getUsedQuota(username string) (int, int64, error) { +func (p *MemoryProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { p.dbHandle.Lock() defer p.dbHandle.Unlock() if p.dbHandle.isClosed { - return 0, 0, errMemoryProviderClosed + return 0, 0, 0, 0, errMemoryProviderClosed } user, err := p.userExistsInternal(username) if err != nil { providerLog(logger.LevelError, "unable to get quota for user %#v error: %v", username, err) - return 0, 0, err + return 0, 0, 0, 0, err } - return user.UsedQuotaFiles, user.UsedQuotaSize, err + return user.UsedQuotaFiles, user.UsedQuotaSize, user.UsedUploadDataTransfer, user.UsedDownloadDataTransfer, err } func (p *MemoryProvider) addUser(user *User) error { @@ -269,6 +294,8 @@ func (p *MemoryProvider) addUser(user *User) error { user.LastQuotaUpdate = 0 user.UsedQuotaSize = 0 user.UsedQuotaFiles = 0 + user.UsedUploadDataTransfer = 0 + user.UsedDownloadDataTransfer = 0 user.LastLogin = 0 user.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) @@ -304,6 +331,8 @@ func (p *MemoryProvider) updateUser(user *User) error { user.LastQuotaUpdate = u.LastQuotaUpdate user.UsedQuotaSize = u.UsedQuotaSize user.UsedQuotaFiles = u.UsedQuotaFiles + user.UsedUploadDataTransfer = u.UsedUploadDataTransfer + user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer user.LastLogin = u.LastLogin user.CreatedAt = u.CreatedAt user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) @@ -1335,6 +1364,26 @@ func (p *MemoryProvider) cleanupDefender(from int64) error { return ErrNotImplemented } +func (p *MemoryProvider) addActiveTransfer(transfer ActiveTransfer) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) updateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) removeActiveTransfer(transferID int64, connectionID string) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) cleanupActiveTransfers(before time.Time) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, error) { + return nil, ErrNotImplemented +} + func (p *MemoryProvider) getNextID() int64 { nextID := int64(1) for _, v := range p.dbHandle.users { diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go index 71432831..bcf54253 100644 --- a/dataprovider/mysql.go +++ b/dataprovider/mysql.go @@ -29,6 +29,7 @@ const ( "DROP TABLE IF EXISTS `{{users}}` CASCADE;" + "DROP TABLE IF EXISTS `{{defender_events}}` CASCADE;" + "DROP TABLE IF EXISTS `{{defender_hosts}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{active_transfers}}` CASCADE;" + "DROP TABLE IF EXISTS `{{schema_version}}` CASCADE;" mysqlInitialSQL = "CREATE TABLE `{{schema_version}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" + "CREATE TABLE `{{admins}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " + @@ -75,6 +76,30 @@ const ( "CREATE INDEX `{{prefix}}defender_hosts_ban_time_idx` ON `{{defender_hosts}}` (`ban_time`);" + "CREATE INDEX `{{prefix}}defender_events_date_time_idx` ON `{{defender_events}}` (`date_time`);" + "INSERT INTO {{schema_version}} (version) VALUES (15);" + mysqlV16SQL = "ALTER TABLE `{{users}}` ADD COLUMN `download_data_transfer` integer DEFAULT 0 NOT NULL;" + + "ALTER TABLE `{{users}}` ALTER COLUMN `download_data_transfer` DROP DEFAULT;" + + "ALTER TABLE `{{users}}` ADD COLUMN `total_data_transfer` integer DEFAULT 0 NOT NULL;" + + "ALTER TABLE `{{users}}` ALTER COLUMN `total_data_transfer` DROP DEFAULT;" + + "ALTER TABLE `{{users}}` ADD COLUMN `upload_data_transfer` integer DEFAULT 0 NOT NULL;" + + "ALTER TABLE `{{users}}` ALTER COLUMN `upload_data_transfer` DROP DEFAULT;" + + "ALTER TABLE `{{users}}` ADD COLUMN `used_download_data_transfer` integer DEFAULT 0 NOT NULL;" + + "ALTER TABLE `{{users}}` ALTER COLUMN `used_download_data_transfer` DROP DEFAULT;" + + "ALTER TABLE `{{users}}` ADD COLUMN `used_upload_data_transfer` integer DEFAULT 0 NOT NULL;" + + "ALTER TABLE `{{users}}` ALTER COLUMN `used_upload_data_transfer` DROP DEFAULT;" + + "CREATE TABLE `{{active_transfers}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " + + "`connection_id` varchar(100) NOT NULL, `transfer_id` bigint NOT NULL, `transfer_type` integer NOT NULL, " + + "`username` varchar(255) NOT NULL, `folder_name` varchar(255) NULL, `ip` varchar(50) NOT NULL, " + + "`truncated_size` bigint NOT NULL, `current_ul_size` bigint NOT NULL, `current_dl_size` bigint NOT NULL, " + + "`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL);" + + "CREATE INDEX `{{prefix}}active_transfers_connection_id_idx` ON `{{active_transfers}}` (`connection_id`);" + + "CREATE INDEX `{{prefix}}active_transfers_transfer_id_idx` ON `{{active_transfers}}` (`transfer_id`);" + + "CREATE INDEX `{{prefix}}active_transfers_updated_at_idx` ON `{{active_transfers}}` (`updated_at`);" + mysqlV16DownSQL = "ALTER TABLE `{{users}}` DROP COLUMN `used_upload_data_transfer`;" + + "ALTER TABLE `{{users}}` DROP COLUMN `used_download_data_transfer`;" + + "ALTER TABLE `{{users}}` DROP COLUMN `upload_data_transfer`;" + + "ALTER TABLE `{{users}}` DROP COLUMN `total_data_transfer`;" + + "ALTER TABLE `{{users}}` DROP COLUMN `download_data_transfer`;" + + "DROP TABLE `{{active_transfers}}` CASCADE;" ) // MySQLProvider defines the auth provider for MySQL/MariaDB database @@ -138,11 +163,15 @@ func (p *MySQLProvider) validateUserAndPubKey(username string, publicKey []byte) return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle) } +func (p *MySQLProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { + return sqlCommonUpdateTransferQuota(username, uploadSize, downloadSize, reset, p.dbHandle) +} + func (p *MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) } -func (p *MySQLProvider) getUsedQuota(username string) (int, int64, error) { +func (p *MySQLProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { return sqlCommonGetUsedQuota(username, p.dbHandle) } @@ -340,6 +369,26 @@ func (p *MySQLProvider) cleanupDefender(from int64) error { return sqlCommonDefenderCleanup(from, p.dbHandle) } +func (p *MySQLProvider) addActiveTransfer(transfer ActiveTransfer) error { + return sqlCommonAddActiveTransfer(transfer, p.dbHandle) +} + +func (p *MySQLProvider) updateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) error { + return sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID, connectionID, p.dbHandle) +} + +func (p *MySQLProvider) removeActiveTransfer(transferID int64, connectionID string) error { + return sqlCommonRemoveActiveTransfer(transferID, connectionID, p.dbHandle) +} + +func (p *MySQLProvider) cleanupActiveTransfers(before time.Time) error { + return sqlCommonCleanupActiveTransfers(before, p.dbHandle) +} + +func (p *MySQLProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, error) { + return sqlCommonGetActiveTransfers(from, p.dbHandle) +} + func (p *MySQLProvider) close() error { return p.dbHandle.Close() } @@ -388,6 +437,8 @@ func (p *MySQLProvider) migrateDatabase() error { providerLog(logger.LevelError, "%v", err) logger.ErrorToConsole("%v", err) return err + case version == 15: + return updateMySQLDatabaseFromV15(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version, @@ -410,6 +461,8 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error { } switch dbVersion.Version { + case 16: + return downgradeMySQLDatabaseFromV16(p.dbHandle) default: return fmt.Errorf("database version not handled: %v", dbVersion.Version) } @@ -425,5 +478,31 @@ func (p *MySQLProvider) resetDatabase() error { sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares) sql = strings.ReplaceAll(sql, "{{defender_events}}", sqlTableDefenderEvents) sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts) + sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(sql, ";"), 0) } + +func updateMySQLDatabaseFromV15(dbHandle *sql.DB) error { + return updateMySQLDatabaseFrom15To16(dbHandle) +} + +func downgradeMySQLDatabaseFromV16(dbHandle *sql.DB) error { + return downgradeMySQLDatabaseFrom16To15(dbHandle) +} + +func updateMySQLDatabaseFrom15To16(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database version: 15 -> 16") + providerLog(logger.LevelInfo, "updating database version: 15 -> 16") + sql := strings.ReplaceAll(mysqlV16SQL, "{{users}}", sqlTableUsers) + sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 16) +} + +func downgradeMySQLDatabaseFrom16To15(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database version: 16 -> 15") + providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15") + sql := strings.ReplaceAll(mysqlV16DownSQL, "{{users}}", sqlTableUsers) + sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 15) +} diff --git a/dataprovider/pgsql.go b/dataprovider/pgsql.go index 7eb37197..faf8484e 100644 --- a/dataprovider/pgsql.go +++ b/dataprovider/pgsql.go @@ -29,6 +29,7 @@ DROP TABLE IF EXISTS "{{shares}}" CASCADE; DROP TABLE IF EXISTS "{{users}}" CASCADE; DROP TABLE IF EXISTS "{{defender_events}}" CASCADE; DROP TABLE IF EXISTS "{{defender_hosts}}" CASCADE; +DROP TABLE IF EXISTS "{{active_transfers}}" CASCADE; DROP TABLE IF EXISTS "{{schema_version}}" CASCADE; ` pgsqlInitial = `CREATE TABLE "{{schema_version}}" ("id" serial NOT NULL PRIMARY KEY, "version" integer NOT NULL); @@ -86,6 +87,32 @@ CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("b CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time"); CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id"); INSERT INTO {{schema_version}} (version) VALUES (15); +` + pgsqlV16SQL = `ALTER TABLE "{{users}}" ADD COLUMN "download_data_transfer" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{users}}" ALTER COLUMN "download_data_transfer" DROP DEFAULT; +ALTER TABLE "{{users}}" ADD COLUMN "total_data_transfer" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{users}}" ALTER COLUMN "total_data_transfer" DROP DEFAULT; +ALTER TABLE "{{users}}" ADD COLUMN "upload_data_transfer" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{users}}" ALTER COLUMN "upload_data_transfer" DROP DEFAULT; +ALTER TABLE "{{users}}" ADD COLUMN "used_download_data_transfer" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{users}}" ALTER COLUMN "used_download_data_transfer" DROP DEFAULT; +ALTER TABLE "{{users}}" ADD COLUMN "used_upload_data_transfer" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{users}}" ALTER COLUMN "used_upload_data_transfer" DROP DEFAULT; +CREATE TABLE "{{active_transfers}}" ("id" bigserial NOT NULL PRIMARY KEY, "connection_id" varchar(100) NOT NULL, +"transfer_id" bigint NOT NULL, "transfer_type" integer NOT NULL, "username" varchar(255) NOT NULL, +"folder_name" varchar(255) NULL, "ip" varchar(50) NOT NULL, "truncated_size" bigint NOT NULL, +"current_ul_size" bigint NOT NULL, "current_dl_size" bigint NOT NULL, "created_at" bigint NOT NULL, +"updated_at" bigint NOT NULL); +CREATE INDEX "{{prefix}}active_transfers_connection_id_idx" ON "{{active_transfers}}" ("connection_id"); +CREATE INDEX "{{prefix}}active_transfers_transfer_id_idx" ON "{{active_transfers}}" ("transfer_id"); +CREATE INDEX "{{prefix}}active_transfers_updated_at_idx" ON "{{active_transfers}}" ("updated_at"); +` + pgsqlV16DownSQL = `ALTER TABLE "{{users}}" DROP COLUMN "used_upload_data_transfer" CASCADE; +ALTER TABLE "{{users}}" DROP COLUMN "used_download_data_transfer" CASCADE; +ALTER TABLE "{{users}}" DROP COLUMN "upload_data_transfer" CASCADE; +ALTER TABLE "{{users}}" DROP COLUMN "total_data_transfer" CASCADE; +ALTER TABLE "{{users}}" DROP COLUMN "download_data_transfer" CASCADE; +DROP TABLE "{{active_transfers}}" CASCADE; ` ) @@ -150,11 +177,15 @@ func (p *PGSQLProvider) validateUserAndPubKey(username string, publicKey []byte) return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle) } +func (p *PGSQLProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { + return sqlCommonUpdateTransferQuota(username, uploadSize, downloadSize, reset, p.dbHandle) +} + func (p *PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) } -func (p *PGSQLProvider) getUsedQuota(username string) (int, int64, error) { +func (p *PGSQLProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { return sqlCommonGetUsedQuota(username, p.dbHandle) } @@ -352,6 +383,26 @@ func (p *PGSQLProvider) cleanupDefender(from int64) error { return sqlCommonDefenderCleanup(from, p.dbHandle) } +func (p *PGSQLProvider) addActiveTransfer(transfer ActiveTransfer) error { + return sqlCommonAddActiveTransfer(transfer, p.dbHandle) +} + +func (p *PGSQLProvider) updateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) error { + return sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID, connectionID, p.dbHandle) +} + +func (p *PGSQLProvider) removeActiveTransfer(transferID int64, connectionID string) error { + return sqlCommonRemoveActiveTransfer(transferID, connectionID, p.dbHandle) +} + +func (p *PGSQLProvider) cleanupActiveTransfers(before time.Time) error { + return sqlCommonCleanupActiveTransfers(before, p.dbHandle) +} + +func (p *PGSQLProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, error) { + return sqlCommonGetActiveTransfers(from, p.dbHandle) +} + func (p *PGSQLProvider) close() error { return p.dbHandle.Close() } @@ -406,6 +457,8 @@ func (p *PGSQLProvider) migrateDatabase() error { providerLog(logger.LevelError, "%v", err) logger.ErrorToConsole("%v", err) return err + case version == 15: + return updatePGSQLDatabaseFromV15(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version, @@ -428,6 +481,8 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error { } switch dbVersion.Version { + case 16: + return downgradePGSQLDatabaseFromV16(p.dbHandle) default: return fmt.Errorf("database version not handled: %v", dbVersion.Version) } @@ -443,5 +498,31 @@ func (p *PGSQLProvider) resetDatabase() error { sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares) sql = strings.ReplaceAll(sql, "{{defender_events}}", sqlTableDefenderEvents) sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts) + sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0) } + +func updatePGSQLDatabaseFromV15(dbHandle *sql.DB) error { + return updatePGSQLDatabaseFrom15To16(dbHandle) +} + +func downgradePGSQLDatabaseFromV16(dbHandle *sql.DB) error { + return downgradePGSQLDatabaseFrom16To15(dbHandle) +} + +func updatePGSQLDatabaseFrom15To16(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database version: 15 -> 16") + providerLog(logger.LevelInfo, "updating database version: 15 -> 16") + sql := strings.ReplaceAll(pgsqlV16SQL, "{{users}}", sqlTableUsers) + sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16) +} + +func downgradePGSQLDatabaseFrom16To15(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database version: 16 -> 15") + providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15") + sql := strings.ReplaceAll(pgsqlV16DownSQL, "{{users}}", sqlTableUsers) + sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 15) +} diff --git a/dataprovider/quotaupdater.go b/dataprovider/quotaupdater.go index 20c17815..9190e63d 100644 --- a/dataprovider/quotaupdater.go +++ b/dataprovider/quotaupdater.go @@ -18,18 +18,25 @@ type quotaObject struct { files int } +type transferQuotaObject struct { + ulSize int64 + dlSize int64 +} + type quotaUpdater struct { paramsMutex sync.RWMutex waitTime time.Duration sync.RWMutex - pendingUserQuotaUpdates map[string]quotaObject - pendingFolderQuotaUpdates map[string]quotaObject + pendingUserQuotaUpdates map[string]quotaObject + pendingFolderQuotaUpdates map[string]quotaObject + pendingTransferQuotaUpdates map[string]transferQuotaObject } func newQuotaUpdater() quotaUpdater { return quotaUpdater{ - pendingUserQuotaUpdates: make(map[string]quotaObject), - pendingFolderQuotaUpdates: make(map[string]quotaObject), + pendingUserQuotaUpdates: make(map[string]quotaObject), + pendingFolderQuotaUpdates: make(map[string]quotaObject), + pendingTransferQuotaUpdates: make(map[string]transferQuotaObject), } } @@ -50,6 +57,7 @@ func (q *quotaUpdater) loop() { providerLog(logger.LevelDebug, "delayed quota update check start") q.storeUsersQuota() q.storeFoldersQuota() + q.storeUsersTransferQuota() providerLog(logger.LevelDebug, "delayed quota update check end") waitTime = q.getWaitTime() } @@ -130,6 +138,36 @@ func (q *quotaUpdater) getFolderPendingQuota(name string) (int, int64) { return obj.files, obj.size } +func (q *quotaUpdater) resetUserTransferQuota(username string) { + q.Lock() + defer q.Unlock() + + delete(q.pendingTransferQuotaUpdates, username) +} + +func (q *quotaUpdater) updateUserTransferQuota(username string, ulSize, dlSize int64) { + q.Lock() + defer q.Unlock() + + obj := q.pendingTransferQuotaUpdates[username] + obj.ulSize += ulSize + obj.dlSize += dlSize + if obj.ulSize == 0 && obj.dlSize == 0 { + delete(q.pendingTransferQuotaUpdates, username) + return + } + q.pendingTransferQuotaUpdates[username] = obj +} + +func (q *quotaUpdater) getUserPendingTransferQuota(username string) (int64, int64) { + q.RLock() + defer q.RUnlock() + + obj := q.pendingTransferQuotaUpdates[username] + + return obj.ulSize, obj.dlSize +} + func (q *quotaUpdater) getUsernames() []string { q.RLock() defer q.RUnlock() @@ -154,6 +192,18 @@ func (q *quotaUpdater) getFoldernames() []string { return result } +func (q *quotaUpdater) getTransferQuotaUsernames() []string { + q.RLock() + defer q.RUnlock() + + result := make([]string, 0, len(q.pendingTransferQuotaUpdates)) + for username := range q.pendingTransferQuotaUpdates { + result = append(result, username) + } + + return result +} + func (q *quotaUpdater) storeUsersQuota() { for _, username := range q.getUsernames() { files, size := q.getUserPendingQuota(username) @@ -181,3 +231,17 @@ func (q *quotaUpdater) storeFoldersQuota() { } } } + +func (q *quotaUpdater) storeUsersTransferQuota() { + for _, username := range q.getTransferQuotaUsernames() { + ulSize, dlSize := q.getUserPendingTransferQuota(username) + if ulSize != 0 || dlSize != 0 { + err := provider.updateTransferQuota(username, ulSize, dlSize, false) + if err != nil { + providerLog(logger.LevelWarn, "unable to update transfer quota delayed for user %#v: %v", username, err) + continue + } + q.updateUserTransferQuota(username, -ulSize, -dlSize) + } + } +} diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go index 7dce91f9..107e4f42 100644 --- a/dataprovider/sqlcommon.go +++ b/dataprovider/sqlcommon.go @@ -18,7 +18,7 @@ import ( ) const ( - sqlDatabaseVersion = 15 + sqlDatabaseVersion = 16 defaultSQLQueryTimeout = 10 * time.Second longSQLQueryTimeout = 60 * time.Second ) @@ -639,6 +639,26 @@ func sqlCommonCheckAvailability(dbHandle *sql.DB) error { return dbHandle.PingContext(ctx) } +func sqlCommonUpdateTransferQuota(username string, uploadSize, downloadSize int64, reset bool, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + q := getUpdateTransferQuotaQuery(reset) + stmt, err := dbHandle.PrepareContext(ctx, q) + if err != nil { + providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err) + return err + } + defer stmt.Close() + _, err = stmt.ExecContext(ctx, uploadSize, downloadSize, util.GetTimeAsMsSinceEpoch(time.Now()), username) + if err == nil { + providerLog(logger.LevelDebug, "transfer quota updated for user %#v, ul increment: %v dl increment: %v is reset? %v", + username, uploadSize, downloadSize, reset) + } else { + providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err) + } + return err +} + func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() @@ -659,25 +679,25 @@ func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bo return err } -func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, error) { +func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, int64, int64, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getQuotaQuery() stmt, err := dbHandle.PrepareContext(ctx, q) if err != nil { providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err) - return 0, 0, err + return 0, 0, 0, 0, err } defer stmt.Close() var usedFiles int - var usedSize int64 - err = stmt.QueryRowContext(ctx, username).Scan(&usedSize, &usedFiles) + var usedSize, usedUploadSize, usedDownloadSize int64 + err = stmt.QueryRowContext(ctx, username).Scan(&usedSize, &usedFiles, &usedUploadSize, &usedDownloadSize) if err != nil { providerLog(logger.LevelError, "error getting quota for user: %v, error: %v", username, err) - return 0, 0, err + return 0, 0, 0, 0, err } - return usedFiles, usedSize, err + return usedFiles, usedSize, usedUploadSize, usedDownloadSize, err } func sqlCommonUpdateShareLastUse(shareID string, numTokens int, dbHandle *sql.DB) error { @@ -806,10 +826,11 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error { if err != nil { return err } - _, err = stmt.ExecContext(ctx, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize, - user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters), - string(fsConfig), user.AdditionalInfo, user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), - util.GetTimeAsMsSinceEpoch(time.Now())) + _, err = stmt.ExecContext(ctx, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, + user.MaxSessions, user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, + user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo, + user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()), + user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer) if err != nil { return err } @@ -849,9 +870,10 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error { if err != nil { return err } - _, err = stmt.ExecContext(ctx, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize, - user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, - string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), + _, err = stmt.ExecContext(ctx, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, + user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, + user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.Email, + util.GetTimeAsMsSinceEpoch(time.Now()), user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer, user.ID) if err != nil { return err @@ -1013,16 +1035,124 @@ func sqlCommonGetUsersRangeForQuotaCheck(usernames []string, dbHandle sqlQuerier for rows.Next() { var user User - err = rows.Scan(&user.ID, &user.Username, &user.QuotaSize, &user.UsedQuotaSize) + var filters sql.NullString + err = rows.Scan(&user.ID, &user.Username, &user.QuotaSize, &user.UsedQuotaSize, &user.TotalDataTransfer, + &user.UploadDataTransfer, &user.DownloadDataTransfer, &user.UsedUploadDataTransfer, + &user.UsedDownloadDataTransfer, &filters) if err != nil { return users, err } + if filters.Valid { + var userFilters UserFilters + err = json.Unmarshal([]byte(filters.String), &userFilters) + if err == nil { + user.Filters = userFilters + } + } users = append(users, user) } return users, rows.Err() } +func sqlCommonAddActiveTransfer(transfer ActiveTransfer, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + q := getAddActiveTransferQuery() + stmt, err := dbHandle.PrepareContext(ctx, q) + if err != nil { + providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err) + return err + } + defer stmt.Close() + now := util.GetTimeAsMsSinceEpoch(time.Now()) + _, err = stmt.ExecContext(ctx, transfer.ID, transfer.ConnID, transfer.Type, transfer.Username, + transfer.FolderName, transfer.IP, transfer.TruncatedSize, transfer.CurrentULSize, transfer.CurrentDLSize, + now, now) + return err +} + +func sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + q := getUpdateActiveTransferSizesQuery() + stmt, err := dbHandle.PrepareContext(ctx, q) + if err != nil { + providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err) + return err + } + defer stmt.Close() + + _, err = stmt.ExecContext(ctx, ulSize, dlSize, util.GetTimeAsMsSinceEpoch(time.Now()), connectionID, transferID) + return err +} + +func sqlCommonRemoveActiveTransfer(transferID int64, connectionID string, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + q := getRemoveActiveTransferQuery() + stmt, err := dbHandle.PrepareContext(ctx, q) + if err != nil { + providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err) + return err + } + defer stmt.Close() + _, err = stmt.ExecContext(ctx, connectionID, transferID) + return err +} + +func sqlCommonCleanupActiveTransfers(before time.Time, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getCleanupActiveTransfersQuery() + stmt, err := dbHandle.PrepareContext(ctx, q) + if err != nil { + providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err) + return err + } + defer stmt.Close() + _, err = stmt.ExecContext(ctx, util.GetTimeAsMsSinceEpoch(before)) + return err +} + +func sqlCommonGetActiveTransfers(from time.Time, dbHandle sqlQuerier) ([]ActiveTransfer, error) { + transfers := make([]ActiveTransfer, 0, 30) + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + q := getActiveTransfersQuery() + stmt, err := dbHandle.PrepareContext(ctx, q) + if err != nil { + providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err) + return nil, err + } + defer stmt.Close() + + rows, err := stmt.QueryContext(ctx, util.GetTimeAsMsSinceEpoch(from)) + if err != nil { + return nil, err + } + + defer rows.Close() + for rows.Next() { + var transfer ActiveTransfer + var folderName sql.NullString + err = rows.Scan(&transfer.ID, &transfer.ConnID, &transfer.Type, &transfer.Username, &folderName, &transfer.IP, + &transfer.TruncatedSize, &transfer.CurrentULSize, &transfer.CurrentDLSize, &transfer.CreatedAt, + &transfer.UpdatedAt) + if err != nil { + return transfers, err + } + if folderName.Valid { + transfer.FolderName = folderName.String + } + transfers = append(transfers, transfer) + } + + return transfers, rows.Err() +} + func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier) ([]User, error) { users := make([]User, 0, limit) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) @@ -1439,7 +1569,8 @@ func getUserFromDbRow(row sqlScanner) (User, error) { err := row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions, &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate, &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig, - &additionalInfo, &description, &email, &user.CreatedAt, &user.UpdatedAt) + &additionalInfo, &description, &email, &user.CreatedAt, &user.UpdatedAt, &user.UploadDataTransfer, &user.DownloadDataTransfer, + &user.TotalDataTransfer, &user.UsedUploadDataTransfer, &user.UsedDownloadDataTransfer) if err != nil { if errors.Is(err, sql.ErrNoRows) { return user, util.NewRecordNotFoundError(err.Error()) diff --git a/dataprovider/sqlite.go b/dataprovider/sqlite.go index 264849ec..d18861a1 100644 --- a/dataprovider/sqlite.go +++ b/dataprovider/sqlite.go @@ -11,6 +11,7 @@ import ( "fmt" "path/filepath" "strings" + "time" // we import go-sqlite3 here to be able to disable SQLite support using a build tag _ "github.com/mattn/go-sqlite3" @@ -30,6 +31,7 @@ DROP TABLE IF EXISTS "{{shares}}"; DROP TABLE IF EXISTS "{{users}}"; DROP TABLE IF EXISTS "{{defender_events}}"; DROP TABLE IF EXISTS "{{defender_hosts}}"; +DROP TABLE IF EXISTS "{{active_transfers}}"; DROP TABLE IF EXISTS "{{schema_version}}"; ` sqliteInitialSQL = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "version" integer NOT NULL); @@ -78,6 +80,27 @@ CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("b CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time"); CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id"); INSERT INTO {{schema_version}} (version) VALUES (15); +` + sqliteV16SQL = `ALTER TABLE "{{users}}" ADD COLUMN "download_data_transfer" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{users}}" ADD COLUMN "total_data_transfer" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{users}}" ADD COLUMN "upload_data_transfer" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{users}}" ADD COLUMN "used_download_data_transfer" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{users}}" ADD COLUMN "used_upload_data_transfer" integer DEFAULT 0 NOT NULL; +CREATE TABLE "{{active_transfers}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "connection_id" varchar(100) NOT NULL, +"transfer_id" bigint NOT NULL, "transfer_type" integer NOT NULL, "username" varchar(255) NOT NULL, +"folder_name" varchar(255) NULL, "ip" varchar(50) NOT NULL, "truncated_size" bigint NOT NULL, +"current_ul_size" bigint NOT NULL, "current_dl_size" bigint NOT NULL, "created_at" bigint NOT NULL, +"updated_at" bigint NOT NULL); +CREATE INDEX "{{prefix}}active_transfers_connection_id_idx" ON "{{active_transfers}}" ("connection_id"); +CREATE INDEX "{{prefix}}active_transfers_transfer_id_idx" ON "{{active_transfers}}" ("transfer_id"); +CREATE INDEX "{{prefix}}active_transfers_updated_at_idx" ON "{{active_transfers}}" ("updated_at"); +` + sqliteV16DownSQL = `ALTER TABLE "{{users}}" DROP COLUMN "used_upload_data_transfer"; +ALTER TABLE "{{users}}" DROP COLUMN "used_download_data_transfer"; +ALTER TABLE "{{users}}" DROP COLUMN "upload_data_transfer"; +ALTER TABLE "{{users}}" DROP COLUMN "total_data_transfer"; +ALTER TABLE "{{users}}" DROP COLUMN "download_data_transfer"; +DROP TABLE "{{active_transfers}}"; ` ) @@ -134,11 +157,15 @@ func (p *SQLiteProvider) validateUserAndPubKey(username string, publicKey []byte return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle) } +func (p *SQLiteProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { + return sqlCommonUpdateTransferQuota(username, uploadSize, downloadSize, reset, p.dbHandle) +} + func (p *SQLiteProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) } -func (p *SQLiteProvider) getUsedQuota(username string) (int, int64, error) { +func (p *SQLiteProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { return sqlCommonGetUsedQuota(username, p.dbHandle) } @@ -337,6 +364,26 @@ func (p *SQLiteProvider) cleanupDefender(from int64) error { return sqlCommonDefenderCleanup(from, p.dbHandle) } +func (p *SQLiteProvider) addActiveTransfer(transfer ActiveTransfer) error { + return sqlCommonAddActiveTransfer(transfer, p.dbHandle) +} + +func (p *SQLiteProvider) updateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) error { + return sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID, connectionID, p.dbHandle) +} + +func (p *SQLiteProvider) removeActiveTransfer(transferID int64, connectionID string) error { + return sqlCommonRemoveActiveTransfer(transferID, connectionID, p.dbHandle) +} + +func (p *SQLiteProvider) cleanupActiveTransfers(before time.Time) error { + return sqlCommonCleanupActiveTransfers(before, p.dbHandle) +} + +func (p *SQLiteProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, error) { + return sqlCommonGetActiveTransfers(from, p.dbHandle) +} + func (p *SQLiteProvider) close() error { return p.dbHandle.Close() } @@ -385,6 +432,8 @@ func (p *SQLiteProvider) migrateDatabase() error { providerLog(logger.LevelError, "%v", err) logger.ErrorToConsole("%v", err) return err + case version == 15: + return updateSQLiteDatabaseFromV15(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version, @@ -407,6 +456,8 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error { } switch dbVersion.Version { + case 16: + return downgradeSQLiteDatabaseFromV16(p.dbHandle) default: return fmt.Errorf("database version not handled: %v", dbVersion.Version) } @@ -422,9 +473,35 @@ func (p *SQLiteProvider) resetDatabase() error { sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares) sql = strings.ReplaceAll(sql, "{{defender_events}}", sqlTableDefenderEvents) sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts) + sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0) } +func updateSQLiteDatabaseFromV15(dbHandle *sql.DB) error { + return updateSQLiteDatabaseFrom15To16(dbHandle) +} + +func downgradeSQLiteDatabaseFromV16(dbHandle *sql.DB) error { + return downgradeSQLiteDatabaseFrom16To15(dbHandle) +} + +func updateSQLiteDatabaseFrom15To16(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database version: 15 -> 16") + providerLog(logger.LevelInfo, "updating database version: 15 -> 16") + sql := strings.ReplaceAll(sqliteV16SQL, "{{users}}", sqlTableUsers) + sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16) +} + +func downgradeSQLiteDatabaseFrom16To15(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database version: 16 -> 15") + providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15") + sql := strings.ReplaceAll(sqliteV16DownSQL, "{{users}}", sqlTableUsers) + sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 15) +} + /*func setPragmaFK(dbHandle *sql.DB, value string) error { ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() diff --git a/dataprovider/sqlqueries.go b/dataprovider/sqlqueries.go index 35ef7265..1342bba1 100644 --- a/dataprovider/sqlqueries.go +++ b/dataprovider/sqlqueries.go @@ -11,7 +11,8 @@ import ( const ( selectUserFields = "id,username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions,used_quota_size," + "used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth,expiration_date,last_login,status,filters,filesystem," + - "additional_info,description,email,created_at,updated_at" + "additional_info,description,email,created_at,updated_at,upload_data_transfer,download_data_transfer,total_data_transfer," + + "used_upload_data_transfer,used_download_data_transfer" selectFolderFields = "id,path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem" selectAdminFields = "id,username,password,status,email,permissions,filters,additional_info,description,created_at,updated_at,last_login" selectAPIKeyFields = "key_id,name,api_key,scope,created_at,updated_at,last_use_at,expires_at,description,user_id,admin_id" @@ -276,7 +277,8 @@ func getUsersForQuotaCheckQuery(numArgs int) string { if sb.Len() > 0 { sb.WriteString(")") } - return fmt.Sprintf(`SELECT id,username,quota_size,used_quota_size FROM %v WHERE username IN %v`, + return fmt.Sprintf(`SELECT id,username,quota_size,used_quota_size,total_data_transfer,upload_data_transfer, + download_data_transfer,used_upload_data_transfer,used_download_data_transfer,filters FROM %v WHERE username IN %v`, sqlTableUsers, sb.String()) } @@ -292,6 +294,16 @@ func getDumpFoldersQuery() string { return fmt.Sprintf(`SELECT %v FROM %v`, selectFolderFields, sqlTableFolders) } +func getUpdateTransferQuotaQuery(reset bool) string { + if reset { + return fmt.Sprintf(`UPDATE %v SET used_upload_data_transfer = %v,used_download_data_transfer = %v,last_quota_update = %v + WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) + } + return fmt.Sprintf(`UPDATE %v SET used_upload_data_transfer = used_upload_data_transfer + %v, + used_download_data_transfer = used_download_data_transfer + %v,last_quota_update = %v + WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) +} + func getUpdateQuotaQuery(reset bool) string { if reset { return fmt.Sprintf(`UPDATE %v SET used_quota_size = %v,used_quota_files = %v,last_quota_update = %v @@ -323,28 +335,34 @@ func getUpdateShareLastUseQuery() string { } func getQuotaQuery() string { - return fmt.Sprintf(`SELECT used_quota_size,used_quota_files FROM %v WHERE username = %v`, sqlTableUsers, - sqlPlaceholders[0]) + return fmt.Sprintf(`SELECT used_quota_size,used_quota_files,used_upload_data_transfer, + used_download_data_transfer FROM %v WHERE username = %v`, + sqlTableUsers, sqlPlaceholders[0]) } func getAddUserQuery() string { return fmt.Sprintf(`INSERT INTO %v (username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions, used_quota_size,used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth,status,last_login,expiration_date,filters, - filesystem,additional_info,description,email,created_at,updated_at) - VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0,0,0,%v,%v,%v,0,%v,%v,%v,%v,%v,%v,%v,%v)`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], - sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], - sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], - sqlPlaceholders[14], sqlPlaceholders[15], sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19], - sqlPlaceholders[20]) + filesystem,additional_info,description,email,created_at,updated_at,upload_data_transfer,download_data_transfer,total_data_transfer, + used_upload_data_transfer,used_download_data_transfer) + VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0,0,0,%v,%v,%v,0,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0,0)`, + sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], + sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], + sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14], + sqlPlaceholders[15], sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19], + sqlPlaceholders[20], sqlPlaceholders[21], sqlPlaceholders[22], sqlPlaceholders[23]) } func getUpdateUserQuery() string { return fmt.Sprintf(`UPDATE %v SET password=%v,public_keys=%v,home_dir=%v,uid=%v,gid=%v,max_sessions=%v,quota_size=%v, quota_files=%v,permissions=%v,upload_bandwidth=%v,download_bandwidth=%v,status=%v,expiration_date=%v,filters=%v,filesystem=%v, - additional_info=%v,description=%v,email=%v,updated_at=%v WHERE id = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], - sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], - sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14], sqlPlaceholders[15], - sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19]) + additional_info=%v,description=%v,email=%v,updated_at=%v,upload_data_transfer=%v,download_data_transfer=%v, + total_data_transfer=%v WHERE id = %v`, + sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], + sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], + sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14], + sqlPlaceholders[15], sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19], + sqlPlaceholders[20], sqlPlaceholders[21], sqlPlaceholders[22]) } func getDeleteUserQuery() string { @@ -439,6 +457,34 @@ func getRelatedUsersForFoldersQuery(folders []vfs.BaseVirtualFolder) string { WHERE fm.folder_id IN %v ORDER BY fm.folder_id`, sqlTableFoldersMapping, sqlTableUsers, sb.String()) } +func getActiveTransfersQuery() string { + return fmt.Sprintf(`SELECT transfer_id,connection_id,transfer_type,username,folder_name,ip,truncated_size, + current_ul_size,current_dl_size,created_at,updated_at FROM %v WHERE updated_at > %v`, + sqlTableActiveTransfers, sqlPlaceholders[0]) +} + +func getAddActiveTransferQuery() string { + return fmt.Sprintf(`INSERT INTO %v (transfer_id,connection_id,transfer_type,username,folder_name,ip,truncated_size, + current_ul_size,current_dl_size,created_at,updated_at) VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v)`, + sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], + sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], + sqlPlaceholders[9], sqlPlaceholders[10]) +} + +func getUpdateActiveTransferSizesQuery() string { + return fmt.Sprintf(`UPDATE %v SET current_ul_size=%v,current_dl_size=%v,updated_at=%v WHERE connection_id = %v AND transfer_id = %v`, + sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4]) +} + +func getRemoveActiveTransferQuery() string { + return fmt.Sprintf(`DELETE FROM %v WHERE connection_id = %v AND transfer_id = %v`, + sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getCleanupActiveTransfersQuery() string { + return fmt.Sprintf(`DELETE FROM %v WHERE updated_at < %v`, sqlTableActiveTransfers, sqlPlaceholders[0]) +} + func getDatabaseVersionQuery() string { return fmt.Sprintf("SELECT version from %v LIMIT 1", sqlTableSchemaVersion) } diff --git a/dataprovider/user.go b/dataprovider/user.go index de130246..54ff1623 100644 --- a/dataprovider/user.go +++ b/dataprovider/user.go @@ -462,7 +462,7 @@ func (u *User) GetFilesystemForPath(virtualPath, connectionID string) (vfs.Fs, e // If the path is not inside a virtual folder an error is returned func (u *User) GetVirtualFolderForPath(virtualPath string) (vfs.VirtualFolder, error) { var folder vfs.VirtualFolder - if len(u.VirtualFolders) == 0 { + if virtualPath == "/" || len(u.VirtualFolders) == 0 { return folder, errNoMatchingVirtualFolder } dirsForPath := util.GetDirsForVirtualPath(virtualPath) @@ -1071,11 +1071,58 @@ func (u *User) GetHomeDir() string { return filepath.Clean(u.HomeDir) } -// HasQuotaRestrictions returns true if there is a quota restriction on number of files or size or both +// HasQuotaRestrictions returns true if there are any disk quota restrictions func (u *User) HasQuotaRestrictions() bool { return u.QuotaFiles > 0 || u.QuotaSize > 0 } +// HasTransferQuotaRestrictions returns true if there are any data transfer restrictions +func (u *User) HasTransferQuotaRestrictions() bool { + if len(u.Filters.DataTransferLimits) > 0 { + return true + } + return u.UploadDataTransfer > 0 || u.TotalDataTransfer > 0 || u.DownloadDataTransfer > 0 +} + +// GetDataTransferLimits returns upload, download and total data transfer limits +func (u *User) GetDataTransferLimits(clientIP string) (int64, int64, int64) { + var total, ul, dl int64 + if len(u.Filters.DataTransferLimits) > 0 { + ip := net.ParseIP(clientIP) + if ip != nil { + for _, limit := range u.Filters.DataTransferLimits { + for _, source := range limit.Sources { + _, ipNet, err := net.ParseCIDR(source) + if err == nil { + if ipNet.Contains(ip) { + if limit.TotalDataTransfer > 0 { + total = limit.TotalDataTransfer * 1048576 + } + if limit.DownloadDataTransfer > 0 { + dl = limit.DownloadDataTransfer * 1048576 + } + if limit.UploadDataTransfer > 0 { + ul = limit.UploadDataTransfer * 1048576 + } + return ul, dl, total + } + } + } + } + } + } + if u.TotalDataTransfer > 0 { + total = u.TotalDataTransfer * 1048576 + } + if u.DownloadDataTransfer > 0 { + dl = u.DownloadDataTransfer * 1048576 + } + if u.UploadDataTransfer > 0 { + ul = u.UploadDataTransfer * 1048576 + } + return ul, dl, total +} + // GetQuotaSummary returns used quota and limits if defined func (u *User) GetQuotaSummary() string { var result string @@ -1283,33 +1330,50 @@ func (u *User) getACopy() User { copy(bwLimit.Sources, limit.Sources) filters.BandwidthLimits = append(filters.BandwidthLimits, bwLimit) } + filters.DataTransferLimits = make([]sdk.DataTransferLimit, 0, len(u.Filters.DataTransferLimits)) + for _, limit := range u.Filters.DataTransferLimits { + dtLimit := sdk.DataTransferLimit{ + UploadDataTransfer: limit.UploadDataTransfer, + DownloadDataTransfer: limit.DownloadDataTransfer, + TotalDataTransfer: limit.TotalDataTransfer, + Sources: make([]string, 0, len(limit.Sources)), + } + dtLimit.Sources = make([]string, len(limit.Sources)) + copy(dtLimit.Sources, limit.Sources) + filters.DataTransferLimits = append(filters.DataTransferLimits, dtLimit) + } return User{ BaseUser: sdk.BaseUser{ - ID: u.ID, - Username: u.Username, - Email: u.Email, - Password: u.Password, - PublicKeys: pubKeys, - HomeDir: u.HomeDir, - UID: u.UID, - GID: u.GID, - MaxSessions: u.MaxSessions, - QuotaSize: u.QuotaSize, - QuotaFiles: u.QuotaFiles, - Permissions: permissions, - UsedQuotaSize: u.UsedQuotaSize, - UsedQuotaFiles: u.UsedQuotaFiles, - LastQuotaUpdate: u.LastQuotaUpdate, - UploadBandwidth: u.UploadBandwidth, - DownloadBandwidth: u.DownloadBandwidth, - Status: u.Status, - ExpirationDate: u.ExpirationDate, - LastLogin: u.LastLogin, - AdditionalInfo: u.AdditionalInfo, - Description: u.Description, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, + ID: u.ID, + Username: u.Username, + Email: u.Email, + Password: u.Password, + PublicKeys: pubKeys, + HomeDir: u.HomeDir, + UID: u.UID, + GID: u.GID, + MaxSessions: u.MaxSessions, + QuotaSize: u.QuotaSize, + QuotaFiles: u.QuotaFiles, + Permissions: permissions, + UsedQuotaSize: u.UsedQuotaSize, + UsedQuotaFiles: u.UsedQuotaFiles, + LastQuotaUpdate: u.LastQuotaUpdate, + UploadBandwidth: u.UploadBandwidth, + DownloadBandwidth: u.DownloadBandwidth, + UploadDataTransfer: u.UploadDataTransfer, + DownloadDataTransfer: u.DownloadDataTransfer, + TotalDataTransfer: u.TotalDataTransfer, + UsedUploadDataTransfer: u.UsedUploadDataTransfer, + UsedDownloadDataTransfer: u.UsedDownloadDataTransfer, + Status: u.Status, + ExpirationDate: u.ExpirationDate, + LastLogin: u.LastLogin, + AdditionalInfo: u.AdditionalInfo, + Description: u.Description, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, }, Filters: filters, VirtualFolders: virtualFolders, diff --git a/docs/full-configuration.md b/docs/full-configuration.md index d2931845..1d91d2bc 100644 --- a/docs/full-configuration.md +++ b/docs/full-configuration.md @@ -214,7 +214,7 @@ The configuration file contains the following sections: - `update_mode`, integer. Defines how the database will be initialized/updated. 0 means automatically. 1 means manually using the initprovider sub-command. - `skip_natural_keys_validation`, boolean. If `true` you can use any UTF-8 character for natural keys as username, admin name, folder name. These keys are used in URIs for REST API and Web admin. If `false` only unreserved URI characters are allowed: ALPHA / DIGIT / "-" / "." / "_" / "~". Default: `false`. - `create_default_admin`, boolean. Before you can use SFTPGo you need to create an admin account. If you open the admin web UI, a setup screen will guide you in creating the first admin account. You can automatically create the first admin account by enabling this setting and setting the environment variables `SFTPGO_DEFAULT_ADMIN_USERNAME` and `SFTPGO_DEFAULT_ADMIN_PASSWORD`. You can also create the first admin by loading initial data. This setting has no effect if an admin account is already found within the data provider. Default `false`. - - `is_shared`, integer. If the data provider is shared across multiple SFTPGo instances, set this parameter to `1`. `MySQL`, `PostgreSQL` and `CockroachDB` can be shared, this setting is ignored for other data providers. For shared data providers, SFTPGo periodically reloads the latest updated users, based on the `updated_at` field, and updates its internal caches if users are updated from a different instance. This check, if enabled, is executed every 10 minutes. Default: `0`. + - `is_shared`, integer. If the data provider is shared across multiple SFTPGo instances, set this parameter to `1`. `MySQL`, `PostgreSQL` and `CockroachDB` can be shared, this setting is ignored for other data providers. For shared data providers, SFTPGo periodically reloads the latest updated users, based on the `updated_at` field, and updates its internal caches if users are updated from a different instance. This check, if enabled, is executed every 10 minutes. For shared data providers, active transfers are persisted in the database and thus quota checks between ongoing transfers will work cross multiple instances. Default: `0`. - **"httpd"**, the configuration for the HTTP server used to serve REST API and to expose the built-in web interface - `bindings`, list of structs. Each struct has the following fields: - `port`, integer. The port used for serving HTTP requests. Default: 8080. diff --git a/docs/howto/getting-started.md b/docs/howto/getting-started.md index 855186b7..ce444bc8 100644 --- a/docs/howto/getting-started.md +++ b/docs/howto/getting-started.md @@ -342,7 +342,7 @@ Restart SFTPGo to apply the changes. ### Use CockroachDB data provider -We suppose you have installed CocroackDB this way: +We suppose you have installed CockroachDB this way: ```shell sudo su diff --git a/ftpd/ftpd_test.go b/ftpd/ftpd_test.go index 9bd4adea..bda37e24 100644 --- a/ftpd/ftpd_test.go +++ b/ftpd/ftpd_test.go @@ -289,7 +289,7 @@ func TestMain(m *testing.M) { os.Exit(1) } - err = common.Initialize(commonConf) + err = common.Initialize(commonConf, 0) if err != nil { logger.WarnToConsole("error initializing common: %v", err) os.Exit(1) @@ -1042,7 +1042,7 @@ func TestRateLimiter(t *testing.T) { }, } - err := common.Initialize(cfg) + err := common.Initialize(cfg, 0) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) @@ -1076,7 +1076,7 @@ func TestRateLimiter(t *testing.T) { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - err = common.Initialize(oldConfig) + err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } @@ -1088,7 +1088,7 @@ func TestDefender(t *testing.T) { cfg.DefenderConfig.Threshold = 3 cfg.DefenderConfig.ScoreLimitExceeded = 2 - err := common.Initialize(cfg) + err := common.Initialize(cfg, 0) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) @@ -1118,7 +1118,7 @@ func TestDefender(t *testing.T) { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - err = common.Initialize(oldConfig) + err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } @@ -1998,6 +1998,71 @@ func TestUploadOverwriteVfolder(t *testing.T) { assert.NoError(t, err) } +func TestTransferQuotaLimits(t *testing.T) { + u := getTestUser() + u.DownloadDataTransfer = 1 + u.UploadDataTransfer = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(524288) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), ftpserver.ErrStorageExceeded.Error()) + } + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) + } + err = client.Quit() + assert.NoError(t, err) + } + + testFileSize = int64(600000) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + user.DownloadDataTransfer = 2 + user.UploadDataTransfer = 2 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.Error(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.Error(t, err) + + err = client.Quit() + assert.NoError(t, err) + } + + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + func TestAllocateAvailable(t *testing.T) { u := getTestUser() mappedPath := filepath.Join(os.TempDir(), "vdir") @@ -2042,18 +2107,10 @@ func TestAllocateAvailable(t *testing.T) { testFileSize := user.QuotaSize - 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) - code, response, err := client.SendCustomCommand("allo 99") + code, response, err := client.SendCustomCommand("allo 1000") assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "Done !", response) - code, response, err = client.SendCustomCommand("allo 100") - assert.NoError(t, err) - assert.Equal(t, ftp.StatusCommandOK, code) - assert.Equal(t, "Done !", response) - code, response, err = client.SendCustomCommand("allo 150") - assert.NoError(t, err) - assert.Equal(t, ftp.StatusFileUnavailable, code) - assert.Contains(t, response, ftpserver.ErrStorageExceeded.Error()) err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) assert.NoError(t, err) @@ -2063,38 +2120,69 @@ func TestAllocateAvailable(t *testing.T) { assert.Equal(t, ftp.StatusFile, code) assert.Equal(t, "1", response) - // we still have space in vdir - code, response, err = client.SendCustomCommand("allo 50") - assert.NoError(t, err) - assert.Equal(t, ftp.StatusCommandOK, code) - assert.Equal(t, "Done !", response) - err = ftpUploadFile(testFilePath, path.Join("/vdir", testFileName), testFileSize, client, 0) - assert.NoError(t, err) - code, response, err = client.SendCustomCommand("allo 50") - assert.NoError(t, err) - assert.Equal(t, ftp.StatusFileUnavailable, code) - assert.Contains(t, response, ftpserver.ErrStorageExceeded.Error()) - err = client.Quit() assert.NoError(t, err) err = os.Remove(testFilePath) assert.NoError(t, err) } + user.TotalDataTransfer = 1 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + code, response, err := client.SendCustomCommand("AVBL") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + assert.Equal(t, "1", response) - user.Filters.MaxUploadFileSize = 100 + err = client.Quit() + assert.NoError(t, err) + } + + user.TotalDataTransfer = 0 + user.UploadDataTransfer = 5 + user.QuotaSize = 6 * 1024 * 1024 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + code, response, err := client.SendCustomCommand("AVBL") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + assert.Equal(t, "5242880", response) + + err = client.Quit() + assert.NoError(t, err) + } + + user.TotalDataTransfer = 0 + user.UploadDataTransfer = 5 user.QuotaSize = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client, err = getFTPClient(user, false, nil) if assert.NoError(t, err) { - code, response, err := client.SendCustomCommand("allo 99") + code, response, err := client.SendCustomCommand("AVBL") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + assert.Equal(t, "5242880", response) + + err = client.Quit() + assert.NoError(t, err) + } + + user.Filters.MaxUploadFileSize = 100 + user.QuotaSize = 0 + user.TotalDataTransfer = 0 + user.UploadDataTransfer = 0 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + code, response, err := client.SendCustomCommand("allo 10000") assert.NoError(t, err) assert.Equal(t, ftp.StatusCommandOK, code) assert.Equal(t, "Done !", response) - code, response, err = client.SendCustomCommand("allo 150") - assert.NoError(t, err) - assert.Equal(t, ftp.StatusFileUnavailable, code) - assert.Contains(t, response, ftpserver.ErrStorageExceeded.Error()) code, response, err = client.SendCustomCommand("AVBL") assert.NoError(t, err) diff --git a/ftpd/handler.go b/ftpd/handler.go index da58808c..dace0cac 100644 --- a/ftpd/handler.go +++ b/ftpd/handler.go @@ -202,12 +202,12 @@ func (c *Connection) Chtimes(name string, atime time.Time, mtime time.Time) erro func (c *Connection) GetAvailableSpace(dirName string) (int64, error) { c.UpdateLastActivity() - quotaResult := c.HasSpace(false, false, path.Join(dirName, "fakefile.txt")) - if !quotaResult.HasSpace { + diskQuota, transferQuota := c.HasSpace(false, false, path.Join(dirName, "fakefile.txt")) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { return 0, nil } - if quotaResult.AllowedSize == 0 { + if diskQuota.AllowedSize == 0 && transferQuota.AllowedULSize == 0 && transferQuota.AllowedTotalSize == 0 { // no quota restrictions if c.User.Filters.MaxUploadFileSize > 0 { return c.User.Filters.MaxUploadFileSize, nil @@ -225,45 +225,35 @@ func (c *Connection) GetAvailableSpace(dirName string) (int64, error) { return int64(statVFS.FreeSpace()), nil } + allowedDiskSize := diskQuota.AllowedSize + allowedUploadSize := transferQuota.AllowedULSize + if transferQuota.AllowedTotalSize > 0 { + allowedUploadSize = transferQuota.AllowedTotalSize + } + allowedSize := allowedDiskSize + if allowedSize == 0 { + allowedSize = allowedUploadSize + } else { + if allowedUploadSize > 0 && allowedUploadSize < allowedSize { + allowedSize = allowedUploadSize + } + } // the available space is the minimum between MaxUploadFileSize, if setted, // and quota allowed size if c.User.Filters.MaxUploadFileSize > 0 { - if c.User.Filters.MaxUploadFileSize < quotaResult.AllowedSize { + if c.User.Filters.MaxUploadFileSize < allowedSize { return c.User.Filters.MaxUploadFileSize, nil } } - return quotaResult.AllowedSize, nil + return allowedSize, nil } // AllocateSpace implements ClientDriverExtensionAllocate interface func (c *Connection) AllocateSpace(size int) error { c.UpdateLastActivity() - // check the max allowed file size first - if c.User.Filters.MaxUploadFileSize > 0 && int64(size) > c.User.Filters.MaxUploadFileSize { - return c.GetQuotaExceededError() - } - - // we don't have a path here so we check home dir and any virtual folders - // we return no error if there is space in any folder - folders := []string{"/"} - for _, v := range c.User.VirtualFolders { - // the space is checked for the parent folder - folders = append(folders, path.Join(v.VirtualPath, "fakefile.txt")) - } - for _, f := range folders { - quotaResult := c.HasSpace(false, false, f) - if quotaResult.HasSpace { - if quotaResult.QuotaSize == 0 { - // unlimited size is allowed - return nil - } - if quotaResult.GetRemainingSize() > int64(size) { - return nil - } - } - } - return c.GetQuotaExceededError() + // we treat ALLO as NOOP see RFC 959 + return nil } // RemoveDir implements ClientDriverExtensionRemoveDir @@ -318,6 +308,11 @@ func (c *Connection) downloadFile(fs vfs.Fs, fsPath, ftpPath string, offset int6 if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(ftpPath)) { return nil, c.GetPermissionDeniedError() } + transferQuota := c.GetTransferQuota() + if !transferQuota.HasDownloadSpace() { + c.Log(logger.LevelInfo, "denying file read due to quota limits") + return nil, c.GetReadQuotaExceededError() + } if ok, policy := c.User.IsFileAllowed(ftpPath); !ok { c.Log(logger.LevelWarn, "reading file %#v is not allowed", ftpPath) @@ -336,7 +331,7 @@ func (c *Connection) downloadFile(fs vfs.Fs, fsPath, ftpPath string, offset int6 } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath, - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs, transferQuota) baseTransfer.SetFtpMode(c.getFTPMode()) t := newTransfer(baseTransfer, nil, r, offset) @@ -381,8 +376,8 @@ func (c *Connection) uploadFile(fs vfs.Fs, fsPath, ftpPath string, flags int) (f } func (c *Connection) handleFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, requestPath string) (ftpserver.FileTransfer, error) { - quotaResult := c.HasSpace(true, false, requestPath) - if !quotaResult.HasSpace { + diskQuota, transferQuota := c.HasSpace(true, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, ftpserver.ErrStorageExceeded } @@ -399,10 +394,10 @@ func (c *Connection) handleFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) // we can get an error only for resume - maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) + maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, 0, fs.IsUploadResumeSupported()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs) + common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs, transferQuota) baseTransfer.SetFtpMode(c.getFTPMode()) t := newTransfer(baseTransfer, w, nil, 0) @@ -412,8 +407,8 @@ func (c *Connection) handleFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolvedPath, filePath string, fileSize int64, requestPath string) (ftpserver.FileTransfer, error) { var err error - quotaResult := c.HasSpace(false, false, requestPath) - if !quotaResult.HasSpace { + diskQuota, transferQuota := c.HasSpace(false, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, ftpserver.ErrStorageExceeded } @@ -426,7 +421,7 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve isResume := flags&os.O_TRUNC == 0 // if there is a size limit remaining size cannot be 0 here, since quotaResult.HasSpace // will return false in this case and we deny the upload before - maxWriteSize, err := c.GetMaxWriteSize(quotaResult, isResume, fileSize, fs.IsUploadResumeSupported()) + maxWriteSize, err := c.GetMaxWriteSize(diskQuota, isResume, fileSize, fs.IsUploadResumeSupported()) if err != nil { c.Log(logger.LevelDebug, "unable to get max write size: %v", err) return nil, err @@ -481,7 +476,7 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs) + common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs, transferQuota) baseTransfer.SetFtpMode(c.getFTPMode()) t := newTransfer(baseTransfer, w, nil, 0) diff --git a/ftpd/internal_test.go b/ftpd/internal_test.go index 7febea90..52923b2f 100644 --- a/ftpd/internal_test.go +++ b/ftpd/internal_test.go @@ -808,7 +808,7 @@ func TestTransferErrors(t *testing.T) { clientContext: mockCC, } baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, file.Name(), file.Name(), testfile, - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) tr := newTransfer(baseTransfer, nil, nil, 0) err = tr.Close() assert.NoError(t, err) @@ -826,7 +826,7 @@ func TestTransferErrors(t *testing.T) { r, _, err := pipeat.Pipe() assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile, - common.TransferUpload, 0, 0, 0, 0, false, fs) + common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) tr = newTransfer(baseTransfer, nil, r, 10) pos, err := tr.Seek(10, 0) assert.NoError(t, err) @@ -838,7 +838,7 @@ func TestTransferErrors(t *testing.T) { assert.NoError(t, err) pipeWriter := vfs.NewPipeWriter(w) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile, - common.TransferUpload, 0, 0, 0, 0, false, fs) + common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) tr = newTransfer(baseTransfer, pipeWriter, nil, 0) err = r.Close() diff --git a/ftpd/transfer.go b/ftpd/transfer.go index 1a9e5aa9..5cd2e1d2 100644 --- a/ftpd/transfer.go +++ b/ftpd/transfer.go @@ -49,6 +49,9 @@ func (t *transfer) Read(p []byte) (n int, err error) { n, err = t.reader.Read(p) atomic.AddInt64(&t.BytesSent, int64(n)) + if err == nil { + err = t.CheckRead() + } if err != nil && err != io.EOF { t.TransferError(err) return @@ -64,8 +67,8 @@ func (t *transfer) Write(p []byte) (n int, err error) { n, err = t.writer.Write(p) atomic.AddInt64(&t.BytesReceived, int64(n)) - if t.MaxWriteSize > 0 && err == nil && atomic.LoadInt64(&t.BytesReceived) > t.MaxWriteSize { - err = t.Connection.GetQuotaExceededError() + if err == nil { + err = t.CheckWrite() } if err != nil { t.TransferError(err) diff --git a/go.mod b/go.mod index fdc44374..7ab7e568 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,11 @@ module github.com/drakkan/sftpgo/v2 go 1.17 require ( - cloud.google.com/go/storage v1.18.2 + cloud.google.com/go/storage v1.19.0 github.com/Azure/azure-storage-blob-go v0.14.0 github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962 github.com/alexedwards/argon2id v0.0.0-20211130144151-3585854a6387 - github.com/aws/aws-sdk-go v1.42.37 + github.com/aws/aws-sdk-go v1.42.44 github.com/cockroachdb/cockroach-go/v2 v2.2.6 github.com/eikenb/pipeat v0.0.0-20210603033007-44fc3ffce52b github.com/fclairamb/ftpserverlib v0.17.0 @@ -24,39 +24,39 @@ require ( github.com/hashicorp/go-plugin v1.4.3 github.com/hashicorp/go-retryablehttp v0.7.0 github.com/jlaffaye/ftp v0.0.0-20201112195030-9aae4d151126 - github.com/klauspost/compress v1.14.1 - github.com/lestrrat-go/jwx v1.2.17 + github.com/klauspost/compress v1.14.2 + github.com/lestrrat-go/jwx v1.2.18 github.com/lib/pq v1.10.4 github.com/lithammer/shortuuid/v3 v3.0.7 - github.com/mattn/go-sqlite3 v1.14.10 + github.com/mattn/go-sqlite3 v1.14.11 github.com/mhale/smtpd v0.8.0 github.com/minio/sio v0.3.0 github.com/otiai10/copy v1.7.0 github.com/pires/go-proxyproto v0.6.1 - github.com/pkg/sftp v1.13.5-0.20211217081921-1849af66afae + github.com/pkg/sftp v1.13.5-0.20220119192800-7d25d533c9a3 github.com/pquerna/otp v1.3.0 - github.com/prometheus/client_golang v1.12.0 + github.com/prometheus/client_golang v1.12.1 github.com/rs/cors v1.8.2 github.com/rs/xid v1.3.0 github.com/rs/zerolog v1.26.2-0.20211219225053-665519c4da50 - github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea + github.com/sftpgo/sdk v0.0.0-20220130093602-2e82a333cdec github.com/shirou/gopsutil/v3 v3.21.13-0.20220106132423-a3ae4bc40d26 github.com/spf13/afero v1.8.0 github.com/spf13/cobra v1.3.0 github.com/spf13/viper v1.10.1 github.com/stretchr/testify v1.7.0 - github.com/studio-b12/gowebdav v0.0.0-20211106090535-29e74efa701f + github.com/studio-b12/gowebdav v0.0.0-20220128162035-c7b1ff8a5e62 github.com/wagslane/go-password-validator v0.3.0 github.com/xhit/go-simple-mail/v2 v2.10.0 github.com/yl2chen/cidranger v1.0.3-0.20210928021809-d1cb2c52f37a go.etcd.io/bbolt v1.3.6 go.uber.org/automaxprocs v1.4.0 gocloud.dev v0.24.0 - golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 - golang.org/x/net v0.0.0-20220111093109-d55c255bac03 - golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 + golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed + golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd + golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27 golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 - google.golang.org/api v0.65.0 + google.golang.org/api v0.66.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) @@ -76,7 +76,7 @@ require ( github.com/fatih/color v1.13.0 // indirect github.com/fsnotify/fsnotify v1.5.1 // indirect github.com/go-ole/go-ole v1.2.6 // indirect - github.com/goccy/go-json v0.9.3 // indirect + github.com/goccy/go-json v0.9.4 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/google/go-cmp v0.5.7 // indirect @@ -96,7 +96,7 @@ require ( github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magiconair/properties v1.8.5 // indirect github.com/mattn/go-colorable v0.1.12 // indirect - github.com/mattn/go-ieproxy v0.0.1 // indirect + github.com/mattn/go-ieproxy v0.0.3-0.20220115171849-ffa2c199638b // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/miekg/dns v1.1.45 // indirect @@ -123,11 +123,11 @@ require ( golang.org/x/mod v0.5.1 // indirect golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 // indirect golang.org/x/text v0.3.7 // indirect - golang.org/x/tools v0.1.8 // indirect + golang.org/x/tools v0.1.9 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 // indirect - google.golang.org/grpc v1.43.0 // indirect + google.golang.org/genproto v0.0.0-20220126215142-9970aeb2e350 // indirect + google.golang.org/grpc v1.44.0 // indirect google.golang.org/protobuf v1.27.1 // indirect gopkg.in/ini.v1 v1.66.3 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect @@ -138,6 +138,6 @@ replace ( github.com/eikenb/pipeat => github.com/drakkan/pipeat v0.0.0-20210805162858-70e57fa8a639 github.com/fclairamb/ftpserverlib => github.com/drakkan/ftpserverlib v0.0.0-20220113173527-7442aa777ac0 github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9 - golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20220106154735-630a1d952834 - golang.org/x/net => github.com/drakkan/net v0.0.0-20220113164424-6c7f3de7b303 + golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20220130095207-a206cf284b7c + golang.org/x/net => github.com/drakkan/net v0.0.0-20220130095023-bd85f1236c34 ) diff --git a/go.sum b/go.sum index be2ed194..3525383c 100644 --- a/go.sum +++ b/go.sum @@ -70,8 +70,8 @@ cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RX cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= cloud.google.com/go/storage v1.16.1/go.mod h1:LaNorbty3ehnU3rEjXSNV/NRgQA0O8Y+uh6bPe5UOk4= -cloud.google.com/go/storage v1.18.2 h1:5NQw6tOn3eMm0oE8vTkfjau18kjL79FlMjy/CHTpmoY= -cloud.google.com/go/storage v1.18.2/go.mod h1:AiIj7BWXyhO5gGVmYJ+S8tbkCx3yb0IMjua8Aw4naVM= +cloud.google.com/go/storage v1.19.0 h1:XOQSnPJD8hRtZJ3VdCyK0mBZsGGImrzPAMbSWcHSe6Q= +cloud.google.com/go/storage v1.19.0/go.mod h1:6rgiTRjOqI/Zd9YKimub5TIB4d+p3LH33V3ZE1DMuUM= cloud.google.com/go/trace v0.1.0/go.mod h1:wxEwsoeRVPbeSkt7ZC9nWCgmoKQRAoySN7XHW2AmI7g= contrib.go.opencensus.io/exporter/aws v0.0.0-20200617204711-c478e41e60e9/go.mod h1:uu1P0UCM/6RbsMrgPa98ll8ZcHM858i/AD06a9aLRCA= contrib.go.opencensus.io/exporter/stackdriver v0.13.8/go.mod h1:huNtlWx75MwO7qMs0KrMxPZXzNNWebav1Sq/pm02JdQ= @@ -141,8 +141,8 @@ github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgI github.com/aws/aws-sdk-go v1.15.27/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0= github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= github.com/aws/aws-sdk-go v1.40.34/go.mod h1:585smgzpB/KqRA+K3y/NL/oYRqQvpNJYvLm+LY1U59Q= -github.com/aws/aws-sdk-go v1.42.37 h1:EIziSq3REaoi1LgUBgxoQr29DQS7GYHnBbZPajtJmXM= -github.com/aws/aws-sdk-go v1.42.37/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc= +github.com/aws/aws-sdk-go v1.42.44 h1:vPlF4cUsdN5ETfvb7ewZFbFZyB6Rsfndt3kS2XqLXKo= +github.com/aws/aws-sdk-go v1.42.44/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc= github.com/aws/aws-sdk-go-v2 v1.9.0/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4= github.com/aws/aws-sdk-go-v2/config v1.7.0/go.mod h1:w9+nMZ7soXCe5nT46Ri354SNhXDQ6v+V5wqDjnZE+GY= github.com/aws/aws-sdk-go-v2/credentials v1.4.0/go.mod h1:dgGR+Qq7Wjcd4AOAW5Rf5Tnv3+x7ed6kETXyS9WCuAY= @@ -214,14 +214,14 @@ github.com/devigned/tab v0.1.1/go.mod h1:XG9mPq0dFghrYvoBF3xdRrJzSTX1b7IQrvaL9mz github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dimchansky/utfbom v1.1.0/go.mod h1:rO41eb7gLfo8SF1jd9F8HplJm1Fewwi4mQvIirEdv+8= github.com/dimchansky/utfbom v1.1.1/go.mod h1:SxdoEBH5qIqFocHMyGOXVAybYJdr71b1Q/j0mACtrfE= -github.com/drakkan/crypto v0.0.0-20220106154735-630a1d952834 h1:uyeD3WaQSuxc7/d061EBmjVQj59kMMFqe21U8/IEP7A= -github.com/drakkan/crypto v0.0.0-20220106154735-630a1d952834/go.mod h1:SiM6ypd8Xu1xldObYtbDztuUU7xUzMnUULfphXFZmro= +github.com/drakkan/crypto v0.0.0-20220130095207-a206cf284b7c h1:IqTZK/MGRdMPRyyJQSxDtrEokSJDJl+nreM2/CFYTsg= +github.com/drakkan/crypto v0.0.0-20220130095207-a206cf284b7c/go.mod h1:SiM6ypd8Xu1xldObYtbDztuUU7xUzMnUULfphXFZmro= github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9 h1:LPH1dEblAOO/LoG7yHPMtBLXhQmjaga91/DDjWk9jWA= github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9/go.mod h1:2lmrmq866uF2tnje75wQHzmPXhmSWUt7Gyx2vgK1RCU= github.com/drakkan/ftpserverlib v0.0.0-20220113173527-7442aa777ac0 h1:8lhuOHaxuxiVuTiS8NHCXZKZ28WWxDzwwwIn673c6Jg= github.com/drakkan/ftpserverlib v0.0.0-20220113173527-7442aa777ac0/go.mod h1:erV/bp9DEm6wvpPewC02KUJz0gdReWyz/7nHZP+4pAI= -github.com/drakkan/net v0.0.0-20220113164424-6c7f3de7b303 h1:0aNnMI/95JKY6sqHbZLVGGvHOP7l6ZI7HU7ejAYm7pM= -github.com/drakkan/net v0.0.0-20220113164424-6c7f3de7b303/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +github.com/drakkan/net v0.0.0-20220130095023-bd85f1236c34 h1:DRayAKtBRaVU3jg58b/HCbkRleByBD5q6NkN1wcJ2RU= +github.com/drakkan/net v0.0.0-20220130095023-bd85f1236c34/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= github.com/drakkan/pipeat v0.0.0-20210805162858-70e57fa8a639 h1:8tfGdb4kg/YCvAbIrsMazgoNtnqdOqQVDKW12uUCuuU= github.com/drakkan/pipeat v0.0.0-20210805162858-70e57fa8a639/go.mod h1:kltMsfRMTHSFdMbK66XdS8mfMW77+FZA1fGY1xYMF84= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -283,9 +283,8 @@ github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22 github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/goccy/go-json v0.7.6/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/goccy/go-json v0.9.1/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/goccy/go-json v0.9.3 h1:VYKeLtdIQXWaeTZy5JNGZbVui5ck7Vf5MlWEcflqz0s= -github.com/goccy/go-json v0.9.3/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/goccy/go-json v0.9.4 h1:L8MLKG2mvVXiQu07qB6hmfqeSYQdOnqPot2GhsIwIaI= +github.com/goccy/go-json v0.9.4/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= @@ -508,8 +507,8 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.13.5/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/klauspost/compress v1.14.1 h1:hLQYb23E8/fO+1u53d02A97a8UnsddcvYzq4ERRU4ds= -github.com/klauspost/compress v1.14.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.14.2 h1:S0OHlFk/Gbon/yauFJ4FfJJF5V0fc5HbBTJazi28pRw= +github.com/klauspost/compress v1.14.2/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/cpuid/v2 v2.0.4/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= @@ -539,8 +538,8 @@ github.com/lestrrat-go/httpcc v1.0.0/go.mod h1:tGS/u00Vh5N6FHNkExqGGNId8e0Big+++ github.com/lestrrat-go/iter v1.0.1 h1:q8faalr2dY6o8bV45uwrxq12bRa1ezKrB6oM9FUgN4A= github.com/lestrrat-go/iter v1.0.1/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbqPDrJ/OJc= github.com/lestrrat-go/jwx v1.2.6/go.mod h1:tJuGuAI3LC71IicTx82Mz1n3w9woAs2bYJZpkjJQ5aU= -github.com/lestrrat-go/jwx v1.2.17 h1:e6IWTrTu4pI7B8wa9TfAY17Ra9o5ymZ95L5tAjWtfF8= -github.com/lestrrat-go/jwx v1.2.17/go.mod h1:UxIzTZAhlHvgx83iJpnm24r5luD7zlFrtHVbG7Qs9DU= +github.com/lestrrat-go/jwx v1.2.18 h1:RV4hcTRUlPVYUnGqATKXEojoOsLexoU8Na4KheVzxQ8= +github.com/lestrrat-go/jwx v1.2.18/go.mod h1:bWTBO7IHHVMtNunM8so9MT8wD+euEY1PzGEyCnuI2qM= github.com/lestrrat-go/option v1.0.0 h1:WqAWL8kh8VcSoD6xjSH34/1m8yxluXQbDeKNfvFeEO4= github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -566,8 +565,9 @@ github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-ieproxy v0.0.1 h1:qiyop7gCflfhwCzGyeT0gro3sF9AIg9HU98JORTkqfI= github.com/mattn/go-ieproxy v0.0.1/go.mod h1:pYabZ6IHcRpFh7vIaLfK7rdcWgFEb3SFJ6/gNWuh88E= +github.com/mattn/go-ieproxy v0.0.3-0.20220115171849-ffa2c199638b h1:hOk7BgJT/9Vt2aIrfXp0qA6hwY2JZSwX4Rmsgp8DJ6E= +github.com/mattn/go-ieproxy v0.0.3-0.20220115171849-ffa2c199638b/go.mod h1:6ZpRmhBaYuBX1U2za+9rC9iCGLsSp2tftelZne7CPko= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= @@ -579,8 +579,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= -github.com/mattn/go-sqlite3 v1.14.10 h1:MLn+5bFRlWMGoSRmJour3CL1w/qL96mvipqpwQW/Sfk= -github.com/mattn/go-sqlite3 v1.14.10/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.11 h1:gt+cp9c0XGqe9S/wAHTL3n/7MqY+siPWgWJgqdsFrzQ= +github.com/mattn/go-sqlite3 v1.14.11/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mhale/smtpd v0.8.0 h1:5JvdsehCg33PQrZBvFyDMMUDQmvbzVpZgKob7eYBJc0= @@ -636,8 +636,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= -github.com/pkg/sftp v1.13.5-0.20211217081921-1849af66afae h1:J8MHmz3LSjRtoR4SKiPq8BNo3DacJl5kQRjJeWilkUI= -github.com/pkg/sftp v1.13.5-0.20211217081921-1849af66afae/go.mod h1:wHDZ0IZX6JcBYRK1TH9bcVq8G7TLpVHYIGJRFnmPfxg= +github.com/pkg/sftp v1.13.5-0.20220119192800-7d25d533c9a3 h1:gyvzmVdk4vso+w4gt8x2YtMdbAGSyX5KnekiEsbDLvQ= +github.com/pkg/sftp v1.13.5-0.20220119192800-7d25d533c9a3/go.mod h1:wHDZ0IZX6JcBYRK1TH9bcVq8G7TLpVHYIGJRFnmPfxg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= @@ -651,8 +651,8 @@ github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5Fsn github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.12.0 h1:C+UIj/QWtmqY13Arb8kwMt5j34/0Z2iKamrJ+ryC0Gg= -github.com/prometheus/client_golang v1.12.0/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= +github.com/prometheus/client_golang v1.12.1 h1:ZiaPsmm9uiBeaSMRznKsCDNtPCS0T3JVDGF+06gjBzk= +github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -691,8 +691,8 @@ github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdh github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4 h1:PT+ElG/UUFMfqy5HrxJxNzj3QBOf7dZwupeVC+mG1Lo= github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4/go.mod h1:MnkX001NG75g3p8bhFycnyIjeQoOjGL6CEIsdE/nKSY= -github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea h1:ouwL3x9tXiAXIhdXtJGONd905f1dBLu3HhfFoaTq24k= -github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea/go.mod h1:Bhgac6kiwIziILXLzH4wepT8lQXyhF83poDXqZorN6Q= +github.com/sftpgo/sdk v0.0.0-20220130093602-2e82a333cdec h1:zdL+7nNYny5e87IDZMFReFHviKRenxmCGDwgLwHIrwU= +github.com/sftpgo/sdk v0.0.0-20220130093602-2e82a333cdec/go.mod h1:gcYbk4z578GfwbC9kJOz2rltYoPYUIcGZgV13r74MJw= github.com/shirou/gopsutil/v3 v3.21.13-0.20220106132423-a3ae4bc40d26 h1:nkvraEu1xs6D3AimiR9SkIOCG6lVvVZRfwbbQ7fX1DY= github.com/shirou/gopsutil/v3 v3.21.13-0.20220106132423-a3ae4bc40d26/go.mod h1:BToYZVTlSVlfazpDDYFnsVZLaoRG+g8ufT6fPQLdJzA= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= @@ -728,8 +728,8 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/studio-b12/gowebdav v0.0.0-20211106090535-29e74efa701f h1:SLJx0nHhb2ZLlYNMAbrYsjwmVwXx4yRT48lNIxOp7ts= -github.com/studio-b12/gowebdav v0.0.0-20211106090535-29e74efa701f/go.mod h1:gCcfDlA1Y7GqOaeEKw5l9dOGx1VLdc/HuQSlQAaZ30s= +github.com/studio-b12/gowebdav v0.0.0-20220128162035-c7b1ff8a5e62 h1:b2nJXyPCa9HY7giGM+kYcnQ71m14JnGdQabMPmyt++8= +github.com/studio-b12/gowebdav v0.0.0-20220128162035-c7b1ff8a5e62/go.mod h1:bHA7t77X/QFExdeAnDzK6vKM34kEZAcE1OX4MfiwjkE= github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tklauser/go-sysconf v0.3.9 h1:JeUVdAOWhhxVcU6Eqr/ATFHgXk/mmiItdKeJPev3vTo= @@ -919,7 +919,6 @@ golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210503080704-8803ae5d1324/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -936,7 +935,6 @@ golang.org/x/sys v0.0.0-20210816183151-1e6c022a8912/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210908233432-aa78b53d3365/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210917161153-d61c044b1678/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211013075003-97ac67df715c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -945,11 +943,14 @@ golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211205182925-97ca703d548d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 h1:XfKQ4OlFl8okEOr5UvAqFRVj8pY/4yfcXrddB8qAbU0= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= +golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27 h1:XDXtA5hveEEV8JB2l7nhMTp3t3cHp9ZpwcdjqyEWLlo= +golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1031,8 +1032,8 @@ golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.8 h1:P1HhGGuLW4aAclzjtmJdf0mJOjVUZUzOTqkAkWL+l6w= -golang.org/x/tools v0.1.8/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= +golang.org/x/tools v0.1.9 h1:j9KsMiaP1c3B0OTQGth0/k+miLGTgLsAFUCrF2vLcF8= +golang.org/x/tools v0.1.9/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1072,14 +1073,14 @@ google.golang.org/api v0.54.0/go.mod h1:7C4bFFOvVDGXjfDTAsgGwDgAxRDeQ4X8NvUedIt6 google.golang.org/api v0.55.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE= google.golang.org/api v0.56.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE= google.golang.org/api v0.57.0/go.mod h1:dVPlbZyBo2/OjBpmvNdpn2GRm6rPy75jyU7bmhdrMgI= -google.golang.org/api v0.58.0/go.mod h1:cAbP2FsxoGVNwtgNAmmn3y5G1TWAiVYRmg4yku3lv+E= google.golang.org/api v0.59.0/go.mod h1:sT2boj7M9YJxZzgeZqXogmhfmRWDtPzT31xkieUbuZU= google.golang.org/api v0.61.0/go.mod h1:xQRti5UdCmoCEqFxcz93fTl338AVqDgyaDRuOZ3hg9I= google.golang.org/api v0.62.0/go.mod h1:dKmwPCydfsad4qCH08MSdgWjfHOyfpd4VtDGgRFdavw= google.golang.org/api v0.63.0/go.mod h1:gs4ij2ffTRXwuzzgJl/56BdwJaA194ijkfn++9tDuPo= google.golang.org/api v0.64.0/go.mod h1:931CdxA8Rm4t6zqTFGSsgwbAEZ2+GMYurbndwSimebM= -google.golang.org/api v0.65.0 h1:MTW9c+LIBAbwoS1Gb+YV7NjFBt2f7GtAS5hIzh2NjgQ= google.golang.org/api v0.65.0/go.mod h1:ArYhxgGadlWmqO1IqVujw6Cs8IdD33bTmzKo2Sh+cbg= +google.golang.org/api v0.66.0 h1:CbGy4LEiXCVCiNEDFgGpWOVwsDT7E2Qej1ZvN1P7KPg= +google.golang.org/api v0.66.0/go.mod h1:I1dmXYpX7HGwz/ejRxwQp2qj5bFAz93HiCU1C1oYd9M= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1152,10 +1153,8 @@ google.golang.org/genproto v0.0.0-20210828152312-66f60bf46e71/go.mod h1:eFjDcFEc google.golang.org/genproto v0.0.0-20210831024726-fe130286e0e2/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= google.golang.org/genproto v0.0.0-20210903162649-d08c68adba83/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= google.golang.org/genproto v0.0.0-20210909211513-a8c4777a87af/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= -google.golang.org/genproto v0.0.0-20210917145530-b395a37504d4/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= google.golang.org/genproto v0.0.0-20210924002016-3dee208752a0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211008145708-270636b82663/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= -google.golang.org/genproto v0.0.0-20211016002631-37fc39342514/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211028162531-8db9c33dc351/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211118181313-81c1377c94b1/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211129164237-f09f9a12af12/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= @@ -1166,8 +1165,10 @@ google.golang.org/genproto v0.0.0-20211221195035-429b39de9b1c/go.mod h1:5CzLGKJ6 google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20220107163113-42d7afdf6368/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20220111164026-67b88f271998/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= -google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 h1:zzNejm+EgrbLfDZ6lu9Uud2IVvHySPl8vQzf04laR5Q= +google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= +google.golang.org/genproto v0.0.0-20220126215142-9970aeb2e350 h1:YxHp5zqIcAShDEvRr5/0rVESVS+njYF68PSdazrNLJo= +google.golang.org/genproto v0.0.0-20220126215142-9970aeb2e350/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= @@ -1196,8 +1197,9 @@ google.golang.org/grpc v1.39.1/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnD google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.40.1/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.42.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= -google.golang.org/grpc v1.43.0 h1:Eeu7bZtDZ2DpRCsLhUlcrLnvYaMK1Gz86a+hMVvELmM= google.golang.org/grpc v1.43.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= +google.golang.org/grpc v1.44.0 h1:weqSxi/TMs1SqFRMHCtBgXRs8k3X39QIDEZ0pRcttUg= +google.golang.org/grpc v1.44.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= diff --git a/httpd/api_http_user.go b/httpd/api_http_user.go index 2b55c3a8..e3b6ce3e 100644 --- a/httpd/api_http_user.go +++ b/httpd/api_http_user.go @@ -270,6 +270,13 @@ func uploadUserFiles(w http.ResponseWriter, r *http.Request) { if err != nil { return } + transferQuota := connection.GetTransferQuota() + if !transferQuota.HasUploadSpace() { + connection.Log(logger.LevelInfo, "denying file write due to transfer quota limits") + sendAPIResponse(w, r, common.ErrQuotaExceeded, "Denying file write due to transfer quota limits", + http.StatusRequestEntityTooLarge) + return + } common.Connections.Add(connection) defer common.Connections.Remove(connection.GetID()) diff --git a/httpd/api_quota.go b/httpd/api_quota.go index a01978eb..bfc85826 100644 --- a/httpd/api_quota.go +++ b/httpd/api_quota.go @@ -23,6 +23,11 @@ type quotaUsage struct { UsedQuotaFiles int `json:"used_quota_files"` } +type transferQuotaUsage struct { + UsedUploadDataTransfer int64 `json:"used_upload_data_transfer"` + UsedDownloadDataTransfer int64 `json:"used_download_data_transfer"` +} + func getUsersQuotaScans(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) render.JSON(w, r, common.QuotaScans.GetUsersQuotaScans()) @@ -118,6 +123,43 @@ func startFolderQuotaScanCompat(w http.ResponseWriter, r *http.Request) { doStartFolderQuotaScan(w, r, f.Name) } +func updateUserTransferQuotaUsage(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + var usage transferQuotaUsage + err := render.DecodeJSON(r.Body, &usage) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + if usage.UsedUploadDataTransfer < 0 || usage.UsedDownloadDataTransfer < 0 { + sendAPIResponse(w, r, errors.New("invalid used transfer quota parameters, negative values are not allowed"), + "", http.StatusBadRequest) + return + } + mode, err := getQuotaUpdateMode(r) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + user, err := dataprovider.UserExists(getURLParam(r, "username")) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if mode == quotaUpdateModeAdd && !user.HasTransferQuotaRestrictions() && dataprovider.GetQuotaTracking() == 2 { + sendAPIResponse(w, r, errors.New("this user has no transfer quota restrictions, only reset mode is supported"), + "", http.StatusBadRequest) + return + } + err = dataprovider.UpdateUserTransferQuota(&user, usage.UsedUploadDataTransfer, usage.UsedDownloadDataTransfer, + mode == quotaUpdateModeReset) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Quota updated", http.StatusOK) +} + func doUpdateUserQuotaUsage(w http.ResponseWriter, r *http.Request, username string, usage quotaUsage) { if usage.UsedQuotaFiles < 0 || usage.UsedQuotaSize < 0 { sendAPIResponse(w, r, errors.New("invalid used quota parameters, negative values are not allowed"), @@ -147,9 +189,9 @@ func doUpdateUserQuotaUsage(w http.ResponseWriter, r *http.Request, username str err = dataprovider.UpdateUserQuota(&user, usage.UsedQuotaFiles, usage.UsedQuotaSize, mode == quotaUpdateModeReset) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) - } else { - sendAPIResponse(w, r, err, "Quota updated", http.StatusOK) + return } + sendAPIResponse(w, r, err, "Quota updated", http.StatusOK) } func doUpdateFolderQuotaUsage(w http.ResponseWriter, r *http.Request, name string, usage quotaUsage) { diff --git a/httpd/api_shares.go b/httpd/api_shares.go index 6f4fbb72..858811e6 100644 --- a/httpd/api_shares.go +++ b/httpd/api_shares.go @@ -13,6 +13,7 @@ import ( "github.com/drakkan/sftpgo/v2/common" "github.com/drakkan/sftpgo/v2/dataprovider" + "github.com/drakkan/sftpgo/v2/logger" "github.com/drakkan/sftpgo/v2/util" ) @@ -159,6 +160,12 @@ func downloadFromShare(w http.ResponseWriter, r *http.Request) { dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck if compress { + transferQuota := connection.GetTransferQuota() + if !transferQuota.HasDownloadSpace() { + err = connection.GetReadQuotaExceededError() + connection.Log(logger.LevelInfo, "denying share read due to quota limits") + sendAPIResponse(w, r, err, "", getMappedStatusCode(err)) + } w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"share-%v.zip\"", share.ShareID)) renderCompressedFiles(w, connection, "/", share.Paths, &share) return @@ -209,6 +216,14 @@ func uploadFilesToShare(w http.ResponseWriter, r *http.Request) { return } + transferQuota := connection.GetTransferQuota() + if !transferQuota.HasUploadSpace() { + connection.Log(logger.LevelInfo, "denying file write due to transfer quota limits") + sendAPIResponse(w, r, common.ErrQuotaExceeded, "Denying file write due to transfer quota limits", + http.StatusRequestEntityTooLarge) + return + } + common.Connections.Add(connection) defer common.Connections.Remove(connection.GetID()) diff --git a/httpd/api_utils.go b/httpd/api_utils.go index cfab52a2..cdc75418 100644 --- a/httpd/api_utils.go +++ b/httpd/api_utils.go @@ -97,6 +97,8 @@ func getMappedStatusCode(err error) int { switch { case errors.Is(err, os.ErrPermission): statusCode = http.StatusForbidden + case errors.Is(err, common.ErrReadQuotaExceeded): + statusCode = http.StatusForbidden case errors.Is(err, os.ErrNotExist): statusCode = http.StatusNotFound case errors.Is(err, common.ErrQuotaExceeded): @@ -310,7 +312,11 @@ func downloadFile(w http.ResponseWriter, r *http.Request, connection *Connection w.Header().Set("Accept-Ranges", "bytes") w.WriteHeader(responseStatus) if r.Method != http.MethodHead { - io.CopyN(w, reader, size) //nolint:errcheck + _, err = io.CopyN(w, reader, size) + if err != nil { + connection.Log(logger.LevelDebug, "error reading file to download: %v", err) + panic(http.ErrAbortHandler) + } } return http.StatusOK, nil } diff --git a/httpd/file.go b/httpd/file.go index c43381c5..8178016b 100644 --- a/httpd/file.go +++ b/httpd/file.go @@ -49,6 +49,9 @@ func (f *httpdFile) Read(p []byte) (n int, err error) { n, err = f.reader.Read(p) atomic.AddInt64(&f.BytesSent, int64(n)) + if err == nil { + err = f.CheckRead() + } if err != nil && err != io.EOF { f.TransferError(err) return @@ -70,8 +73,8 @@ func (f *httpdFile) Write(p []byte) (n int, err error) { n, err = f.writer.Write(p) atomic.AddInt64(&f.BytesReceived, int64(n)) - if f.MaxWriteSize > 0 && err == nil && atomic.LoadInt64(&f.BytesReceived) > f.MaxWriteSize { - err = common.ErrQuotaExceeded + if err == nil { + err = f.CheckWrite() } if err != nil { f.TransferError(err) diff --git a/httpd/handler.go b/httpd/handler.go index 9d1bcd82..fa56e233 100644 --- a/httpd/handler.go +++ b/httpd/handler.go @@ -85,6 +85,12 @@ func (c *Connection) ReadDir(name string) ([]os.FileInfo, error) { func (c *Connection) getFileReader(name string, offset int64, method string) (io.ReadCloser, error) { c.UpdateLastActivity() + transferQuota := c.GetTransferQuota() + if !transferQuota.HasDownloadSpace() { + c.Log(logger.LevelInfo, "denying file read due to quota limits") + return nil, c.GetReadQuotaExceededError() + } + name = util.CleanPath(name) if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(name)) { return nil, c.GetPermissionDeniedError() @@ -114,7 +120,7 @@ func (c *Connection) getFileReader(name string, offset int64, method string) (io } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, name, common.TransferDownload, - 0, 0, 0, 0, false, fs) + 0, 0, 0, 0, false, fs, transferQuota) return newHTTPDFile(baseTransfer, nil, r), nil } @@ -171,8 +177,8 @@ func (c *Connection) getFileWriter(name string) (io.WriteCloser, error) { } func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, requestPath string, isNewFile bool, fileSize int64) (io.WriteCloser, error) { - quotaResult := c.HasSpace(isNewFile, false, requestPath) - if !quotaResult.HasSpace { + diskQuota, transferQuota := c.HasSpace(isNewFile, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, common.ErrQuotaExceeded } @@ -182,7 +188,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request return nil, c.GetPermissionDeniedError() } - maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, fileSize, fs.IsUploadResumeSupported()) + maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, fileSize, fs.IsUploadResumeSupported()) file, w, cancelFn, err := fs.Create(filePath, 0) if err != nil { @@ -215,7 +221,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs) + common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs, transferQuota) return newHTTPDFile(baseTransfer, w, nil), nil } @@ -294,8 +300,8 @@ func (t *throttledReader) GetTruncatedSize() int64 { return 0 } -func (t *throttledReader) GetMaxAllowedSize() int64 { - return 0 +func (t *throttledReader) HasSizeLimit() bool { + return false } func (t *throttledReader) Truncate(fsPath string, size int64) (int64, error) { diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index 2b98aa60..e5e71e3c 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -297,7 +297,7 @@ func TestMain(m *testing.M) { os.RemoveAll(credentialsPath) //nolint:errcheck logger.InfoToConsole("Starting HTTPD tests, provider: %v", providerConf.Driver) - err = common.Initialize(config.GetCommonConfig()) + err = common.Initialize(config.GetCommonConfig(), 0) if err != nil { logger.WarnToConsole("error initializing common: %v", err) os.Exit(1) @@ -506,7 +506,104 @@ func TestBasicUserHandling(t *testing.T) { assert.NoError(t, err) } -func TestUserBandwidthLimit(t *testing.T) { +func TestUserTransferLimits(t *testing.T) { + u := getTestUser() + u.TotalDataTransfer = 100 + u.Filters.DataTransferLimits = []sdk.DataTransferLimit{ + { + Sources: nil, + }, + } + _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "Validation error: no data transfer limit source specified") + u.Filters.DataTransferLimits = []sdk.DataTransferLimit{ + { + Sources: []string{"a"}, + }, + } + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "Validation error: could not parse data transfer limit source") + u.Filters.DataTransferLimits = []sdk.DataTransferLimit{ + { + Sources: []string{"127.0.0.1/32"}, + UploadDataTransfer: 120, + DownloadDataTransfer: 140, + }, + { + Sources: []string{"192.168.0.0/24", "192.168.1.0/24"}, + TotalDataTransfer: 400, + }, + { + Sources: []string{"10.0.0.0/8"}, + }, + } + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Len(t, user.Filters.DataTransferLimits, 3) + assert.Equal(t, u.Filters.DataTransferLimits, user.Filters.DataTransferLimits) + up, down, total := user.GetDataTransferLimits("1.1.1.1") + assert.Equal(t, user.TotalDataTransfer*1024*1024, total) + assert.Equal(t, user.UploadDataTransfer*1024*1024, up) + assert.Equal(t, user.DownloadDataTransfer*1024*1024, down) + up, down, total = user.GetDataTransferLimits("127.0.0.1") + assert.Equal(t, user.Filters.DataTransferLimits[0].TotalDataTransfer*1024*1024, total) + assert.Equal(t, user.Filters.DataTransferLimits[0].UploadDataTransfer*1024*1024, up) + assert.Equal(t, user.Filters.DataTransferLimits[0].DownloadDataTransfer*1024*1024, down) + up, down, total = user.GetDataTransferLimits("192.168.1.6") + assert.Equal(t, user.Filters.DataTransferLimits[1].TotalDataTransfer*1024*1024, total) + assert.Equal(t, user.Filters.DataTransferLimits[1].UploadDataTransfer*1024*1024, up) + assert.Equal(t, user.Filters.DataTransferLimits[1].DownloadDataTransfer*1024*1024, down) + up, down, total = user.GetDataTransferLimits("10.1.2.3") + assert.Equal(t, user.Filters.DataTransferLimits[2].TotalDataTransfer*1024*1024, total) + assert.Equal(t, user.Filters.DataTransferLimits[2].UploadDataTransfer*1024*1024, up) + assert.Equal(t, user.Filters.DataTransferLimits[2].DownloadDataTransfer*1024*1024, down) + + connID := xid.New().String() + localAddr := "::1" + conn := common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "1.1.1.2", user) + transferQuota := conn.GetTransferQuota() + assert.Equal(t, user.TotalDataTransfer*1024*1024, transferQuota.AllowedTotalSize) + assert.Equal(t, user.UploadDataTransfer*1024*1024, transferQuota.AllowedULSize) + assert.Equal(t, user.DownloadDataTransfer*1024*1024, transferQuota.AllowedDLSize) + + conn = common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "127.0.0.1", user) + transferQuota = conn.GetTransferQuota() + assert.Equal(t, user.Filters.DataTransferLimits[0].TotalDataTransfer*1024*1024, transferQuota.AllowedTotalSize) + assert.Equal(t, user.Filters.DataTransferLimits[0].UploadDataTransfer*1024*1024, transferQuota.AllowedULSize) + assert.Equal(t, user.Filters.DataTransferLimits[0].DownloadDataTransfer*1024*1024, transferQuota.AllowedDLSize) + + conn = common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "192.168.1.5", user) + transferQuota = conn.GetTransferQuota() + assert.Equal(t, user.Filters.DataTransferLimits[1].TotalDataTransfer*1024*1024, transferQuota.AllowedTotalSize) + assert.Equal(t, user.Filters.DataTransferLimits[1].UploadDataTransfer*1024*1024, transferQuota.AllowedULSize) + assert.Equal(t, user.Filters.DataTransferLimits[1].DownloadDataTransfer*1024*1024, transferQuota.AllowedDLSize) + + u.UsedDownloadDataTransfer = 10 * 1024 * 1024 + u.UsedUploadDataTransfer = 5 * 1024 * 1024 + _, err = httpdtest.UpdateTransferQuotaUsage(u, "", http.StatusOK) + assert.NoError(t, err) + + conn = common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "192.168.1.6", user) + transferQuota = conn.GetTransferQuota() + assert.Equal(t, (user.Filters.DataTransferLimits[1].TotalDataTransfer-15)*1024*1024, transferQuota.AllowedTotalSize) + assert.Equal(t, user.Filters.DataTransferLimits[1].UploadDataTransfer*1024*1024, transferQuota.AllowedULSize) + assert.Equal(t, user.Filters.DataTransferLimits[1].DownloadDataTransfer*1024*1024, transferQuota.AllowedDLSize) + + conn = common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "10.8.3.4", user) + transferQuota = conn.GetTransferQuota() + assert.Equal(t, int64(0), transferQuota.AllowedTotalSize) + assert.Equal(t, int64(0), transferQuota.AllowedULSize) + assert.Equal(t, int64(0), transferQuota.AllowedDLSize) + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserBandwidthLimits(t *testing.T) { u := getTestUser() u.UploadBandwidth = 128 u.DownloadBandwidth = 96 @@ -518,6 +615,14 @@ func TestUserBandwidthLimit(t *testing.T) { _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err, string(resp)) assert.Contains(t, string(resp), "Validation error: could not parse bandwidth limit source") + u.Filters.BandwidthLimits = []sdk.BandwidthLimit{ + { + Sources: nil, + }, + } + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "Validation error: no bandwidth limit source specified") u.Filters.BandwidthLimits = []sdk.BandwidthLimit{ { Sources: []string{"127.0.0.0/8", "::1/128"}, @@ -2163,6 +2268,81 @@ func TestUpdateUser(t *testing.T) { assert.NoError(t, err) } +func TestUpdateUserTransferQuotaUsage(t *testing.T) { + u := getTestUser() + usedDownloadDataTransfer := int64(2 * 1024 * 1024) + usedUploadDataTransfer := int64(1024 * 1024) + u.UsedDownloadDataTransfer = usedDownloadDataTransfer + u.UsedUploadDataTransfer = usedUploadDataTransfer + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), user.UsedUploadDataTransfer) + assert.Equal(t, int64(0), user.UsedDownloadDataTransfer) + _, err = httpdtest.UpdateTransferQuotaUsage(u, "invalid_mode", http.StatusBadRequest) + assert.NoError(t, err) + _, err = httpdtest.UpdateTransferQuotaUsage(u, "", http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, usedUploadDataTransfer, user.UsedUploadDataTransfer) + assert.Equal(t, usedDownloadDataTransfer, user.UsedDownloadDataTransfer) + _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusBadRequest) + assert.NoError(t, err, "user has no transfer quota restrictions add mode should fail") + user.TotalDataTransfer = 100 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2*usedUploadDataTransfer, user.UsedUploadDataTransfer) + assert.Equal(t, 2*usedDownloadDataTransfer, user.UsedDownloadDataTransfer) + u.UsedDownloadDataTransfer = -1 + _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusBadRequest) + assert.NoError(t, err) + u.UsedDownloadDataTransfer = usedDownloadDataTransfer + u.Username += "1" + _, err = httpdtest.UpdateTransferQuotaUsage(u, "", http.StatusNotFound) + assert.NoError(t, err) + u.Username = defaultUsername + _, err = httpdtest.UpdateTransferQuotaUsage(u, "", http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, usedUploadDataTransfer, user.UsedUploadDataTransfer) + assert.Equal(t, usedDownloadDataTransfer, user.UsedDownloadDataTransfer) + u.UsedDownloadDataTransfer = 0 + u.UsedUploadDataTransfer = 1 + _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, usedUploadDataTransfer+1, user.UsedUploadDataTransfer) + assert.Equal(t, usedDownloadDataTransfer, user.UsedDownloadDataTransfer) + u.UsedDownloadDataTransfer = 1 + u.UsedUploadDataTransfer = 0 + _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, usedUploadDataTransfer+1, user.UsedUploadDataTransfer) + assert.Equal(t, usedDownloadDataTransfer+1, user.UsedDownloadDataTransfer) + + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "transfer-usage"), + bytes.NewBuffer([]byte(`not a json`))) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + func TestUpdateUserQuotaUsage(t *testing.T) { u := getTestUser() usedQuotaFiles := 1 @@ -2171,6 +2351,10 @@ func TestUpdateUserQuotaUsage(t *testing.T) { u.UsedQuotaSize = usedQuotaSize user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) _, err = httpdtest.UpdateQuotaUsage(u, "invalid_mode", http.StatusBadRequest) assert.NoError(t, err) _, err = httpdtest.UpdateQuotaUsage(u, "", http.StatusOK) @@ -3948,6 +4132,8 @@ func TestQuotaTrackingDisabled(t *testing.T) { assert.NoError(t, err) _, err = httpdtest.UpdateQuotaUsage(user, "", http.StatusForbidden) assert.NoError(t, err) + _, err = httpdtest.UpdateTransferQuotaUsage(user, "", http.StatusForbidden) + assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) // folder quota scan must fail @@ -4306,7 +4492,7 @@ func TestDefenderAPI(t *testing.T) { cfg.DefenderConfig.Threshold = 3 cfg.DefenderConfig.ScoreLimitExceeded = 2 - err := common.Initialize(cfg) + err := common.Initialize(cfg, 0) assert.NoError(t, err) ip := "::1" @@ -4405,7 +4591,7 @@ func TestDefenderAPI(t *testing.T) { } } - err := common.Initialize(oldConfig) + err := common.Initialize(oldConfig, 0) require.NoError(t, err) } @@ -4427,7 +4613,7 @@ func TestDefenderAPIErrors(t *testing.T) { cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Driver = common.DefenderDriverProvider - err := common.Initialize(cfg) + err := common.Initialize(cfg, 0) require.NoError(t, err) token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) @@ -4465,7 +4651,7 @@ func TestDefenderAPIErrors(t *testing.T) { err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) - err = common.Initialize(oldConfig) + err = common.Initialize(oldConfig, 0) require.NoError(t, err) } } @@ -4851,6 +5037,8 @@ func TestLoaddataMode(t *testing.T) { assert.NoError(t, err) _, err = httpdtest.RemoveFolder(folder, http.StatusOK) assert.NoError(t, err) + _, err = httpdtest.RemoveAPIKey(apiKey, http.StatusOK) + assert.NoError(t, err) err = os.Remove(backupFilePath) assert.NoError(t, err) } @@ -4869,7 +5057,7 @@ func TestRateLimiter(t *testing.T) { }, } - err := common.Initialize(cfg) + err := common.Initialize(cfg, 0) assert.NoError(t, err) client := &http.Client{ @@ -4905,7 +5093,7 @@ func TestRateLimiter(t *testing.T) { err = resp.Body.Close() assert.NoError(t, err) - err = common.Initialize(oldConfig) + err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } @@ -8333,7 +8521,7 @@ func TestDefender(t *testing.T) { cfg.DefenderConfig.Threshold = 3 cfg.DefenderConfig.ScoreLimitExceeded = 2 - err := common.Initialize(cfg) + err := common.Initialize(cfg, 0) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) @@ -8389,7 +8577,7 @@ func TestDefender(t *testing.T) { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - err = common.Initialize(oldConfig) + err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } @@ -10659,6 +10847,132 @@ func TestWebFilesAPI(t *testing.T) { checkResponseCode(t, http.StatusNotFound, rr) } +func TestWebFilesTransferQuotaLimits(t *testing.T) { + u := getTestUser() + u.UploadDataTransfer = 1 + u.DownloadDataTransfer = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + testFileName := "file.data" + testFileSize := 550000 + testFileContents := make([]byte, testFileSize) + n, err := io.ReadFull(rand.Reader, testFileContents) + assert.NoError(t, err) + assert.Equal(t, testFileSize, n) + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("filenames", testFileName) + assert.NoError(t, err) + _, err = part.Write(testFileContents) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, testFileContents, rr.Body.Bytes()) + // error while download is active + downloadFunc := func() { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + } + downloadFunc() + // error before starting the download + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + // error while upload is active + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) + // error before starting the upload + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) + // now test upload/download to/from shares + share1 := dataprovider.Share{ + Name: "share1", + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + } + asJSON, err := json.Marshal(share1) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + share2 := dataprovider.Share{ + Name: "share2", + Scope: dataprovider.ShareScopeWrite, + Paths: []string{"/"}, + } + asJSON, err = json.Marshal(share2) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID = rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + rr = executeRequest(req) + checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + func TestWebUploadErrors(t *testing.T) { u := getTestUser() u.QuotaSize = 65535 @@ -11318,6 +11632,10 @@ func TestClientUserClose(t *testing.T) { wg.Add(1) go func() { defer wg.Done() + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() req, _ := http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) setJWTCookieForReq(req, webToken) rr := executeRequest(req) @@ -13035,6 +13353,8 @@ func TestWebUserAddMock(t *testing.T) { user := getTestUser() user.UploadBandwidth = 32 user.DownloadBandwidth = 64 + user.UploadDataTransfer = 1000 + user.DownloadDataTransfer = 2000 user.UID = 1000 user.AdditionalInfo = "info" user.Description = "user dsc" @@ -13088,6 +13408,7 @@ func TestWebUserAddMock(t *testing.T) { form.Set("description", user.Description) form.Add("hooks", "external_auth_disabled") form.Set("disable_fs_checks", "checked") + form.Set("total_data_transfer", "0") b, contentType, _ := getMultipartFormData(form, "", "") // test invalid url escape req, _ = http.NewRequest(http.MethodPost, webUserPath+"?a=%2", &b) @@ -13152,6 +13473,33 @@ func TestWebUserAddMock(t *testing.T) { rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) form.Set("download_bandwidth", strconv.FormatInt(user.DownloadBandwidth, 10)) + form.Set("upload_data_transfer", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid upload data transfer + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("upload_data_transfer", strconv.FormatInt(user.UploadDataTransfer, 10)) + form.Set("download_data_transfer", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid download data transfer + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("download_data_transfer", strconv.FormatInt(user.DownloadDataTransfer, 10)) + form.Set("total_data_transfer", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid total data transfer + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("total_data_transfer", strconv.FormatInt(user.TotalDataTransfer, 10)) form.Set("status", "a") b, contentType, _ = getMultipartFormData(form, "", "") // test invalid status @@ -13240,6 +13588,49 @@ func TestWebUserAddMock(t *testing.T) { assert.Contains(t, rr.Body.String(), "Validation error: could not parse bandwidth limit source") form.Set("bandwidth_limit_sources1", "127.0.0.1/32") form.Set("upload_bandwidth_source1", "-1") + form.Set("data_transfer_limit_sources0", "127.0.1.1") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), "could not parse data transfer limit source") + form.Set("data_transfer_limit_sources0", "127.0.1.1/32") + form.Set("upload_data_transfer_source0", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), "invalid upload_data_transfer_source") + form.Set("upload_data_transfer_source0", "0") + form.Set("download_data_transfer_source0", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), "invalid download_data_transfer_source") + form.Set("download_data_transfer_source0", "0") + form.Set("total_data_transfer_source0", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), "invalid total_data_transfer_source") + form.Set("total_data_transfer_source0", "0") + form.Set("data_transfer_limit_sources10", "192.168.5.0/24, 10.8.0.0/16") + form.Set("download_data_transfer_source10", "100") + form.Set("upload_data_transfer_source10", "120") + form.Set("data_transfer_limit_sources12", "192.168.3.0/24, 10.8.2.0/24,::1/64") + form.Set("download_data_transfer_source12", "100") + form.Set("upload_data_transfer_source12", "120") + form.Set("total_data_transfer_source12", "200") form.Set(csrfFormToken, "invalid form token") b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) @@ -13278,6 +13669,9 @@ func TestWebUserAddMock(t *testing.T) { assert.Equal(t, user.UID, newUser.UID) assert.Equal(t, user.UploadBandwidth, newUser.UploadBandwidth) assert.Equal(t, user.DownloadBandwidth, newUser.DownloadBandwidth) + assert.Equal(t, user.UploadDataTransfer, newUser.UploadDataTransfer) + assert.Equal(t, user.DownloadDataTransfer, newUser.DownloadDataTransfer) + assert.Equal(t, user.TotalDataTransfer, newUser.TotalDataTransfer) assert.Equal(t, int64(1000), newUser.Filters.MaxUploadFileSize) assert.Equal(t, user.AdditionalInfo, newUser.AdditionalInfo) assert.Equal(t, user.Description, newUser.Description) @@ -13340,6 +13734,30 @@ func TestWebUserAddMock(t *testing.T) { } } } + if assert.Len(t, newUser.Filters.DataTransferLimits, 3) { + for _, dtLimit := range newUser.Filters.DataTransferLimits { + switch len(dtLimit.Sources) { + case 3: + assert.Equal(t, "192.168.3.0/24", dtLimit.Sources[0]) + assert.Equal(t, "10.8.2.0/24", dtLimit.Sources[1]) + assert.Equal(t, "::1/64", dtLimit.Sources[2]) + assert.Equal(t, int64(0), dtLimit.UploadDataTransfer) + assert.Equal(t, int64(0), dtLimit.DownloadDataTransfer) + assert.Equal(t, int64(200), dtLimit.TotalDataTransfer) + case 2: + assert.Equal(t, "192.168.5.0/24", dtLimit.Sources[0]) + assert.Equal(t, "10.8.0.0/16", dtLimit.Sources[1]) + assert.Equal(t, int64(120), dtLimit.UploadDataTransfer) + assert.Equal(t, int64(100), dtLimit.DownloadDataTransfer) + assert.Equal(t, int64(0), dtLimit.TotalDataTransfer) + case 1: + assert.Equal(t, "127.0.1.1/32", dtLimit.Sources[0]) + assert.Equal(t, int64(0), dtLimit.UploadDataTransfer) + assert.Equal(t, int64(0), dtLimit.DownloadDataTransfer) + assert.Equal(t, int64(0), dtLimit.TotalDataTransfer) + } + } + } assert.Equal(t, sdk.TLSUsernameNone, newUser.Filters.TLSUsername) req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, newUser.Username), nil) @@ -13367,6 +13785,7 @@ func TestWebUserUpdateMock(t *testing.T) { DownloadBandwidth: 512, }, } + user.TotalDataTransfer = 4000 userAsJSON := getUserAsJSON(t, user) req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) setBearerForReq(req, apiToken) @@ -13402,6 +13821,7 @@ func TestWebUserUpdateMock(t *testing.T) { user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.True(t, user.Filters.TOTPConfig.Enabled) + assert.Equal(t, int64(4000), user.TotalDataTransfer) if assert.Len(t, user.Filters.BandwidthLimits, 1) { if assert.Len(t, user.Filters.BandwidthLimits[0].Sources, 2) { assert.Equal(t, "10.8.0.0/16", user.Filters.BandwidthLimits[0].Sources[0]) @@ -13437,6 +13857,9 @@ func TestWebUserUpdateMock(t *testing.T) { form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") form.Set("permissions", "*") form.Set("sub_perm_path0", "/otherdir") form.Set("sub_perm_permissions0", "list") @@ -13523,7 +13946,9 @@ func TestWebUserUpdateMock(t *testing.T) { assert.Equal(t, sdk.TLSUsernameCN, updateUser.Filters.TLSUsername) assert.True(t, updateUser.Filters.AllowAPIKeyAuth) assert.True(t, updateUser.Filters.TOTPConfig.Enabled) - + assert.Equal(t, int64(0), updateUser.TotalDataTransfer) + assert.Equal(t, int64(0), updateUser.DownloadDataTransfer) + assert.Equal(t, int64(0), updateUser.UploadDataTransfer) if val, ok := updateUser.Permissions["/otherdir"]; ok { assert.True(t, util.IsStringInSlice(dataprovider.PermListItems, val)) assert.True(t, util.IsStringInSlice(dataprovider.PermUpload, val)) @@ -13626,6 +14051,9 @@ func TestUserTemplateWithFoldersMock(t *testing.T) { form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") form.Set("permissions", "*") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") @@ -13713,6 +14141,9 @@ func TestUserSaveFromTemplateMock(t *testing.T) { form.Set("home_dir", filepath.Join(os.TempDir(), "%username%")) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") form.Set("uid", "0") form.Set("gid", "0") form.Set("max_sessions", "0") @@ -13794,6 +14225,9 @@ func TestUserTemplateMock(t *testing.T) { form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") form.Set("permissions", "*") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") @@ -14136,6 +14570,9 @@ func TestWebUserS3Mock(t *testing.T) { form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") form.Set("permissions", "*") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") @@ -14336,6 +14773,9 @@ func TestWebUserGCSMock(t *testing.T) { form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") form.Set("permissions", "*") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") @@ -14448,6 +14888,9 @@ func TestWebUserAzureBlobMock(t *testing.T) { form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") form.Set("permissions", "*") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") @@ -14608,6 +15051,9 @@ func TestWebUserCryptMock(t *testing.T) { form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") form.Set("permissions", "*") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") @@ -14710,6 +15156,9 @@ func TestWebUserSFTPFsMock(t *testing.T) { form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) form.Set("upload_bandwidth", "0") form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") form.Set("permissions", "*") form.Set("status", strconv.Itoa(user.Status)) form.Set("expiration_date", "2020-01-01 00:00:00") diff --git a/httpd/internal_test.go b/httpd/internal_test.go index 7d8de8bd..263535d1 100644 --- a/httpd/internal_test.go +++ b/httpd/internal_test.go @@ -1882,7 +1882,7 @@ func TestHTTPDFile(t *testing.T) { assert.NoError(t, err) baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, p, p, name, common.TransferDownload, - 0, 0, 0, 0, false, fs) + 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) httpdFile := newHTTPDFile(baseTransfer, nil, nil) // the file is closed, read should fail buf := make([]byte, 100) diff --git a/httpd/server.go b/httpd/server.go index b0a2ac74..e181bc2a 100644 --- a/httpd/server.go +++ b/httpd/server.go @@ -1075,9 +1075,13 @@ func (s *httpdServer) initializeRouter() { router.With(checkPerm(dataprovider.PermAdminManageSystem)).Get(loadDataPath, loadData) router.With(checkPerm(dataprovider.PermAdminManageSystem)).Post(loadDataPath, loadDataFromRequest) router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(updateUsedQuotaPath, updateUserQuotaUsageCompat) - router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(quotasBasePath+"/users/{username}/usage", updateUserQuotaUsage) + router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(quotasBasePath+"/users/{username}/usage", + updateUserQuotaUsage) + router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(quotasBasePath+"/users/{username}/transfer-usage", + updateUserTransferQuotaUsage) router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(updateFolderUsedQuotaPath, updateFolderQuotaUsageCompat) - router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(quotasBasePath+"/folders/{name}/usage", updateFolderQuotaUsage) + router.With(checkPerm(dataprovider.PermAdminChangeUsers)).Put(quotasBasePath+"/folders/{name}/usage", + updateFolderQuotaUsage) router.With(checkPerm(dataprovider.PermAdminViewDefender)).Get(defenderHosts, getDefenderHosts) router.With(checkPerm(dataprovider.PermAdminViewDefender)).Get(defenderHosts+"/{id}", getDefenderHostByID) router.With(checkPerm(dataprovider.PermAdminManageDefender)).Delete(defenderHosts+"/{id}", deleteDefenderHostByID) diff --git a/httpd/webadmin.go b/httpd/webadmin.go index 927fa226..2e669170 100644 --- a/httpd/webadmin.go +++ b/httpd/webadmin.go @@ -762,6 +762,50 @@ func getUserPermissionsFromPostFields(r *http.Request) map[string][]string { return permissions } +func getDataTransferLimitsFromPostFields(r *http.Request) ([]sdk.DataTransferLimit, error) { + var result []sdk.DataTransferLimit + + for k := range r.Form { + if strings.HasPrefix(k, "data_transfer_limit_sources") { + sources := getSliceFromDelimitedValues(r.Form.Get(k), ",") + if len(sources) > 0 { + dtLimit := sdk.DataTransferLimit{ + Sources: sources, + } + idx := strings.TrimPrefix(k, "data_transfer_limit_sources") + ul := r.Form.Get(fmt.Sprintf("upload_data_transfer_source%v", idx)) + dl := r.Form.Get(fmt.Sprintf("download_data_transfer_source%v", idx)) + total := r.Form.Get(fmt.Sprintf("total_data_transfer_source%v", idx)) + if ul != "" { + dataUL, err := strconv.ParseInt(ul, 10, 64) + if err != nil { + return result, fmt.Errorf("invalid upload_data_transfer_source%v %#v: %w", idx, ul, err) + } + dtLimit.UploadDataTransfer = dataUL + } + if dl != "" { + dataDL, err := strconv.ParseInt(dl, 10, 64) + if err != nil { + return result, fmt.Errorf("invalid download_data_transfer_source%v %#v: %w", idx, dl, err) + } + dtLimit.DownloadDataTransfer = dataDL + } + if total != "" { + dataTotal, err := strconv.ParseInt(total, 10, 64) + if err != nil { + return result, fmt.Errorf("invalid total_data_transfer_source%v %#v: %w", idx, total, err) + } + dtLimit.TotalDataTransfer = dataTotal + } + + result = append(result, dtLimit) + } + } + } + + return result, nil +} + func getBandwidthLimitsFromPostFields(r *http.Request) ([]sdk.BandwidthLimit, error) { var result []sdk.BandwidthLimit @@ -872,7 +916,12 @@ func getFiltersFromUserPostFields(r *http.Request) (sdk.BaseUserFilters, error) if err != nil { return filters, err } + dtLimits, err := getDataTransferLimitsFromPostFields(r) + if err != nil { + return filters, err + } filters.BandwidthLimits = bwLimits + filters.DataTransferLimits = dtLimits filters.AllowedIP = getSliceFromDelimitedValues(r.Form.Get("allowed_ip"), ",") filters.DeniedIP = getSliceFromDelimitedValues(r.Form.Get("denied_ip"), ",") filters.DeniedLoginMethods = r.Form["ssh_login_methods"] @@ -1176,6 +1225,22 @@ func getUserFromTemplate(user dataprovider.User, template userTemplateFields) da return user } +func getTransferLimits(r *http.Request) (int64, int64, int64, error) { + dataTransferUL, err := strconv.ParseInt(r.Form.Get("upload_data_transfer"), 10, 64) + if err != nil { + return 0, 0, 0, err + } + dataTransferDL, err := strconv.ParseInt(r.Form.Get("download_data_transfer"), 10, 64) + if err != nil { + return 0, 0, 0, err + } + dataTransferTotal, err := strconv.ParseInt(r.Form.Get("total_data_transfer"), 10, 64) + if err != nil { + return 0, 0, 0, err + } + return dataTransferUL, dataTransferDL, dataTransferTotal, nil +} + func getUserFromPostFields(r *http.Request) (dataprovider.User, error) { var user dataprovider.User err := r.ParseMultipartForm(maxRequestSize) @@ -1211,6 +1276,10 @@ func getUserFromPostFields(r *http.Request) (dataprovider.User, error) { if err != nil { return user, err } + dataTransferUL, dataTransferDL, dataTransferTotal, err := getTransferLimits(r) + if err != nil { + return user, err + } status, err := strconv.Atoi(r.Form.Get("status")) if err != nil { return user, err @@ -1234,23 +1303,26 @@ func getUserFromPostFields(r *http.Request) (dataprovider.User, error) { } user = dataprovider.User{ BaseUser: sdk.BaseUser{ - Username: r.Form.Get("username"), - Email: r.Form.Get("email"), - Password: r.Form.Get("password"), - PublicKeys: r.Form["public_keys"], - HomeDir: r.Form.Get("home_dir"), - UID: uid, - GID: gid, - Permissions: getUserPermissionsFromPostFields(r), - MaxSessions: maxSessions, - QuotaSize: quotaSize, - QuotaFiles: quotaFiles, - UploadBandwidth: bandwidthUL, - DownloadBandwidth: bandwidthDL, - Status: status, - ExpirationDate: expirationDateMillis, - AdditionalInfo: r.Form.Get("additional_info"), - Description: r.Form.Get("description"), + Username: r.Form.Get("username"), + Email: r.Form.Get("email"), + Password: r.Form.Get("password"), + PublicKeys: r.Form["public_keys"], + HomeDir: r.Form.Get("home_dir"), + UID: uid, + GID: gid, + Permissions: getUserPermissionsFromPostFields(r), + MaxSessions: maxSessions, + QuotaSize: quotaSize, + QuotaFiles: quotaFiles, + UploadBandwidth: bandwidthUL, + DownloadBandwidth: bandwidthDL, + UploadDataTransfer: dataTransferUL, + DownloadDataTransfer: dataTransferDL, + TotalDataTransfer: dataTransferTotal, + Status: status, + ExpirationDate: expirationDateMillis, + AdditionalInfo: r.Form.Get("additional_info"), + Description: r.Form.Get("description"), }, Filters: dataprovider.UserFilters{ BaseUserFilters: filters, diff --git a/httpdtest/httpdtest.go b/httpdtest/httpdtest.go index ef8b1891..8b052010 100644 --- a/httpdtest/httpdtest.go +++ b/httpdtest/httpdtest.go @@ -521,7 +521,8 @@ func StartQuotaScan(user dataprovider.User, expectedStatusCode int) ([]byte, err return body, checkResponse(resp.StatusCode, expectedStatusCode) } -// UpdateQuotaUsage updates the user used quota limits and checks the received HTTP Status code against expectedStatusCode. +// UpdateQuotaUsage updates the user used quota limits and checks the received +// HTTP Status code against expectedStatusCode. func UpdateQuotaUsage(user dataprovider.User, mode string, expectedStatusCode int) ([]byte, error) { var body []byte userAsJSON, _ := json.Marshal(user) @@ -539,6 +540,25 @@ func UpdateQuotaUsage(user dataprovider.User, mode string, expectedStatusCode in return body, checkResponse(resp.StatusCode, expectedStatusCode) } +// UpdateTransferQuotaUsage updates the user used transfer quota limits and checks the received +// HTTP Status code against expectedStatusCode. +func UpdateTransferQuotaUsage(user dataprovider.User, mode string, expectedStatusCode int) ([]byte, error) { + var body []byte + userAsJSON, _ := json.Marshal(user) + url, err := addModeQueryParam(buildURLRelativeToBase(quotasBasePath, "users", user.Username, "transfer-usage"), mode) + if err != nil { + return body, err + } + resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(userAsJSON), "application/json", + getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + // GetRetentionChecks returns the active retention checks func GetRetentionChecks(expectedStatusCode int) ([]common.ActiveRetentionChecks, []byte, error) { var checks []common.ActiveRetentionChecks @@ -1495,6 +1515,9 @@ func compareUserFilters(expected *dataprovider.User, actual *dataprovider.User) if err := compareUserBandwidthLimitFilters(expected, actual); err != nil { return err } + if err := compareUserDataTransferLimitFilters(expected, actual); err != nil { + return err + } return compareUserFilePatternsFilters(expected, actual) } @@ -1510,9 +1533,33 @@ func checkFilterMatch(expected []string, actual []string) bool { return true } +func compareUserDataTransferLimitFilters(expected *dataprovider.User, actual *dataprovider.User) error { + if len(expected.Filters.DataTransferLimits) != len(actual.Filters.DataTransferLimits) { + return errors.New("data transfer limits filters mismatch") + } + for idx, l := range expected.Filters.DataTransferLimits { + if actual.Filters.DataTransferLimits[idx].UploadDataTransfer != l.UploadDataTransfer { + return errors.New("data transfer limit upload_data_transfer mismatch") + } + if actual.Filters.DataTransferLimits[idx].DownloadDataTransfer != l.DownloadDataTransfer { + return errors.New("data transfer limit download_data_transfer mismatch") + } + if actual.Filters.DataTransferLimits[idx].TotalDataTransfer != l.TotalDataTransfer { + return errors.New("data transfer limit total_data_transfer mismatch") + } + for _, source := range actual.Filters.DataTransferLimits[idx].Sources { + if !util.IsStringInSlice(source, l.Sources) { + return errors.New("data transfer limit source mismatch") + } + } + } + + return nil +} + func compareUserBandwidthLimitFilters(expected *dataprovider.User, actual *dataprovider.User) error { if len(expected.Filters.BandwidthLimits) != len(actual.Filters.BandwidthLimits) { - return errors.New("bandwidth filters mismatch") + return errors.New("bandwidth limits filters mismatch") } for idx, l := range expected.Filters.BandwidthLimits { @@ -1573,12 +1620,6 @@ func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.U if expected.MaxSessions != actual.MaxSessions { return errors.New("MaxSessions mismatch") } - if expected.QuotaSize != actual.QuotaSize { - return errors.New("QuotaSize mismatch") - } - if expected.QuotaFiles != actual.QuotaFiles { - return errors.New("QuotaFiles mismatch") - } if len(expected.Permissions) != len(actual.Permissions) { return errors.New("permissions mismatch") } @@ -1600,6 +1641,25 @@ func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.U if expected.Description != actual.Description { return errors.New("description mismatch") } + return compareQuotaUserFields(expected, actual) +} + +func compareQuotaUserFields(expected *dataprovider.User, actual *dataprovider.User) error { + if expected.QuotaSize != actual.QuotaSize { + return errors.New("QuotaSize mismatch") + } + if expected.QuotaFiles != actual.QuotaFiles { + return errors.New("QuotaFiles mismatch") + } + if expected.UploadDataTransfer != actual.UploadDataTransfer { + return errors.New("upload_data_transfer mismatch") + } + if expected.DownloadDataTransfer != actual.DownloadDataTransfer { + return errors.New("download_data_transfer mismatch") + } + if expected.TotalDataTransfer != actual.TotalDataTransfer { + return errors.New("total_data_transfer mismatch") + } return nil } diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index 9930cef6..86b4fe96 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -1124,24 +1124,9 @@ paths: put: tags: - quota - summary: Update quota usage limits + summary: Update disk quota usage limits description: Sets the current used quota limits for the given user operationId: user_quota_update_usage - parameters: - - in: query - name: mode - required: false - description: the update mode specifies if the given quota usage values should be added or replace the current ones - schema: - type: string - enum: - - add - - reset - description: | - Update type: - * `add` - add the specified quota limits to the current used ones - * `reset` - reset the values to the specified ones. This is the default - example: reset requestBody: required: true description: 'If used_quota_size and used_quota_files are missing they will default to 0, this means that if mode is "add" the current value, for the missing field, will remain unchanged, if mode is "reset" the missing field is set to 0' @@ -1172,6 +1157,64 @@ paths: $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' + /quotas/users/{username}/transfer-usage: + parameters: + - name: username + in: path + description: the username + required: true + schema: + type: string + - in: query + name: mode + required: false + description: the update mode specifies if the given quota usage values should be added or replace the current ones + schema: + type: string + enum: + - add + - reset + description: | + Update type: + * `add` - add the specified quota limits to the current used ones + * `reset` - reset the values to the specified ones. This is the default + example: reset + put: + tags: + - quota + summary: Update transfer quota usage limits + description: Sets the current used transfer quota limits for the given user + operationId: user_transfer_quota_update_usage + requestBody: + required: true + description: 'If used_upload_data_transfer and used_download_data_transfer are missing they will default to 0, this means that if mode is "add" the current value, for the missing field, will remain unchanged, if mode is "reset" the missing field is set to 0' + content: + application/json: + schema: + $ref: '#/components/schemas/TransferQuotaUsage' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Quota updated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '409': + $ref: '#/components/responses/Conflict' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' /quotas/folders/scans: get: tags: @@ -4425,6 +4468,23 @@ components: type: integer format: int32 description: 'Maximum download bandwidth as KB/s, 0 means unlimited' + DataTransferLimit: + type: object + properties: + sources: + type: array + items: + type: string + description: 'Source networks in CIDR notation as defined in RFC 4632 and RFC 4291 for example `192.0.2.0/24` or `2001:db8::/32`. The limit applies if the defined networks contain the client IP' + upload_data_transfer: + type: integer + description: 'Maximum data transfer allowed for uploads as MB. 0 means no limit' + download_data_transfer: + type: integer + description: 'Maximum data transfer allowed for downloads as MB. 0 means no limit' + total_data_transfer: + type: integer + description: 'Maximum total data transfer as MB. 0 means unlimited. You can set a total data transfer instead of the individual values for uploads and downloads' UserFilters: type: object properties: @@ -4494,6 +4554,10 @@ components: type: array items: $ref: '#/components/schemas/BandwidthLimit' + data_transfer_limits: + type: array + items: + $ref: '#/components/schemas/DataTransferLimit' description: Additional user options Secret: type: object @@ -4824,12 +4888,25 @@ components: description: Last quota update as unix timestamp in milliseconds upload_bandwidth: type: integer - format: int32 description: 'Maximum upload bandwidth as KB/s, 0 means unlimited' download_bandwidth: type: integer - format: int32 description: 'Maximum download bandwidth as KB/s, 0 means unlimited' + upload_data_transfer: + type: integer + description: 'Maximum data transfer allowed for uploads as MB. 0 means no limit' + download_data_transfer: + type: integer + description: 'Maximum data transfer allowed for downloads as MB. 0 means no limit' + total_data_transfer: + type: integer + description: 'Maximum total data transfer as MB. 0 means unlimited. You can set a total data transfer instead of the individual values for uploads and downloads' + used_upload_data_transfer: + type: integer + description: 'Uploaded size, as bytes, since the last reset' + used_download_data_transfer: + type: integer + description: 'Downloaded size, as bytes, since the last reset' created_at: type: integer format: int64 @@ -4996,6 +5073,17 @@ components: used_quota_files: type: integer format: int32 + TransferQuotaUsage: + type: object + properties: + used_upload_data_transfer: + type: integer + format: int64 + description: 'The value must be specified as bytes' + used_download_data_transfer: + type: integer + format: int64 + description: 'The value must be specified as bytes' Transfer: type: object properties: diff --git a/service/service.go b/service/service.go index 07c51399..816514ce 100644 --- a/service/service.go +++ b/service/service.go @@ -86,7 +86,8 @@ func (s *Service) Start() error { return errors.New(infoString) } - err := common.Initialize(config.GetCommonConfig()) + providerConf := config.GetProviderConf() + err := common.Initialize(config.GetCommonConfig(), providerConf.GetShared()) if err != nil { logger.Error(logSender, "", "%v", err) logger.ErrorToConsole("%v", err) @@ -118,9 +119,6 @@ func (s *Service) Start() error { logger.ErrorToConsole("unable to initialize SMTP configuration: %v", err) os.Exit(1) } - - providerConf := config.GetProviderConf() - err = dataprovider.Initialize(providerConf, s.ConfigDir, s.PortableMode == 0) if err != nil { logger.Error(logSender, "", "error initializing data provider: %v", err) diff --git a/sftpd/handler.go b/sftpd/handler.go index ab6edf7d..90d91f2e 100644 --- a/sftpd/handler.go +++ b/sftpd/handler.go @@ -62,6 +62,11 @@ func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) { if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(request.Filepath)) { return nil, sftp.ErrSSHFxPermissionDenied } + transferQuota := c.GetTransferQuota() + if !transferQuota.HasDownloadSpace() { + c.Log(logger.LevelInfo, "denying file read due to quota limits") + return nil, c.GetReadQuotaExceededError() + } if ok, policy := c.User.IsFileAllowed(request.Filepath); !ok { c.Log(logger.LevelWarn, "reading file %#v is not allowed", request.Filepath) @@ -85,7 +90,7 @@ func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) { } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, request.Filepath, common.TransferDownload, - 0, 0, 0, 0, false, fs) + 0, 0, 0, 0, false, fs, transferQuota) t := newTransfer(baseTransfer, nil, r, nil) return t, nil @@ -271,7 +276,7 @@ func (c *Connection) StatVFS(r *sftp.Request) (*sftp.StatVFS, error) { // not produce any side effect here. // we don't consider c.User.Filters.MaxUploadFileSize, we return disk stats here // not the limit for a single file upload - quotaResult := c.HasSpace(true, true, path.Join(r.Filepath, "fakefile.txt")) + quotaResult, _ := c.HasSpace(true, true, path.Join(r.Filepath, "fakefile.txt")) fs, p, err := c.GetFsAndResolvedPath(r.Filepath) if err != nil { @@ -341,8 +346,8 @@ func (c *Connection) handleSFTPRemove(request *sftp.Request) error { } func (c *Connection) handleSFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, requestPath string, errForRead error) (sftp.WriterAtReaderAt, error) { - quotaResult := c.HasSpace(true, false, requestPath) - if !quotaResult.HasSpace { + diskQuota, transferQuota := c.HasSpace(true, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, c.GetQuotaExceededError() } @@ -361,10 +366,10 @@ func (c *Connection) handleSFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) // we can get an error only for resume - maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) + maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, 0, fs.IsUploadResumeSupported()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs) + common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs, transferQuota) t := newTransfer(baseTransfer, w, nil, errForRead) return t, nil @@ -373,8 +378,8 @@ func (c *Connection) handleSFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileOpenFlags, resolvedPath, filePath string, fileSize int64, requestPath string, errForRead error) (sftp.WriterAtReaderAt, error) { var err error - quotaResult := c.HasSpace(false, false, requestPath) - if !quotaResult.HasSpace { + diskQuota, transferQuota := c.HasSpace(false, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, c.GetQuotaExceededError() } @@ -388,7 +393,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO // if there is a size limit the remaining size cannot be 0 here, since quotaResult.HasSpace // will return false in this case and we deny the upload before. // For Cloud FS GetMaxWriteSize will return unsupported operation - maxWriteSize, err := c.GetMaxWriteSize(quotaResult, isResume, fileSize, fs.IsUploadResumeSupported()) + maxWriteSize, err := c.GetMaxWriteSize(diskQuota, isResume, fileSize, fs.IsUploadResumeSupported()) if err != nil { c.Log(logger.LevelDebug, "unable to get max write size: %v", err) return nil, err @@ -444,7 +449,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs) + common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs, transferQuota) t := newTransfer(baseTransfer, w, nil, errForRead) return t, nil diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index abf7ced5..bfa4c521 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -163,7 +163,7 @@ func TestUploadResumeInvalidOffset(t *testing.T) { fs := vfs.NewOsFs("", os.TempDir(), "") conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, - common.TransferUpload, 10, 0, 0, 0, false, fs) + common.TransferUpload, 10, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) transfer := newTransfer(baseTransfer, nil, nil, nil) _, err = transfer.WriteAt([]byte("test"), 0) assert.Error(t, err, "upload with invalid offset must fail") @@ -195,7 +195,7 @@ func TestReadWriteErrors(t *testing.T) { fs := vfs.NewOsFs("", os.TempDir(), "") conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, - 0, 0, 0, 0, false, fs) + 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) transfer := newTransfer(baseTransfer, nil, nil, nil) err = file.Close() assert.NoError(t, err) @@ -210,7 +210,7 @@ func TestReadWriteErrors(t *testing.T) { r, _, err := pipeat.Pipe() assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, - 0, 0, 0, 0, false, fs) + 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) transfer = newTransfer(baseTransfer, nil, r, nil) err = transfer.Close() assert.NoError(t, err) @@ -221,7 +221,7 @@ func TestReadWriteErrors(t *testing.T) { assert.NoError(t, err) pipeWriter := vfs.NewPipeWriter(w) baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, - 0, 0, 0, 0, false, fs) + 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) transfer = newTransfer(baseTransfer, pipeWriter, nil, nil) err = r.Close() @@ -269,7 +269,7 @@ func TestTransferCancelFn(t *testing.T) { fs := vfs.NewOsFs("", os.TempDir(), "") conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), file.Name(), testfile, common.TransferDownload, - 0, 0, 0, 0, false, fs) + 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) transfer := newTransfer(baseTransfer, nil, nil, nil) errFake := errors.New("fake error, this will trigger cancelFn") @@ -382,7 +382,7 @@ func TestSFTPGetUsedQuota(t *testing.T) { connection := Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", u), } - quotaResult := connection.HasSpace(false, false, "/") + quotaResult, _ := connection.HasSpace(false, false, "/") assert.False(t, quotaResult.HasSpace) } @@ -977,7 +977,7 @@ func TestSystemCommandErrors(t *testing.T) { } sshCmd.connection.channel = &mockSSHChannel baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", "", - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) transfer := newTransfer(baseTransfer, nil, nil, nil) destBuff := make([]byte, 65535) dst := bytes.NewBuffer(destBuff) @@ -993,7 +993,7 @@ func TestSystemCommandErrors(t *testing.T) { sshCmd.connection.channel = &mockSSHChannel transfer.MaxWriteSize = 1 _, err = transfer.copyFromReaderToWriter(dst, sshCmd.connection.channel) - assert.EqualError(t, err, common.ErrQuotaExceeded.Error()) + assert.True(t, transfer.Connection.IsQuotaExceededError(err)) mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer(buf), @@ -1007,7 +1007,25 @@ func TestSystemCommandErrors(t *testing.T) { assert.EqualError(t, err, io.ErrShortWrite.Error()) transfer.MaxWriteSize = -1 _, err = transfer.copyFromReaderToWriter(sshCmd.connection.channel, dst) - assert.EqualError(t, err, common.ErrQuotaExceeded.Error()) + assert.True(t, transfer.Connection.IsQuotaExceededError(err)) + + baseTransfer = common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", "", + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{ + AllowedDLSize: 1, + }) + transfer = newTransfer(baseTransfer, nil, nil, nil) + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: nil, + } + sshCmd.connection.channel = &mockSSHChannel + _, err = transfer.copyFromReaderToWriter(dst, sshCmd.connection.channel) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) + } + err = os.RemoveAll(homeDir) assert.NoError(t, err) } @@ -1644,7 +1662,7 @@ func TestSCPUploadFiledata(t *testing.T) { assert.NoError(t, err) baseTransfer := common.NewBaseTransfer(file, scpCommand.connection.BaseConnection, nil, file.Name(), file.Name(), - "/"+testfile, common.TransferDownload, 0, 0, 0, 0, true, fs) + "/"+testfile, common.TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) transfer := newTransfer(baseTransfer, nil, nil, nil) err = scpCommand.getUploadFileData(2, transfer) @@ -1729,7 +1747,7 @@ func TestUploadError(t *testing.T) { file, err := os.Create(fileTempName) assert.NoError(t, err) baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, testfile, file.Name(), - testfile, common.TransferUpload, 0, 0, 0, 0, true, fs) + testfile, common.TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) transfer := newTransfer(baseTransfer, nil, nil, nil) errFake := errors.New("fake error") @@ -1788,7 +1806,7 @@ func TestTransferFailingReader(t *testing.T) { r, _, err := pipeat.Pipe() assert.NoError(t, err) baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, fsPath, filepath.Base(fsPath), - common.TransferUpload, 0, 0, 0, 0, false, fs) + common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) errRead := errors.New("read is not allowed") tr := newTransfer(baseTransfer, nil, r, errRead) _, err = tr.ReadAt(buf, 0) diff --git a/sftpd/scp.go b/sftpd/scp.go index ee3a1947..03014ba5 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -212,11 +212,11 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) err } func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error { - quotaResult := c.connection.HasSpace(isNewFile, false, requestPath) - if !quotaResult.HasSpace { + diskQuota, transferQuota := c.connection.HasSpace(isNewFile, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { err := fmt.Errorf("denying file write due to quota limits") c.connection.Log(logger.LevelError, "error uploading file: %#v, err: %v", filePath, err) - c.sendErrorMessage(fs, err) + c.sendErrorMessage(nil, err) return err } err := common.ExecutePreAction(c.connection.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, @@ -228,7 +228,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string, return err } - maxWriteSize, _ := c.connection.GetMaxWriteSize(quotaResult, false, fileSize, fs.IsUploadResumeSupported()) + maxWriteSize, _ := c.connection.GetMaxWriteSize(diskQuota, false, fileSize, fs.IsUploadResumeSupported()) file, w, cancelFn, err := fs.Create(filePath, 0) if err != nil { @@ -262,7 +262,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string, vfs.SetPathPermissions(fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs) + common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs, transferQuota) t := newTransfer(baseTransfer, w, nil, nil) return c.getUploadFileData(sizeToRead, t) @@ -471,6 +471,12 @@ func (c *scpCommand) sendDownloadFileData(fs vfs.Fs, filePath string, stat os.Fi func (c *scpCommand) handleDownload(filePath string) error { c.connection.UpdateLastActivity() + transferQuota := c.connection.GetTransferQuota() + if !transferQuota.HasDownloadSpace() { + c.connection.Log(logger.LevelInfo, "denying file read due to quota limits") + c.sendErrorMessage(nil, c.connection.GetReadQuotaExceededError()) + return c.connection.GetReadQuotaExceededError() + } var err error fs, err := c.connection.User.GetFilesystemForPath(filePath, c.connection.ID) @@ -531,7 +537,7 @@ func (c *scpCommand) handleDownload(filePath string) error { } baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, p, filePath, - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs, transferQuota) t := newTransfer(baseTransfer, nil, r, nil) err = c.sendDownloadFileData(fs, p, stat, t) diff --git a/sftpd/server.go b/sftpd/server.go index def15f39..621d1074 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -89,8 +89,6 @@ type Configuration struct { // If set to a negative number, the number of attempts is unlimited. // If set to zero, the number of attempts are limited to 6. MaxAuthTries int `json:"max_auth_tries" mapstructure:"max_auth_tries"` - // Actions to execute on file operations and SSH commands - Actions common.ProtocolActions `json:"actions" mapstructure:"actions"` // HostKeys define the daemon's private host keys. // Each host key can be defined as a path relative to the configuration directory or an absolute one. // If empty or missing, the daemon will search or try to generate "id_rsa" and "id_ecdsa" host keys diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index 57280bb1..29ce2eb3 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -178,7 +178,7 @@ func TestMain(m *testing.M) { scriptArgs = "$@" } - err = common.Initialize(commonConf) + err = common.Initialize(commonConf, 0) if err != nil { logger.WarnToConsole("error initializing common: %v", err) os.Exit(1) @@ -323,7 +323,7 @@ func TestMain(m *testing.M) { os.Remove(postConnectPath) os.Remove(preDownloadPath) os.Remove(preUploadPath) - //os.Remove(keyIntAuthPath) + os.Remove(keyIntAuthPath) os.Remove(checkPwdPath) os.Exit(exitCode) } @@ -434,6 +434,9 @@ func TestBasicSFTPHandling(t *testing.T) { err = os.Remove(localDownloadPath) assert.NoError(t, err) } + u.Username = "missing user" + _, _, err = getSftpClient(u, false) + assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) @@ -604,7 +607,7 @@ func TestRateLimiter(t *testing.T) { }, } - err := common.Initialize(cfg) + err := common.Initialize(cfg, 0) assert.NoError(t, err) usePubKey := false @@ -625,7 +628,7 @@ func TestRateLimiter(t *testing.T) { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - err = common.Initialize(oldConfig) + err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } @@ -637,7 +640,7 @@ func TestDefender(t *testing.T) { cfg.DefenderConfig.Threshold = 3 cfg.DefenderConfig.ScoreLimitExceeded = 2 - err := common.Initialize(cfg) + err := common.Initialize(cfg, 0) assert.NoError(t, err) usePubKey := false @@ -666,7 +669,7 @@ func TestDefender(t *testing.T) { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - err = common.Initialize(oldConfig) + err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } @@ -4052,6 +4055,65 @@ func TestQuotaLimits(t *testing.T) { assert.NoError(t, err) } +func TestTransferQuotaLimits(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.DownloadDataTransfer = 1 + u.UploadDataTransfer = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(550000) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + // error while download is active + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) + } + // error before starting the download + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) + } + // error while upload is active + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + // error before starting the upload + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + } + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.UsedDownloadDataTransfer, int64(1024*1024)) + assert.Greater(t, user.UsedUploadDataTransfer, int64(1024*1024)) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + func TestUploadMaxSize(t *testing.T) { testFileSize := int64(65535) usePubKey := false @@ -8995,6 +9057,53 @@ func TestSCPPatternsFilter(t *testing.T) { assert.NoError(t, err) } +func TestSCPTransferQuotaLimits(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.DownloadDataTransfer = 1 + u.UploadDataTransfer = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(550000) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.NoError(t, err) + err = scpDownload(localDownloadPath, remoteDownPath, false, false) + assert.NoError(t, err) + // error while download is active + err = scpDownload(localDownloadPath, remoteDownPath, false, false) + assert.Error(t, err) + // error before starting the download + err = scpDownload(localDownloadPath, remoteDownPath, false, false) + assert.Error(t, err) + // error while upload is active + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.Error(t, err) + // error before starting the upload + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.Error(t, err) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.UsedDownloadDataTransfer, int64(1024*1024)) + assert.Greater(t, user.UsedUploadDataTransfer, int64(1024*1024)) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + func TestSCPUploadMaxSize(t *testing.T) { testFileSize := int64(65535) usePubKey := true diff --git a/sftpd/ssh_cmd.go b/sftpd/ssh_cmd.go index 7f591c36..bfcac9fb 100644 --- a/sftpd/ssh_cmd.go +++ b/sftpd/ssh_cmd.go @@ -51,6 +51,25 @@ type systemCommand struct { fs vfs.Fs } +func (c *systemCommand) GetSTDs() (io.WriteCloser, io.ReadCloser, io.ReadCloser, error) { + stdin, err := c.cmd.StdinPipe() + if err != nil { + return nil, nil, nil, err + } + stdout, err := c.cmd.StdoutPipe() + if err != nil { + stdin.Close() + return nil, nil, nil, err + } + stderr, err := c.cmd.StderrPipe() + if err != nil { + stdin.Close() + stdout.Close() + return nil, nil, nil, err + } + return stdin, stdout, stderr, nil +} + func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommands []string) bool { var msg sshSubsystemExecMsg if err := ssh.Unmarshal(payload, &msg); err == nil { @@ -309,8 +328,8 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { if !c.isLocalPath(sshDestPath) { return c.sendErrorResponse(errUnsupportedConfig) } - quotaResult := c.connection.HasSpace(true, false, command.quotaCheckPath) - if !quotaResult.HasSpace { + diskQuota, transferQuota := c.connection.HasSpace(true, false, command.quotaCheckPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() || !transferQuota.HasDownloadSpace() { return c.sendErrorResponse(common.ErrQuotaExceeded) } perms := []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermCreateDirs, dataprovider.PermListItems, @@ -324,15 +343,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { return c.sendErrorResponse(err) } - stdin, err := command.cmd.StdinPipe() - if err != nil { - return c.sendErrorResponse(err) - } - stdout, err := command.cmd.StdoutPipe() - if err != nil { - return c.sendErrorResponse(err) - } - stderr, err := command.cmd.StderrPipe() + stdin, stdout, stderr, err := command.GetSTDs() if err != nil { return c.sendErrorResponse(err) } @@ -351,12 +362,12 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { var once sync.Once commandResponse := make(chan bool) - remainingQuotaSize := quotaResult.GetRemainingSize() + remainingQuotaSize := diskQuota.GetRemainingSize() go func() { defer stdin.Close() baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath, - common.TransferUpload, 0, 0, remainingQuotaSize, 0, false, command.fs) + common.TransferUpload, 0, 0, remainingQuotaSize, 0, false, command.fs, transferQuota) transfer := newTransfer(baseTransfer, nil, nil, nil) w, e := transfer.copyFromReaderToWriter(stdin, c.connection.channel) @@ -369,7 +380,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { go func() { baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath, - common.TransferDownload, 0, 0, 0, 0, false, command.fs) + common.TransferDownload, 0, 0, 0, 0, false, command.fs, transferQuota) transfer := newTransfer(baseTransfer, nil, nil, nil) w, e := transfer.copyFromReaderToWriter(c.connection.channel, stdout) @@ -383,7 +394,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { go func() { baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath, - common.TransferDownload, 0, 0, 0, 0, false, command.fs) + common.TransferDownload, 0, 0, 0, 0, false, command.fs, transferQuota) transfer := newTransfer(baseTransfer, nil, nil, nil) w, e := transfer.copyFromReaderToWriter(c.connection.channel.(ssh.Channel).Stderr(), stderr) @@ -662,7 +673,7 @@ func (c *sshCommand) checkCopyDestination(fs vfs.Fs, fsDestPath string) error { } func (c *sshCommand) checkCopyQuota(numFiles int, filesSize int64, requestPath string) error { - quotaResult := c.connection.HasSpace(true, false, requestPath) + quotaResult, _ := c.connection.HasSpace(true, false, requestPath) if !quotaResult.HasSpace { return common.ErrQuotaExceeded } diff --git a/sftpd/transfer.go b/sftpd/transfer.go index dd63c58d..52b0a017 100644 --- a/sftpd/transfer.go +++ b/sftpd/transfer.go @@ -95,6 +95,9 @@ func (t *transfer) ReadAt(p []byte, off int64) (n int, err error) { n, err = t.readerAt.ReadAt(p, off) atomic.AddInt64(&t.BytesSent, int64(n)) + if err == nil { + err = t.CheckRead() + } if err != nil && err != io.EOF { if t.GetType() == common.TransferDownload { t.TransferError(err) @@ -118,8 +121,8 @@ func (t *transfer) WriteAt(p []byte, off int64) (n int, err error) { n, err = t.writerAt.WriteAt(p, off) atomic.AddInt64(&t.BytesReceived, int64(n)) - if t.MaxWriteSize > 0 && err == nil && atomic.LoadInt64(&t.BytesReceived) > t.MaxWriteSize { - err = common.ErrQuotaExceeded + if err == nil { + err = t.CheckWrite() } if err != nil { t.TransferError(err) @@ -197,12 +200,16 @@ func (t *transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64, written += int64(nw) if isDownload { atomic.StoreInt64(&t.BytesSent, written) + if errCheck := t.CheckRead(); errCheck != nil { + err = errCheck + break + } } else { atomic.StoreInt64(&t.BytesReceived, written) - } - if t.MaxWriteSize > 0 && written > t.MaxWriteSize { - err = common.ErrQuotaExceeded - break + if errCheck := t.CheckWrite(); errCheck != nil { + err = errCheck + break + } } } if ew != nil { diff --git a/templates/webadmin/user.html b/templates/webadmin/user.html index c473eb26..9ee39289 100644 --- a/templates/webadmin/user.html +++ b/templates/webadmin/user.html @@ -509,7 +509,7 @@ - Comma separated IP/Mask in CIDR format, for example "192.168.1.0/24,10.8.0.100/32" + Comma separated IP/Mask in CIDR format, example: "192.168.1.0/24,10.8.0.100/32" @@ -520,7 +520,7 @@ - Comma separated IP/Mask in CIDR format, for example "192.168.1.0/24,10.8.0.100/32" + Comma separated IP/Mask in CIDR format, example: "192.168.1.0/24,10.8.0.100/32" @@ -594,7 +594,7 @@
- Per-source bandwidth limits + Per-source bandwidth speed limits
@@ -605,7 +605,7 @@ - Comma separated IP/Mask in CIDR format, for example "192.168.1.0/24,10.8.0.100/32" + Comma separated IP/Mask in CIDR format, example: "192.168.1.0/24,10.8.0.100/32"
@@ -637,7 +637,7 @@ - Comma separated IP/Mask in CIDR format, for example "192.168.1.0/24,10.8.0.100/32" + Comma separated IP/Mask in CIDR format, example: "192.168.1.0/24,10.8.0.100/32"
@@ -668,7 +668,138 @@
+
+
+
+ +
+ +
+ + + Maximum data transfer allowed for uploads. 0 means no limit + +
+
+ +
+ + + Maximum data transfer allowed for downloads. 0 means no limit + +
+
+ +
+ +
+ + + Maximum data transfer allowed for uploads + downloads. Replace the individual limits. 0 means no limit + +
+
+ +
+
+ Per-source data transfer limits +
+
+
+
+ {{range $idx, $dtLimit := .User.Filters.DataTransferLimits -}} +
+
+ + + Comma separated IP/Mask in CIDR format, example: "192.168.1.0/24,10.8.0.100/32" + +
+
+
+ + + UL (MB). 0 means no limit + +
+
+ + + DL (MB). 0 means no limit + +
+
+
+
+ + + Total (MB). 0 means no limit + +
+
+
+ +
+
+ {{else}} +
+
+ + + Comma separated IP/Mask in CIDR format, example: "192.168.1.0/24,10.8.0.100/32" + +
+
+
+ + + UL (MB). 0 means no limit + +
+
+ + + DL (MB). 0 means no limit + +
+
+
+
+ + + Total (MB). 0 means no limit + +
+
+
+ +
+
+ {{end}} +
+
+ +
+
@@ -930,7 +1061,7 @@ - Comma separated IP/Mask in CIDR format, for example "192.168.1.0/24,10.8.0.100/32" + Comma separated IP/Mask in CIDR format, example: "192.168.1.0/24,10.8.0.100/32"
@@ -962,6 +1093,58 @@ $(this).closest(".form_field_bwlimits_outer_row").remove(); }); + $("body").on("click", ".add_new_dtlimit_field_btn", function () { + var index = $(".form_field_dtlimits_outer").find(".form_field_dtlimits_outer_row").length; + while (document.getElementById("idDataTransferLimitSources"+index) != null){ + index++; + } + $(".form_field_dtlimits_outer").append(` +
+
+ + + Comma separated IP/Mask in CIDR format, example: "192.168.1.0/24,10.8.0.100/32" + +
+
+
+ + + UL (MB). 0 means no limit + +
+
+ + + DL (MB). 0 means no limit + +
+
+
+
+ + + Total (MB). 0 means no limit + +
+
+
+ +
+
+ `); + }); + + $("body").on("click", ".remove_dtlimit_btn_frm_field", function () { + $(this).closest(".form_field_dtlimits_outer_row").remove(); + }); + $("body").on("click", ".add_new_pattern_field_btn", function () { var index = $(".form_field_patterns_outer").find(".form_field_patterns_outer_row").length; while (document.getElementById("idPatternPath"+index) != null){ diff --git a/webdavd/file.go b/webdavd/file.go index c00a6856..141bb33b 100644 --- a/webdavd/file.go +++ b/webdavd/file.go @@ -142,6 +142,11 @@ func (f *webDavFile) Read(p []byte) (n int, err error) { if !f.Connection.User.HasPerm(dataprovider.PermDownload, path.Dir(f.GetVirtualPath())) { return 0, f.Connection.GetPermissionDeniedError() } + transferQuota := f.BaseTransfer.GetTransferQuota() + if !transferQuota.HasDownloadSpace() { + f.Connection.Log(logger.LevelInfo, "denying file read due to quota limits") + return 0, f.Connection.GetReadQuotaExceededError() + } if ok, policy := f.Connection.User.IsFileAllowed(f.GetVirtualPath()); !ok { f.Connection.Log(logger.LevelWarn, "reading file %#v is not allowed", f.GetVirtualPath()) @@ -180,7 +185,9 @@ func (f *webDavFile) Read(p []byte) (n int, err error) { n, err = f.reader.Read(p) atomic.AddInt64(&f.BytesSent, int64(n)) - + if err == nil { + err = f.CheckRead() + } if err != nil && err != io.EOF { f.TransferError(err) return @@ -200,8 +207,8 @@ func (f *webDavFile) Write(p []byte) (n int, err error) { n, err = f.writer.Write(p) atomic.AddInt64(&f.BytesReceived, int64(n)) - if f.MaxWriteSize > 0 && err == nil && atomic.LoadInt64(&f.BytesReceived) > f.MaxWriteSize { - err = common.ErrQuotaExceeded + if err == nil { + err = f.CheckWrite() } if err != nil { f.TransferError(err) @@ -260,6 +267,9 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) { startByte := int64(0) atomic.StoreInt64(&f.BytesReceived, 0) atomic.StoreInt64(&f.BytesSent, 0) + go func(ulSize, dlSize int64, user dataprovider.User) { + dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck + }(atomic.LoadInt64(&f.BytesReceived), atomic.LoadInt64(&f.BytesSent), f.Connection.User) switch whence { case io.SeekStart: diff --git a/webdavd/handler.go b/webdavd/handler.go index ee9ab1e5..35ede235 100644 --- a/webdavd/handler.go +++ b/webdavd/handler.go @@ -150,7 +150,7 @@ func (c *Connection) getFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath, - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs, c.GetTransferQuota()) return newWebDavFile(baseTransfer, nil, r), nil } @@ -193,8 +193,8 @@ func (c *Connection) putFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File } func (c *Connection) handleUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, requestPath string) (webdav.File, error) { - quotaResult := c.HasSpace(true, false, requestPath) - if !quotaResult.HasSpace { + diskQuota, transferQuota := c.HasSpace(true, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, common.ErrQuotaExceeded } @@ -211,19 +211,20 @@ func (c *Connection) handleUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, re vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) // we can get an error only for resume - maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) + maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, 0, fs.IsUploadResumeSupported()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs) + common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs, transferQuota) return newWebDavFile(baseTransfer, w, nil), nil } func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePath string, fileSize int64, - requestPath string) (webdav.File, error) { + requestPath string, +) (webdav.File, error) { var err error - quotaResult := c.HasSpace(false, false, requestPath) - if !quotaResult.HasSpace { + diskQuota, transferQuota := c.HasSpace(false, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, common.ErrQuotaExceeded } @@ -235,7 +236,7 @@ func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePat // if there is a size limit remaining size cannot be 0 here, since quotaResult.HasSpace // will return false in this case and we deny the upload before - maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, fileSize, fs.IsUploadResumeSupported()) + maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, fileSize, fs.IsUploadResumeSupported()) if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { err = fs.Rename(resolvedPath, filePath) @@ -271,7 +272,7 @@ func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePat vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, false, fs) + common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, false, fs, transferQuota) return newWebDavFile(baseTransfer, w, nil), nil } diff --git a/webdavd/internal_test.go b/webdavd/internal_test.go index 424c628c..21c2950f 100644 --- a/webdavd/internal_test.go +++ b/webdavd/internal_test.go @@ -695,7 +695,7 @@ func TestContentType(t *testing.T) { testFilePath := filepath.Join(user.HomeDir, testFile) ctx := context.Background() baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) fs = newMockOsFs(nil, false, fs.ConnectionID(), user.GetHomeDir(), nil) err := os.WriteFile(testFilePath, []byte(""), os.ModePerm) assert.NoError(t, err) @@ -745,7 +745,7 @@ func TestTransferReadWriteErrors(t *testing.T) { } testFilePath := filepath.Join(user.HomeDir, testFile) baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferUpload, 0, 0, 0, 0, false, fs) + common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile := newWebDavFile(baseTransfer, nil, nil) p := make([]byte, 1) _, err := davFile.Read(p) @@ -763,7 +763,7 @@ func TestTransferReadWriteErrors(t *testing.T) { assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Read(p) assert.True(t, os.IsNotExist(err)) @@ -771,7 +771,7 @@ func TestTransferReadWriteErrors(t *testing.T) { assert.True(t, os.IsNotExist(err)) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) err = os.WriteFile(testFilePath, []byte(""), os.ModePerm) assert.NoError(t, err) f, err := os.Open(testFilePath) @@ -796,7 +796,7 @@ func TestTransferReadWriteErrors(t *testing.T) { assert.NoError(t, err) mockFs := newMockOsFs(nil, false, fs.ConnectionID(), user.HomeDir, r) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, 0, false, mockFs) + common.TransferDownload, 0, 0, 0, 0, false, mockFs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) writeContent := []byte("content\r\n") @@ -816,7 +816,7 @@ func TestTransferReadWriteErrors(t *testing.T) { assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.writer = f err = davFile.Close() @@ -841,7 +841,7 @@ func TestTransferSeek(t *testing.T) { testFilePath := filepath.Join(user.HomeDir, testFile) testFileContents := []byte("content") baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferUpload, 0, 0, 0, 0, false, fs) + common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile := newWebDavFile(baseTransfer, nil, nil) _, err := davFile.Seek(0, io.SeekStart) assert.EqualError(t, err, common.ErrOpUnsupported.Error()) @@ -849,7 +849,7 @@ func TestTransferSeek(t *testing.T) { assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Seek(0, io.SeekCurrent) assert.True(t, os.IsNotExist(err)) @@ -863,14 +863,14 @@ func TestTransferSeek(t *testing.T) { assert.NoError(t, err) } baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Seek(0, io.SeekStart) assert.Error(t, err) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) res, err := davFile.Seek(0, io.SeekStart) assert.NoError(t, err) @@ -885,14 +885,14 @@ func TestTransferSeek(t *testing.T) { assert.Nil(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Seek(0, io.SeekEnd) assert.True(t, os.IsNotExist(err)) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.reader = f davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil) @@ -907,7 +907,7 @@ func TestTransferSeek(t *testing.T) { assert.Equal(t, int64(5), res) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil) diff --git a/webdavd/webdavd_test.go b/webdavd/webdavd_test.go index 250dc3ac..9e674e4a 100644 --- a/webdavd/webdavd_test.go +++ b/webdavd/webdavd_test.go @@ -298,7 +298,7 @@ func TestMain(m *testing.M) { os.Exit(1) } - err = common.Initialize(commonConf) + err = common.Initialize(commonConf, 0) if err != nil { logger.WarnToConsole("error initializing common: %v", err) os.Exit(1) @@ -509,11 +509,17 @@ func TestBasicHandling(t *testing.T) { expectedQuotaFiles := 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) assert.NoError(t, err) // overwrite an existing file - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) assert.NoError(t, err) + // wrong password + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword+"1", + true, testFileSize, client) + assert.Error(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) @@ -549,9 +555,11 @@ func TestBasicHandling(t *testing.T) { assert.NoError(t, err) err = client.MkdirAll(path.Join(testDir, "sub2", "sub2"), os.ModePerm) assert.NoError(t, err) - err = uploadFile(testFilePath, path.Join(testDir, testFileName+".txt"), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(testDir, testFileName+".txt"), + user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) - err = uploadFile(testFilePath, path.Join(testDir, testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(testDir, testFileName), + user.Username, defaultPassword, true, testFileSize, client) assert.NoError(t, err) files, err := client.ReadDir(testDir) assert.NoError(t, err) @@ -597,10 +605,12 @@ func TestBasicHandlingCryptFs(t *testing.T) { expectedQuotaFiles := user.UsedQuotaFiles + 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, + user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) // overwrite an existing file - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, + user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) @@ -631,9 +641,11 @@ func TestBasicHandlingCryptFs(t *testing.T) { assert.NoError(t, err) err = client.MkdirAll(path.Join(testDir, "sub2", "sub2"), os.ModePerm) assert.NoError(t, err) - err = uploadFile(testFilePath, path.Join(testDir, testFileName+".txt"), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(testDir, testFileName+".txt"), + user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) - err = uploadFile(testFilePath, path.Join(testDir, testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(testDir, testFileName), + user.Username, defaultPassword, false, testFileSize, client) assert.NoError(t, err) files, err = client.ReadDir(testDir) assert.NoError(t, err) @@ -667,7 +679,8 @@ func TestLockAfterDelete(t *testing.T) { testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) assert.NoError(t, err) lockBody := `` req, err := http.NewRequest("LOCK", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), bytes.NewReader([]byte(lockBody))) @@ -723,7 +736,8 @@ func TestRenameWithLock(t *testing.T) { testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) assert.NoError(t, err) lockBody := `` @@ -779,7 +793,8 @@ func TestPropPatch(t *testing.T) { testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) assert.NoError(t, err) httpClient := httpclient.GetHTTPClient() propatchBody := `Wed, 04 Nov 2020 13:25:51 GMTSat, 05 Dec 2020 21:16:12 GMTWed, 04 Nov 2020 13:25:51 GMT00000000` @@ -842,7 +857,7 @@ func TestRateLimiter(t *testing.T) { }, } - err := common.Initialize(cfg) + err := common.Initialize(cfg, 0) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) @@ -860,7 +875,7 @@ func TestRateLimiter(t *testing.T) { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - err = common.Initialize(oldConfig) + err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } @@ -872,7 +887,7 @@ func TestDefender(t *testing.T) { cfg.DefenderConfig.Threshold = 3 cfg.DefenderConfig.ScoreLimitExceeded = 2 - err := common.Initialize(cfg) + err := common.Initialize(cfg, 0) assert.NoError(t, err) user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) @@ -898,7 +913,7 @@ func TestDefender(t *testing.T) { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - err = common.Initialize(oldConfig) + err = common.Initialize(oldConfig, 0) assert.NoError(t, err) } @@ -1022,7 +1037,8 @@ func TestPreDownloadHook(t *testing.T) { testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) @@ -1071,16 +1087,19 @@ func TestPreUploadHook(t *testing.T) { testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) assert.NoError(t, err) err = os.WriteFile(preUploadPath, getExitCodeScriptContent(1), os.ModePerm) assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) assert.Error(t, err) - err = uploadFile(testFilePath, testFileName+"1", testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName+"1", user.Username, defaultPassword, + false, testFileSize, client) assert.Error(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) @@ -1298,26 +1317,35 @@ func TestUploadErrors(t *testing.T) { assert.NoError(t, err) err = client.Mkdir(subDir2, os.ModePerm) assert.NoError(t, err) - err = uploadFile(testFilePath, path.Join(subDir1, testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(subDir1, testFileName), user.Username, + defaultPassword, true, testFileSize, client) assert.Error(t, err) - err = uploadFile(testFilePath, path.Join(subDir2, testFileName+".zip"), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(subDir2, testFileName+".zip"), user.Username, + defaultPassword, true, testFileSize, client) + assert.Error(t, err) - err = uploadFile(testFilePath, path.Join(subDir2, testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(subDir2, testFileName), user.Username, + defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = client.Rename(path.Join(subDir2, testFileName), path.Join(subDir1, testFileName), false) assert.Error(t, err) - err = uploadFile(testFilePath, path.Join(subDir2, testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(subDir2, testFileName), user.Username, + defaultPassword, true, testFileSize, client) assert.Error(t, err) - err = uploadFile(testFilePath, subDir1, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, subDir1, user.Username, + defaultPassword, true, testFileSize, client) assert.Error(t, err) // overquota - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) assert.Error(t, err) err = client.Remove(path.Join(subDir2, testFileName)) assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) assert.Error(t, err) err = os.Remove(testFilePath) @@ -1394,7 +1422,8 @@ func TestQuotaLimits(t *testing.T) { assert.NoError(t, err) client := getWebDavClient(user, false, nil) // test quota files - err = uploadFile(testFilePath, testFileName+".quota", testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName+".quota", user.Username, defaultPassword, false, + testFileSize, client) if !assert.NoError(t, err, "username: %v", user.Username) { info, err := os.Stat(testFilePath) if assert.NoError(t, err) { @@ -1402,7 +1431,8 @@ func TestQuotaLimits(t *testing.T) { } printLatestLogs(20) } - err = uploadFile(testFilePath, testFileName+".quota1", testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName+".quota1", user.Username, defaultPassword, + false, testFileSize, client) assert.Error(t, err, "username: %v", user.Username) err = client.Rename(testFileName+".quota", testFileName, false) assert.NoError(t, err) @@ -1414,7 +1444,8 @@ func TestQuotaLimits(t *testing.T) { user.QuotaFiles = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName+".quota", testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName+".quota", user.Username, defaultPassword, + false, testFileSize, client) assert.Error(t, err) err = client.Rename(testFileName, testFileName+".quota", false) assert.NoError(t, err) @@ -1423,20 +1454,25 @@ func TestQuotaLimits(t *testing.T) { user.QuotaFiles = 0 user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) - err = uploadFile(testFilePath1, testFileName1, testFileSize1, client) + err = uploadFileWithRawClient(testFilePath1, testFileName1, user.Username, defaultPassword, + false, testFileSize1, client) assert.Error(t, err) _, err = client.Stat(testFileName1) assert.Error(t, err) err = client.Rename(testFileName+".quota", testFileName, false) assert.NoError(t, err) // overwriting an existing file will work if the resulting size is lesser or equal than the current one - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) assert.NoError(t, err) - err = uploadFile(testFilePath2, testFileName, testFileSize2, client) + err = uploadFileWithRawClient(testFilePath2, testFileName, user.Username, defaultPassword, + false, testFileSize2, client) assert.NoError(t, err) - err = uploadFile(testFilePath1, testFileName, testFileSize1, client) + err = uploadFileWithRawClient(testFilePath1, testFileName, user.Username, defaultPassword, + false, testFileSize1, client) assert.Error(t, err) - err = uploadFile(testFilePath2, testFileName, testFileSize2, client) + err = uploadFileWithRawClient(testFilePath2, testFileName, user.Username, defaultPassword, + false, testFileSize2, client) assert.NoError(t, err) err = os.Remove(testFilePath) @@ -1462,6 +1498,49 @@ func TestQuotaLimits(t *testing.T) { assert.NoError(t, err) } +func TestTransferQuotaLimits(t *testing.T) { + u := getTestUser() + u.DownloadDataTransfer = 1 + u.UploadDataTransfer = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(550000) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + client := getWebDavClient(user, false, nil) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + assert.NoError(t, err) + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + // error while download is active + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.Error(t, err) + // error before starting the download + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.Error(t, err) + // error while upload is active + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + assert.Error(t, err) + // error before starting the upload + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + assert.Error(t, err) + + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + func TestUploadMaxSize(t *testing.T) { testFileSize := int64(65535) u := getTestUser() @@ -1482,14 +1561,17 @@ func TestUploadMaxSize(t *testing.T) { err = createTestFile(testFilePath1, testFileSize1) assert.NoError(t, err) client := getWebDavClient(user, false, nil) - err = uploadFile(testFilePath1, testFileName1, testFileSize1, client) + err = uploadFileWithRawClient(testFilePath1, testFileName1, user.Username, defaultPassword, + false, testFileSize1, client) assert.Error(t, err) - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) assert.NoError(t, err) // now test overwrite an existing file with a size bigger than the allowed one err = createTestFile(filepath.Join(user.GetHomeDir(), testFileName1), testFileSize1) assert.NoError(t, err) - err = uploadFile(testFilePath1, testFileName1, testFileSize1, client) + err = uploadFileWithRawClient(testFilePath1, testFileName1, user.Username, defaultPassword, + false, testFileSize1, client) assert.Error(t, err) err = os.Remove(testFilePath) @@ -1534,7 +1616,8 @@ func TestClientClose(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) assert.Error(t, err) wg.Done() }() @@ -1691,10 +1774,12 @@ func TestSFTPBuffered(t *testing.T) { expectedQuotaFiles := 1 err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, sftpUser.Username, defaultPassword, + true, testFileSize, client) assert.NoError(t, err) // overwrite an existing file - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, sftpUser.Username, defaultPassword, + true, testFileSize, client) assert.NoError(t, err) localDownloadPath := filepath.Join(homeBasePath, testDLFileName) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) @@ -1708,7 +1793,8 @@ func TestSFTPBuffered(t *testing.T) { fileContent := []byte("test file contents") err = os.WriteFile(testFilePath, fileContent, os.ModePerm) assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName, int64(len(fileContent)), client) + err = uploadFileWithRawClient(testFilePath, testFileName, sftpUser.Username, defaultPassword, + true, int64(len(fileContent)), client) assert.NoError(t, err) remotePath := fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName) req, err := http.NewRequest(http.MethodGet, remotePath, nil) @@ -1763,7 +1849,8 @@ func TestBytesRangeRequests(t *testing.T) { err = os.WriteFile(testFilePath, fileContent, os.ModePerm) assert.NoError(t, err) client := getWebDavClient(user, true, nil) - err = uploadFile(testFilePath, testFileName, int64(len(fileContent)), client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, int64(len(fileContent)), client) assert.NoError(t, err) remotePath := fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName) req, err := http.NewRequest(http.MethodGet, remotePath, nil) @@ -1902,9 +1989,11 @@ func TestStat(t *testing.T) { assert.NoError(t, err) err = client.Mkdir(subDir, os.ModePerm) assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) assert.NoError(t, err) - err = uploadFile(testFilePath, path.Join("/", subDir, testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join("/", subDir, testFileName), user.Username, + defaultPassword, true, testFileSize, client) assert.NoError(t, err) user.Permissions["/subdir"] = []string{dataprovider.PermUpload, dataprovider.PermDownload} user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") @@ -1960,13 +2049,15 @@ func TestUploadOverwriteVfolder(t *testing.T) { testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) - err = uploadFile(testFilePath, path.Join(vdir, testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(vdir, testFileName), user.Username, + defaultPassword, true, testFileSize, client) assert.NoError(t, err) folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) assert.Equal(t, testFileSize, folder.UsedQuotaSize) assert.Equal(t, 1, folder.UsedQuotaFiles) - err = uploadFile(testFilePath, path.Join(vdir, testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(vdir, testFileName), user.Username, + defaultPassword, true, testFileSize, client) assert.NoError(t, err) folder, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) assert.NoError(t, err) @@ -2059,11 +2150,14 @@ func TestMiscCommands(t *testing.T) { testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) - err = uploadFile(testFilePath, path.Join(dir, testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(dir, testFileName), user.Username, + defaultPassword, true, testFileSize, client) assert.NoError(t, err) - err = uploadFile(testFilePath, path.Join(dir, "sub1", testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(dir, "sub1", testFileName), user.Username, + defaultPassword, true, testFileSize, client) assert.NoError(t, err) - err = uploadFile(testFilePath, path.Join(dir, "sub1", "sub2", testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(dir, "sub1", "sub2", testFileName), user.Username, + defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = client.Copy(dir, dir+"_copy", false) assert.NoError(t, err) @@ -2547,23 +2641,28 @@ func TestNestedVirtualFolders(t *testing.T) { err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) - err = uploadFile(testFilePath, testFileName, testFileSize, client) + err = uploadFileWithRawClient(testFilePath, testFileName, sftpUser.Username, + defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = downloadFile(testFileName, localDownloadPath, testFileSize, client) assert.NoError(t, err) - err = uploadFile(testFilePath, path.Join("/vdir", testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join("/vdir", testFileName), sftpUser.Username, + defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = downloadFile(path.Join("/vdir", testFileName), localDownloadPath, testFileSize, client) assert.NoError(t, err) - err = uploadFile(testFilePath, path.Join(vdirPath, testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(vdirPath, testFileName), sftpUser.Username, + defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = downloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client) assert.NoError(t, err) - err = uploadFile(testFilePath, path.Join(vdirCryptPath, testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(vdirCryptPath, testFileName), sftpUser.Username, + defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = downloadFile(path.Join(vdirCryptPath, testFileName), localDownloadPath, testFileSize, client) assert.NoError(t, err) - err = uploadFile(testFilePath, path.Join(vdirNestedPath, testFileName), testFileSize, client) + err = uploadFileWithRawClient(testFilePath, path.Join(vdirNestedPath, testFileName), sftpUser.Username, + defaultPassword, true, testFileSize, client) assert.NoError(t, err) err = downloadFile(path.Join(vdirNestedPath, testFileName), localDownloadPath, testFileSize, client) assert.NoError(t, err) @@ -2613,7 +2712,54 @@ func checkFileSize(remoteDestPath string, expectedSize int64, client *gowebdav.C return nil } -func uploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *gowebdav.Client) error { +func uploadFileWithRawClient(localSourcePath string, remoteDestPath string, username, password string, + useTLS bool, expectedSize int64, client *gowebdav.Client, +) error { + srcFile, err := os.Open(localSourcePath) + if err != nil { + return err + } + defer srcFile.Close() + + var tlsConfig *tls.Config + rootPath := fmt.Sprintf("http://%v/", webDavServerAddr) + if useTLS { + rootPath = fmt.Sprintf("https://%v/", webDavTLSServerAddr) + tlsConfig = &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + } + req, err := http.NewRequest(http.MethodPut, fmt.Sprintf("%v%v", rootPath, remoteDestPath), srcFile) + if err != nil { + return err + } + req.SetBasicAuth(username, password) + httpClient := &http.Client{Timeout: 10 * time.Second} + if tlsConfig != nil { + customTransport := http.DefaultTransport.(*http.Transport).Clone() + customTransport.TLSClientConfig = tlsConfig + httpClient.Transport = customTransport + } + defer httpClient.CloseIdleConnections() + resp, err := httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusCreated { + return fmt.Errorf("unexpected status code: %v", resp.StatusCode) + } + if expectedSize > 0 { + return checkFileSize(remoteDestPath, expectedSize, client) + } + return nil +} + +// This method is buggy. I have to find time to better investigate and eventually report the issue upstream. +// For now we upload using the uploadFileWithRawClient method +/*func uploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *gowebdav.Client) error { srcFile, err := os.Open(localSourcePath) if err != nil { return err @@ -2625,13 +2771,9 @@ func uploadFile(localSourcePath string, remoteDestPath string, expectedSize int6 } if expectedSize > 0 { return checkFileSize(remoteDestPath, expectedSize, client) - /*if err != nil { - time.Sleep(1 * time.Second) - return checkFileSize(remoteDestPath, expectedSize, client) - }*/ } return nil -} +}*/ func downloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *gowebdav.Client) error { downloadDest, err := os.Create(localDestPath) @@ -2797,7 +2939,18 @@ func createTestFile(path string, size int64) error { return err } - return os.WriteFile(path, content, os.ModePerm) + err = os.WriteFile(path, content, os.ModePerm) + if err != nil { + return err + } + fi, err := os.Stat(path) + if err != nil { + return err + } + if fi.Size() != size { + return fmt.Errorf("unexpected size %v, expected %v", fi.Size(), size) + } + return nil } func printLatestLogs(maxNumberOfLines int) {