From d2a41788462c9d9a036d698d0241fc2f74e4de84 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Thu, 20 Jan 2022 18:19:20 +0100 Subject: [PATCH] check quota usage between ongoing transfers Signed-off-by: Nicola Murino --- common/common.go | 131 +++++++--- common/common_test.go | 14 +- common/connection.go | 70 ++++- common/transfer.go | 54 +++- common/transfer_test.go | 25 +- common/transferschecker.go | 167 ++++++++++++ common/transferschecker_test.go | 449 ++++++++++++++++++++++++++++++++ dataprovider/bolt.go | 47 ++++ dataprovider/dataprovider.go | 26 ++ dataprovider/memory.go | 51 +++- dataprovider/mysql.go | 4 + dataprovider/pgsql.go | 4 + dataprovider/sqlcommon.go | 84 ++++++ dataprovider/sqlite.go | 4 + dataprovider/sqlqueries.go | 19 +- ftpd/handler.go | 10 +- ftpd/internal_test.go | 6 +- go.mod | 16 +- go.sum | 31 ++- httpd/file.go | 11 +- httpd/handler.go | 46 +++- httpd/internal_test.go | 9 +- sftpd/handler.go | 8 +- sftpd/internal_test.go | 26 +- sftpd/scp.go | 6 +- sftpd/ssh_cmd.go | 6 +- tests/eventsearcher/go.mod | 10 +- tests/eventsearcher/go.sum | 16 +- webdavd/handler.go | 10 +- webdavd/internal_test.go | 26 +- 30 files changed, 1228 insertions(+), 158 deletions(-) create mode 100644 common/transferschecker.go create mode 100644 common/transferschecker_test.go diff --git a/common/common.go b/common/common.go index d6f441d7..dffadb0c 100644 --- a/common/common.go +++ b/common/common.go @@ -53,9 +53,10 @@ const ( operationMkdir = "mkdir" operationRmdir = "rmdir" // SSH command action name - OperationSSHCmd = "ssh_cmd" - chtimesFormat = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS - idleTimeoutCheckInterval = 3 * time.Minute + OperationSSHCmd = "ssh_cmd" + chtimesFormat = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS + idleTimeoutCheckInterval = 3 * time.Minute + periodicTimeoutCheckInterval = 1 * time.Minute ) // Stat flags @@ -110,6 +111,7 @@ var ( ErrCrtRevoked = errors.New("your certificate has been revoked") ErrNoCredentials = errors.New("no credential provided") ErrInternalFailure = errors.New("internal failure") + ErrTransferAborted = errors.New("transfer aborted") errNoTransfer = errors.New("requested transfer not found") errTransferMismatch = errors.New("transfer mismatch") ) @@ -120,10 +122,11 @@ var ( // Connections is the list of active connections Connections ActiveConnections // QuotaScans is the list of active quota scans - QuotaScans ActiveScans - idleTimeoutTicker *time.Ticker - idleTimeoutTickerDone chan bool - supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV, + QuotaScans ActiveScans + transfersChecker TransfersChecker + periodicTimeoutTicker *time.Ticker + periodicTimeoutTickerDone chan bool + supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP, ProtocolHTTPShare} disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP} // the map key is the protocol, for each protocol we can have multiple rate limiters @@ -135,9 +138,7 @@ func Initialize(c Configuration) error { Config = c Config.idleLoginTimeout = 2 * time.Minute Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute - if Config.IdleTimeout > 0 { - startIdleTimeoutTicker(idleTimeoutCheckInterval) - } + startPeriodicTimeoutTicker(periodicTimeoutCheckInterval) Config.defender = nil rateLimiters = make(map[string][]*rateLimiter) for _, rlCfg := range c.RateLimitersConfig { @@ -176,6 +177,7 @@ func Initialize(c Configuration) error { } vfs.SetTempPath(c.TempPath) dataprovider.SetTempPath(c.TempPath) + transfersChecker = getTransfersChecker() return nil } @@ -267,41 +269,52 @@ func AddDefenderEvent(ip string, event HostEvent) { } // the ticker cannot be started/stopped from multiple goroutines -func startIdleTimeoutTicker(duration time.Duration) { - stopIdleTimeoutTicker() - idleTimeoutTicker = time.NewTicker(duration) - idleTimeoutTickerDone = make(chan bool) +func startPeriodicTimeoutTicker(duration time.Duration) { + stopPeriodicTimeoutTicker() + periodicTimeoutTicker = time.NewTicker(duration) + periodicTimeoutTickerDone = make(chan bool) go func() { + counter := int64(0) + ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval for { select { - case <-idleTimeoutTickerDone: + case <-periodicTimeoutTickerDone: return - case <-idleTimeoutTicker.C: - Connections.checkIdles() + case <-periodicTimeoutTicker.C: + counter++ + if Config.IdleTimeout > 0 && counter >= int64(ratio) { + counter = 0 + Connections.checkIdles() + } + go Connections.checkTransfers() } } }() } -func stopIdleTimeoutTicker() { - if idleTimeoutTicker != nil { - idleTimeoutTicker.Stop() - idleTimeoutTickerDone <- true - idleTimeoutTicker = nil +func stopPeriodicTimeoutTicker() { + if periodicTimeoutTicker != nil { + periodicTimeoutTicker.Stop() + periodicTimeoutTickerDone <- true + periodicTimeoutTicker = nil } } // ActiveTransfer defines the interface for the current active transfers type ActiveTransfer interface { - GetID() uint64 + GetID() int64 GetType() int GetSize() int64 + GetDownloadedSize() int64 + GetUploadedSize() int64 GetVirtualPath() string GetStartTime() time.Time - SignalClose() + SignalClose(err error) Truncate(fsPath string, size int64) (int64, error) GetRealFsPath(fsPath string) string SetTimes(fsPath string, atime time.Time, mtime time.Time) bool + GetTruncatedSize() int64 + GetMaxAllowedSize() int64 } // ActiveConnection defines the interface for the current active connections @@ -319,6 +332,7 @@ type ActiveConnection interface { AddTransfer(t ActiveTransfer) RemoveTransfer(t ActiveTransfer) GetTransfers() []ConnectionTransfer + SignalTransferClose(transferID int64, err error) CloseFS() error } @@ -335,11 +349,14 @@ type StatAttributes struct { // ConnectionTransfer defines the trasfer details to expose type ConnectionTransfer struct { - ID uint64 `json:"-"` - OperationType string `json:"operation_type"` - StartTime int64 `json:"start_time"` - Size int64 `json:"size"` - VirtualPath string `json:"path"` + ID int64 `json:"-"` + OperationType string `json:"operation_type"` + StartTime int64 `json:"start_time"` + Size int64 `json:"size"` + VirtualPath string `json:"path"` + MaxAllowedSize int64 `json:"-"` + ULSize int64 `json:"-"` + DLSize int64 `json:"-"` } func (t *ConnectionTransfer) getConnectionTransferAsString() string { @@ -653,7 +670,8 @@ func (c *SSHConnection) Close() error { type ActiveConnections struct { // clients contains both authenticated and estabilished connections and the ones waiting // for authentication - clients clientsMap + clients clientsMap + transfersCheckStatus int32 sync.RWMutex connections []ActiveConnection sshConnections []*SSHConnection @@ -825,6 +843,59 @@ func (conns *ActiveConnections) checkIdles() { conns.RUnlock() } +func (conns *ActiveConnections) checkTransfers() { + if atomic.LoadInt32(&conns.transfersCheckStatus) == 1 { + logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution") + return + } + atomic.StoreInt32(&conns.transfersCheckStatus, 1) + defer atomic.StoreInt32(&conns.transfersCheckStatus, 0) + + var wg sync.WaitGroup + + logger.Debug(logSender, "", "start concurrent transfers check") + conns.RLock() + + // update the current size for transfers to monitors + for _, c := range conns.connections { + for _, t := range c.GetTransfers() { + if t.MaxAllowedSize > 0 { + wg.Add(1) + + go func(transfer ConnectionTransfer, connID string) { + defer wg.Done() + transfersChecker.UpdateTransferCurrentSize(transfer.ULSize, transfer.DLSize, transfer.ID, connID) + }(t, c.GetID()) + } + } + } + + conns.RUnlock() + logger.Debug(logSender, "", "waiting for the update of the transfers current size") + wg.Wait() + + logger.Debug(logSender, "", "getting overquota transfers") + overquotaTransfers := transfersChecker.GetOverquotaTransfers() + logger.Debug(logSender, "", "number of overquota transfers: %v", len(overquotaTransfers)) + if len(overquotaTransfers) == 0 { + return + } + + conns.RLock() + defer conns.RUnlock() + + for _, c := range conns.connections { + for _, overquotaTransfer := range overquotaTransfers { + if c.GetID() == overquotaTransfer.ConnID { + logger.Info(logSender, c.GetID(), "user %#v is overquota, try to close transfer id %v ", + c.GetUsername(), overquotaTransfer.TransferID) + c.SignalTransferClose(overquotaTransfer.TransferID, getQuotaExceededError(c.GetProtocol())) + } + } + } + logger.Debug(logSender, "", "transfers check completed") +} + // AddClientConnection stores a new client connection func (conns *ActiveConnections) AddClientConnection(ipAddr string) { conns.clients.add(ipAddr) diff --git a/common/common_test.go b/common/common_test.go index eea8221d..694c2e44 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -408,19 +408,19 @@ func TestIdleConnections(t *testing.T) { assert.Len(t, Connections.sshConnections, 2) Connections.RUnlock() - startIdleTimeoutTicker(100 * time.Millisecond) + startPeriodicTimeoutTicker(100 * time.Millisecond) assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond) assert.Eventually(t, func() bool { Connections.RLock() defer Connections.RUnlock() return len(Connections.sshConnections) == 1 }, 1*time.Second, 200*time.Millisecond) - stopIdleTimeoutTicker() + stopPeriodicTimeoutTicker() assert.Len(t, Connections.GetStats(), 2) c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano() cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano() sshConn2.lastActivity = c.lastActivity - startIdleTimeoutTicker(100 * time.Millisecond) + startPeriodicTimeoutTicker(100 * time.Millisecond) assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond) assert.Eventually(t, func() bool { Connections.RLock() @@ -428,7 +428,7 @@ func TestIdleConnections(t *testing.T) { return len(Connections.sshConnections) == 0 }, 1*time.Second, 200*time.Millisecond) assert.Equal(t, int32(0), Connections.GetClientConnections()) - stopIdleTimeoutTicker() + stopPeriodicTimeoutTicker() assert.True(t, customConn1.isClosed) assert.True(t, customConn2.isClosed) @@ -505,9 +505,9 @@ func TestConnectionStatus(t *testing.T) { fakeConn1 := &fakeConnection{ BaseConnection: c1, } - t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, true, fs) + t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs) t1.BytesReceived = 123 - t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs) + t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs) t2.BytesSent = 456 c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user) fakeConn2 := &fakeConnection{ @@ -519,7 +519,7 @@ func TestConnectionStatus(t *testing.T) { BaseConnection: c3, command: "PROPFIND", } - t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs) + t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs) Connections.Add(fakeConn1) Connections.Add(fakeConn2) Connections.Add(fakeConn3) diff --git a/common/connection.go b/common/connection.go index 08b45249..c3b2cce4 100644 --- a/common/connection.go +++ b/common/connection.go @@ -27,7 +27,7 @@ type BaseConnection struct { lastActivity int64 // unique ID for a transfer. // This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment - transferID uint64 + transferID int64 // Unique identifier for the connection ID string // user associated with this connection if any @@ -66,8 +66,8 @@ func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...interfac } // GetTransferID returns an unique transfer ID for this connection -func (c *BaseConnection) GetTransferID() uint64 { - return atomic.AddUint64(&c.transferID, 1) +func (c *BaseConnection) GetTransferID() int64 { + return atomic.AddInt64(&c.transferID, 1) } // GetID returns the connection ID @@ -125,6 +125,27 @@ func (c *BaseConnection) AddTransfer(t ActiveTransfer) { c.activeTransfers = append(c.activeTransfers, t) c.Log(logger.LevelDebug, "transfer added, id: %v, active transfers: %v", t.GetID(), len(c.activeTransfers)) + if t.GetMaxAllowedSize() > 0 { + folderName := "" + if t.GetType() == TransferUpload { + vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(t.GetVirtualPath())) + if err == nil { + if !vfolder.IsIncludedInUserQuota() { + folderName = vfolder.Name + } + } + } + go transfersChecker.AddTransfer(dataprovider.ActiveTransfer{ + ID: t.GetID(), + Type: t.GetType(), + ConnID: c.ID, + Username: c.GetUsername(), + FolderName: folderName, + TruncatedSize: t.GetTruncatedSize(), + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + } } // RemoveTransfer removes the specified transfer from the active ones @@ -132,6 +153,10 @@ func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) { c.Lock() defer c.Unlock() + if t.GetMaxAllowedSize() > 0 { + go transfersChecker.RemoveTransfer(t.GetID(), c.ID) + } + for idx, transfer := range c.activeTransfers { if transfer.GetID() == t.GetID() { lastIdx := len(c.activeTransfers) - 1 @@ -145,6 +170,20 @@ func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) { c.Log(logger.LevelWarn, "transfer to remove with id %v not found!", t.GetID()) } +// SignalTransferClose makes the transfer fail on the next read/write with the +// specified error +func (c *BaseConnection) SignalTransferClose(transferID int64, err error) { + c.RLock() + defer c.RUnlock() + + for _, t := range c.activeTransfers { + if t.GetID() == transferID { + c.Log(logger.LevelInfo, "signal transfer close for transfer id %v", transferID) + t.SignalClose(err) + } + } +} + // GetTransfers returns the active transfers func (c *BaseConnection) GetTransfers() []ConnectionTransfer { c.RLock() @@ -160,11 +199,14 @@ func (c *BaseConnection) GetTransfers() []ConnectionTransfer { operationType = operationUpload } transfers = append(transfers, ConnectionTransfer{ - ID: t.GetID(), - OperationType: operationType, - StartTime: util.GetTimeAsMsSinceEpoch(t.GetStartTime()), - Size: t.GetSize(), - VirtualPath: t.GetVirtualPath(), + ID: t.GetID(), + OperationType: operationType, + StartTime: util.GetTimeAsMsSinceEpoch(t.GetStartTime()), + Size: t.GetSize(), + VirtualPath: t.GetVirtualPath(), + MaxAllowedSize: t.GetMaxAllowedSize(), + ULSize: t.GetUploadedSize(), + DLSize: t.GetDownloadedSize(), }) } @@ -181,7 +223,7 @@ func (c *BaseConnection) SignalTransfersAbort() error { } for _, t := range c.activeTransfers { - t.SignalClose() + t.SignalClose(ErrTransferAborted) } return nil } @@ -1208,9 +1250,8 @@ func (c *BaseConnection) GetOpUnsupportedError() error { } } -// GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol -func (c *BaseConnection) GetQuotaExceededError() error { - switch c.protocol { +func getQuotaExceededError(protocol string) error { + switch protocol { case ProtocolSFTP: return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, ErrQuotaExceeded.Error()) case ProtocolFTP: @@ -1220,6 +1261,11 @@ func (c *BaseConnection) GetQuotaExceededError() error { } } +// GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol +func (c *BaseConnection) GetQuotaExceededError() error { + return getQuotaExceededError(c.protocol) +} + // IsQuotaExceededError returns true if the given error is a quota exceeded error func (c *BaseConnection) IsQuotaExceededError(err error) bool { switch c.protocol { diff --git a/common/transfer.go b/common/transfer.go index 2a50650d..95c479cb 100644 --- a/common/transfer.go +++ b/common/transfer.go @@ -20,7 +20,7 @@ var ( // BaseTransfer contains protocols common transfer details for an upload or a download. type BaseTransfer struct { //nolint:maligned - ID uint64 + ID int64 BytesSent int64 BytesReceived int64 Fs vfs.Fs @@ -35,18 +35,21 @@ type BaseTransfer struct { //nolint:maligned MaxWriteSize int64 MinWriteOffset int64 InitialSize int64 + truncatedSize int64 isNewFile bool transferType int AbortTransfer int32 aTime time.Time mTime time.Time sync.Mutex + errAbort error ErrTransfer error } // NewBaseTransfer returns a new BaseTransfer and adds it to the given connection func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPath, effectiveFsPath, requestPath string, - transferType int, minWriteOffset, initialSize, maxWriteSize int64, isNewFile bool, fs vfs.Fs) *BaseTransfer { + transferType int, minWriteOffset, initialSize, maxWriteSize, truncatedSize int64, isNewFile bool, fs vfs.Fs, +) *BaseTransfer { t := &BaseTransfer{ ID: conn.GetTransferID(), File: file, @@ -64,6 +67,7 @@ func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPat BytesReceived: 0, MaxWriteSize: maxWriteSize, AbortTransfer: 0, + truncatedSize: truncatedSize, Fs: fs, } @@ -77,7 +81,7 @@ func (t *BaseTransfer) SetFtpMode(mode string) { } // GetID returns the transfer ID -func (t *BaseTransfer) GetID() uint64 { +func (t *BaseTransfer) GetID() int64 { return t.ID } @@ -94,19 +98,53 @@ func (t *BaseTransfer) GetSize() int64 { return atomic.LoadInt64(&t.BytesReceived) } +// GetDownloadedSize returns the transferred size +func (t *BaseTransfer) GetDownloadedSize() int64 { + return atomic.LoadInt64(&t.BytesSent) +} + +// GetUploadedSize returns the transferred size +func (t *BaseTransfer) GetUploadedSize() int64 { + return atomic.LoadInt64(&t.BytesReceived) +} + // GetStartTime returns the start time func (t *BaseTransfer) GetStartTime() time.Time { return t.start } -// SignalClose signals that the transfer should be closed. -// For same protocols, for example WebDAV, we have no -// access to the network connection, so we use this method -// to make the next read or write to fail -func (t *BaseTransfer) SignalClose() { +// GetAbortError returns the error to send to the client if the transfer was aborted +func (t *BaseTransfer) GetAbortError() error { + t.Lock() + defer t.Unlock() + + if t.errAbort != nil { + return t.errAbort + } + return getQuotaExceededError(t.Connection.protocol) +} + +// SignalClose signals that the transfer should be closed after the next read/write. +// The optional error argument allow to send a specific error, otherwise a generic +// transfer aborted error is sent +func (t *BaseTransfer) SignalClose(err error) { + t.Lock() + t.errAbort = err + t.Unlock() atomic.StoreInt32(&(t.AbortTransfer), 1) } +// GetTruncatedSize returns the truncated sized if this is an upload overwriting +// an existing file +func (t *BaseTransfer) GetTruncatedSize() int64 { + return t.truncatedSize +} + +// GetMaxAllowedSize returns the max allowed size +func (t *BaseTransfer) GetMaxAllowedSize() int64 { + return t.MaxWriteSize +} + // GetVirtualPath returns the transfer virtual path func (t *BaseTransfer) GetVirtualPath() string { return t.requestPath diff --git a/common/transfer_test.go b/common/transfer_test.go index 5b1b8152..f498cab1 100644 --- a/common/transfer_test.go +++ b/common/transfer_test.go @@ -65,7 +65,7 @@ func TestTransferThrottling(t *testing.T) { wantedUploadElapsed -= wantedDownloadElapsed / 10 wantedDownloadElapsed -= wantedDownloadElapsed / 10 conn := NewBaseConnection("id", ProtocolSCP, "", "", u) - transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, true, fs) + transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs) transfer.BytesReceived = testFileSize transfer.Connection.UpdateLastActivity() startTime := transfer.Connection.GetLastActivity() @@ -75,7 +75,7 @@ func TestTransferThrottling(t *testing.T) { err := transfer.Close() assert.NoError(t, err) - transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, true, fs) + transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs) transfer.BytesSent = testFileSize transfer.Connection.UpdateLastActivity() startTime = transfer.Connection.GetLastActivity() @@ -101,7 +101,8 @@ func TestRealPath(t *testing.T) { file, err := os.Create(testFile) require.NoError(t, err) conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) - transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs) + transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", + TransferUpload, 0, 0, 0, 0, true, fs) rPath := transfer.GetRealFsPath(testFile) assert.Equal(t, testFile, rPath) rPath = conn.getRealFsPath(testFile) @@ -138,7 +139,8 @@ func TestTruncate(t *testing.T) { _, err = file.Write([]byte("hello")) assert.NoError(t, err) conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) - transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 5, 100, false, fs) + transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 5, + 100, 0, false, fs) err = conn.SetStat("/transfer_test_file", &StatAttributes{ Size: 2, @@ -155,7 +157,8 @@ func TestTruncate(t *testing.T) { assert.Equal(t, int64(2), fi.Size()) } - transfer = NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 100, true, fs) + transfer = NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, + 100, 0, true, fs) // file.Stat will fail on a closed file err = conn.SetStat("/transfer_test_file", &StatAttributes{ Size: 2, @@ -165,7 +168,7 @@ func TestTruncate(t *testing.T) { err = transfer.Close() assert.NoError(t, err) - transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, true, fs) + transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, 0, true, fs) _, err = transfer.Truncate("mismatch", 0) assert.EqualError(t, err, errTransferMismatch.Error()) _, err = transfer.Truncate(testFile, 0) @@ -202,7 +205,8 @@ func TestTransferErrors(t *testing.T) { assert.FailNow(t, "unable to open test file") } conn := NewBaseConnection("id", ProtocolSFTP, "", "", u) - transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs) + transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, + 0, 0, 0, 0, true, fs) assert.Nil(t, transfer.cancelFn) assert.Equal(t, testFile, transfer.GetFsPath()) transfer.SetCancelFn(cancelFn) @@ -228,7 +232,7 @@ func TestTransferErrors(t *testing.T) { assert.FailNow(t, "unable to open test file") } fsPath := filepath.Join(os.TempDir(), "test_file") - transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, true, fs) + transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs) transfer.BytesReceived = 9 transfer.TransferError(errFake) assert.Error(t, transfer.ErrTransfer, errFake.Error()) @@ -247,7 +251,7 @@ func TestTransferErrors(t *testing.T) { if !assert.NoError(t, err) { assert.FailNow(t, "unable to open test file") } - transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, true, fs) + transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs) transfer.BytesReceived = 9 // the file is closed from the embedding struct before to call close err = file.Close() @@ -273,7 +277,8 @@ func TestRemovePartialCryptoFile(t *testing.T) { }, } conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) - transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs) + transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, + 0, 0, 0, 0, true, fs) transfer.ErrTransfer = errors.New("test error") _, err = transfer.getUploadFileSize() assert.Error(t, err) diff --git a/common/transferschecker.go b/common/transferschecker.go new file mode 100644 index 00000000..35ba128b --- /dev/null +++ b/common/transferschecker.go @@ -0,0 +1,167 @@ +package common + +import ( + "errors" + "sync" + "time" + + "github.com/drakkan/sftpgo/v2/dataprovider" + "github.com/drakkan/sftpgo/v2/logger" + "github.com/drakkan/sftpgo/v2/util" +) + +type overquotaTransfer struct { + ConnID string + TransferID int64 +} + +// TransfersChecker defines the interface that transfer checkers must implement. +// A transfer checker ensure that multiple concurrent transfers does not exceeded +// the remaining user quota +type TransfersChecker interface { + AddTransfer(transfer dataprovider.ActiveTransfer) + RemoveTransfer(ID int64, connectionID string) + UpdateTransferCurrentSize(ulSize int64, dlSize int64, ID int64, connectionID string) + GetOverquotaTransfers() []overquotaTransfer +} + +func getTransfersChecker() TransfersChecker { + return &transfersCheckerMem{} +} + +type transfersCheckerMem struct { + sync.RWMutex + transfers []dataprovider.ActiveTransfer +} + +func (t *transfersCheckerMem) AddTransfer(transfer dataprovider.ActiveTransfer) { + t.Lock() + defer t.Unlock() + + t.transfers = append(t.transfers, transfer) +} + +func (t *transfersCheckerMem) RemoveTransfer(ID int64, connectionID string) { + t.Lock() + defer t.Unlock() + + for idx, transfer := range t.transfers { + if transfer.ID == ID && transfer.ConnID == connectionID { + lastIdx := len(t.transfers) - 1 + t.transfers[idx] = t.transfers[lastIdx] + t.transfers = t.transfers[:lastIdx] + return + } + } +} + +func (t *transfersCheckerMem) UpdateTransferCurrentSize(ulSize int64, dlSize int64, ID int64, connectionID string) { + t.Lock() + defer t.Unlock() + + for idx := range t.transfers { + if t.transfers[idx].ID == ID && t.transfers[idx].ConnID == connectionID { + t.transfers[idx].CurrentDLSize = dlSize + t.transfers[idx].CurrentULSize = ulSize + t.transfers[idx].UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + return + } + } +} + +func (t *transfersCheckerMem) getRemainingDiskQuota(user dataprovider.User, folderName string) (int64, error) { + var result int64 + + if folderName != "" { + for _, folder := range user.VirtualFolders { + if folder.Name == folderName { + if folder.QuotaSize > 0 { + return folder.QuotaSize - folder.UsedQuotaSize, nil + } + } + } + } else { + if user.QuotaSize > 0 { + return user.QuotaSize - user.UsedQuotaSize, nil + } + } + + return result, errors.New("no quota limit defined") +} + +func (t *transfersCheckerMem) aggregateTransfers() (map[string]bool, map[string][]dataprovider.ActiveTransfer) { + t.RLock() + defer t.RUnlock() + + usersToFetch := make(map[string]bool) + aggregations := make(map[string][]dataprovider.ActiveTransfer) + for _, transfer := range t.transfers { + key := transfer.GetKey() + aggregations[key] = append(aggregations[key], transfer) + if len(aggregations[key]) > 1 { + if transfer.FolderName != "" { + usersToFetch[transfer.Username] = true + } else { + if _, ok := usersToFetch[transfer.Username]; !ok { + usersToFetch[transfer.Username] = false + } + } + } + } + + return usersToFetch, aggregations +} + +func (t *transfersCheckerMem) GetOverquotaTransfers() []overquotaTransfer { + usersToFetch, aggregations := t.aggregateTransfers() + + if len(usersToFetch) == 0 { + return nil + } + + users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch) + if err != nil { + logger.Warn(logSender, "", "unable to check transfers, error getting users quota: %v", err) + return nil + } + + usersMap := make(map[string]dataprovider.User) + + for _, user := range users { + usersMap[user.Username] = user + } + + var overquotaTransfers []overquotaTransfer + + for _, transfers := range aggregations { + if len(transfers) > 1 { + username := transfers[0].Username + folderName := transfers[0].FolderName + // transfer type is always upload for now + remaningDiskQuota, err := t.getRemainingDiskQuota(usersMap[username], folderName) + if err != nil { + continue + } + var usedDiskQuota int64 + for _, tr := range transfers { + // We optimistically assume that a cloud transfer that replaces an existing + // file will be successful + usedDiskQuota += tr.CurrentULSize - tr.TruncatedSize + } + logger.Debug(logSender, "", "username %#v, folder %#v, concurrent transfers: %v, remaining disk quota: %v, disk quota used in ongoing transfers: %v", + username, folderName, len(transfers), remaningDiskQuota, usedDiskQuota) + if usedDiskQuota > remaningDiskQuota { + for _, tr := range transfers { + if tr.CurrentULSize > tr.TruncatedSize { + overquotaTransfers = append(overquotaTransfers, overquotaTransfer{ + ConnID: tr.ConnID, + TransferID: tr.ID, + }) + } + } + } + } + } + + return overquotaTransfers +} diff --git a/common/transferschecker_test.go b/common/transferschecker_test.go new file mode 100644 index 00000000..9345b1d0 --- /dev/null +++ b/common/transferschecker_test.go @@ -0,0 +1,449 @@ +package common + +import ( + "fmt" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/rs/xid" + "github.com/sftpgo/sdk" + "github.com/stretchr/testify/assert" + + "github.com/drakkan/sftpgo/v2/dataprovider" + "github.com/drakkan/sftpgo/v2/util" + "github.com/drakkan/sftpgo/v2/vfs" +) + +func TestTransfersCheckerDiskQuota(t *testing.T) { + username := "transfers_check_username" + folderName := "test_transfers_folder" + vdirPath := "/vdir" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Password: "testpwd", + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + QuotaSize: 120, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + VirtualFolders: []vfs.VirtualFolder{ + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: filepath.Join(os.TempDir(), folderName), + }, + VirtualPath: vdirPath, + QuotaSize: 100, + }, + }, + } + + err := dataprovider.AddUser(&user, "", "") + assert.NoError(t, err) + user, err = dataprovider.UserExists(username) + assert.NoError(t, err) + + connID1 := xid.New().String() + fsUser, err := user.GetFilesystemForPath("/file1", connID1) + assert.NoError(t, err) + conn1 := NewBaseConnection(connID1, ProtocolSFTP, "", "", user) + fakeConn1 := &fakeConnection{ + BaseConnection: conn1, + } + transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), + "/file1", TransferUpload, 0, 0, 120, 0, true, fsUser) + transfer1.BytesReceived = 150 + Connections.Add(fakeConn1) + // the transferschecker will do nothing if there is only one ongoing transfer + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + + connID2 := xid.New().String() + conn2 := NewBaseConnection(connID2, ProtocolSFTP, "", "", user) + fakeConn2 := &fakeConnection{ + BaseConnection: conn2, + } + transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), + "/file2", TransferUpload, 0, 0, 120, 40, true, fsUser) + transfer1.BytesReceived = 50 + transfer2.BytesReceived = 60 + Connections.Add(fakeConn2) + + connID3 := xid.New().String() + conn3 := NewBaseConnection(connID3, ProtocolSFTP, "", "", user) + fakeConn3 := &fakeConnection{ + BaseConnection: conn3, + } + transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"), + "/file3", TransferDownload, 0, 0, 120, 0, true, fsUser) + transfer3.BytesReceived = 60 // this value will be ignored, this is a download + Connections.Add(fakeConn3) + + // the transfers are not overquota + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + assert.Nil(t, transfer3.errAbort) + + transfer1.BytesReceived = 80 // truncated size will be subtracted, we are not overquota + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + assert.Nil(t, transfer3.errAbort) + transfer1.BytesReceived = 120 + // we are now overquota + // if another check is in progress nothing is done + atomic.StoreInt32(&Connections.transfersCheckStatus, 1) + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + assert.Nil(t, transfer3.errAbort) + atomic.StoreInt32(&Connections.transfersCheckStatus, 0) + + Connections.checkTransfers() + assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort)) + assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort)) + assert.True(t, conn1.IsQuotaExceededError(transfer1.GetAbortError())) + assert.Nil(t, transfer3.errAbort) + assert.True(t, conn3.IsQuotaExceededError(transfer3.GetAbortError())) + // update the user quota size + user.QuotaSize = 1000 + err = dataprovider.UpdateUser(&user, "", "") + assert.NoError(t, err) + transfer1.errAbort = nil + transfer2.errAbort = nil + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + assert.Nil(t, transfer3.errAbort) + + user.QuotaSize = 0 + err = dataprovider.UpdateUser(&user, "", "") + assert.NoError(t, err) + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + assert.Nil(t, transfer3.errAbort) + // now check a public folder + transfer1.BytesReceived = 0 + transfer2.BytesReceived = 0 + connID4 := xid.New().String() + fsFolder, err := user.GetFilesystemForPath(path.Join(vdirPath, "/file1"), connID4) + assert.NoError(t, err) + conn4 := NewBaseConnection(connID4, ProtocolSFTP, "", "", user) + fakeConn4 := &fakeConnection{ + BaseConnection: conn4, + } + transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(os.TempDir(), folderName, "file1"), + filepath.Join(os.TempDir(), folderName, "file1"), path.Join(vdirPath, "/file1"), TransferUpload, 0, 0, + 100, 0, true, fsFolder) + Connections.Add(fakeConn4) + connID5 := xid.New().String() + conn5 := NewBaseConnection(connID5, ProtocolSFTP, "", "", user) + fakeConn5 := &fakeConnection{ + BaseConnection: conn5, + } + transfer5 := NewBaseTransfer(nil, conn5, nil, filepath.Join(os.TempDir(), folderName, "file2"), + filepath.Join(os.TempDir(), folderName, "file2"), path.Join(vdirPath, "/file2"), TransferUpload, 0, 0, + 100, 0, true, fsFolder) + + Connections.Add(fakeConn5) + transfer4.BytesReceived = 50 + transfer5.BytesReceived = 40 + Connections.checkTransfers() + assert.Nil(t, transfer4.errAbort) + assert.Nil(t, transfer5.errAbort) + transfer5.BytesReceived = 60 + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + assert.Nil(t, transfer3.errAbort) + assert.True(t, conn1.IsQuotaExceededError(transfer4.errAbort)) + assert.True(t, conn2.IsQuotaExceededError(transfer5.errAbort)) + + if dataprovider.GetProviderStatus().Driver != dataprovider.MemoryDataProviderName { + providerConf := dataprovider.GetProviderConfig() + err = dataprovider.Close() + assert.NoError(t, err) + + transfer4.errAbort = nil + transfer5.errAbort = nil + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + assert.Nil(t, transfer3.errAbort) + assert.Nil(t, transfer4.errAbort) + assert.Nil(t, transfer5.errAbort) + + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + } + + Connections.Remove(fakeConn1.GetID()) + Connections.Remove(fakeConn2.GetID()) + Connections.Remove(fakeConn3.GetID()) + Connections.Remove(fakeConn4.GetID()) + Connections.Remove(fakeConn5.GetID()) + stats := Connections.GetStats() + assert.Len(t, stats, 0) + + err = dataprovider.DeleteUser(user.Username, "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.DeleteFolder(folderName, "", "") + assert.NoError(t, err) + err = os.RemoveAll(filepath.Join(os.TempDir(), folderName)) + assert.NoError(t, err) +} + +func TestAggregateTransfers(t *testing.T) { + checker := transfersCheckerMem{} + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferUpload, + ConnID: "1", + Username: "user", + FolderName: "", + TruncatedSize: 0, + CurrentULSize: 100, + CurrentDLSize: 0, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + usersToFetch, aggregations := checker.aggregateTransfers() + assert.Len(t, usersToFetch, 0) + assert.Len(t, aggregations, 1) + + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferDownload, + ConnID: "2", + Username: "user", + FolderName: "", + TruncatedSize: 0, + CurrentULSize: 0, + CurrentDLSize: 100, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + + usersToFetch, aggregations = checker.aggregateTransfers() + assert.Len(t, usersToFetch, 0) + assert.Len(t, aggregations, 2) + + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferUpload, + ConnID: "3", + Username: "user", + FolderName: "folder", + TruncatedSize: 0, + CurrentULSize: 10, + CurrentDLSize: 0, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + + usersToFetch, aggregations = checker.aggregateTransfers() + assert.Len(t, usersToFetch, 0) + assert.Len(t, aggregations, 3) + + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferUpload, + ConnID: "4", + Username: "user1", + FolderName: "", + TruncatedSize: 0, + CurrentULSize: 100, + CurrentDLSize: 0, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + + usersToFetch, aggregations = checker.aggregateTransfers() + assert.Len(t, usersToFetch, 0) + assert.Len(t, aggregations, 4) + + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferUpload, + ConnID: "5", + Username: "user", + FolderName: "", + TruncatedSize: 0, + CurrentULSize: 100, + CurrentDLSize: 0, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + + usersToFetch, aggregations = checker.aggregateTransfers() + assert.Len(t, usersToFetch, 1) + val, ok := usersToFetch["user"] + assert.True(t, ok) + assert.False(t, val) + assert.Len(t, aggregations, 4) + aggregate, ok := aggregations["user0"] + assert.True(t, ok) + assert.Len(t, aggregate, 2) + + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferUpload, + ConnID: "6", + Username: "user", + FolderName: "", + TruncatedSize: 0, + CurrentULSize: 100, + CurrentDLSize: 0, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + + usersToFetch, aggregations = checker.aggregateTransfers() + assert.Len(t, usersToFetch, 1) + val, ok = usersToFetch["user"] + assert.True(t, ok) + assert.False(t, val) + assert.Len(t, aggregations, 4) + aggregate, ok = aggregations["user0"] + assert.True(t, ok) + assert.Len(t, aggregate, 3) + + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferUpload, + ConnID: "7", + Username: "user", + FolderName: "folder", + TruncatedSize: 0, + CurrentULSize: 10, + CurrentDLSize: 0, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + + usersToFetch, aggregations = checker.aggregateTransfers() + assert.Len(t, usersToFetch, 1) + val, ok = usersToFetch["user"] + assert.True(t, ok) + assert.True(t, val) + assert.Len(t, aggregations, 4) + aggregate, ok = aggregations["user0"] + assert.True(t, ok) + assert.Len(t, aggregate, 3) + aggregate, ok = aggregations["userfolder0"] + assert.True(t, ok) + assert.Len(t, aggregate, 2) + + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferUpload, + ConnID: "8", + Username: "user", + FolderName: "", + TruncatedSize: 0, + CurrentULSize: 100, + CurrentDLSize: 0, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + + usersToFetch, aggregations = checker.aggregateTransfers() + assert.Len(t, usersToFetch, 1) + val, ok = usersToFetch["user"] + assert.True(t, ok) + assert.True(t, val) + assert.Len(t, aggregations, 4) + aggregate, ok = aggregations["user0"] + assert.True(t, ok) + assert.Len(t, aggregate, 4) + aggregate, ok = aggregations["userfolder0"] + assert.True(t, ok) + assert.Len(t, aggregate, 2) +} + +func TestGetUsersForQuotaCheck(t *testing.T) { + usersToFetch := make(map[string]bool) + for i := 0; i < 50; i++ { + usersToFetch[fmt.Sprintf("user%v", i)] = i%2 == 0 + } + + users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch) + assert.NoError(t, err) + assert.Len(t, users, 0) + + for i := 0; i < 40; i++ { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: fmt.Sprintf("user%v", i), + Password: "pwd", + HomeDir: filepath.Join(os.TempDir(), fmt.Sprintf("user%v", i)), + Status: 1, + QuotaSize: 120, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + VirtualFolders: []vfs.VirtualFolder{ + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: fmt.Sprintf("f%v", i), + MappedPath: filepath.Join(os.TempDir(), fmt.Sprintf("f%v", i)), + }, + VirtualPath: "/vfolder", + QuotaSize: 100, + }, + }, + } + err = dataprovider.AddUser(&user, "", "") + assert.NoError(t, err) + err = dataprovider.UpdateVirtualFolderQuota(&vfs.BaseVirtualFolder{Name: fmt.Sprintf("f%v", i)}, 1, 50, false) + assert.NoError(t, err) + } + + users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch) + assert.NoError(t, err) + assert.Len(t, users, 40) + + for _, user := range users { + userIdxStr := strings.Replace(user.Username, "user", "", 1) + userIdx, err := strconv.Atoi(userIdxStr) + assert.NoError(t, err) + if userIdx%2 == 0 { + if assert.Len(t, user.VirtualFolders, 1, user.Username) { + assert.Equal(t, int64(100), user.VirtualFolders[0].QuotaSize) + assert.Equal(t, int64(50), user.VirtualFolders[0].UsedQuotaSize) + } + } else { + switch dataprovider.GetProviderStatus().Driver { + case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, + dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: + assert.Len(t, user.VirtualFolders, 0, user.Username) + } + } + } + + for i := 0; i < 40; i++ { + err = dataprovider.DeleteUser(fmt.Sprintf("user%v", i), "", "") + assert.NoError(t, err) + err = dataprovider.DeleteFolder(fmt.Sprintf("f%v", i), "", "") + assert.NoError(t, err) + } + + users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch) + assert.NoError(t, err) + assert.Len(t, users, 0) +} diff --git a/dataprovider/bolt.go b/dataprovider/bolt.go index eaf3f5a2..9d944855 100644 --- a/dataprovider/bolt.go +++ b/dataprovider/bolt.go @@ -647,6 +647,53 @@ func (p *BoltProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) { return nil, nil } +func (p *BoltProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { + users := make([]User, 0, 30) + + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := getUsersBucket(tx) + if err != nil { + return err + } + foldersBucket, err := getFoldersBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + var user User + err := json.Unmarshal(v, &user) + if err != nil { + return err + } + needFolders, ok := toFetch[user.Username] + if !ok { + continue + } + if needFolders && len(user.VirtualFolders) > 0 { + var folders []vfs.VirtualFolder + for idx := range user.VirtualFolders { + folder := &user.VirtualFolders[idx] + baseFolder, err := folderExistsInternal(folder.Name, foldersBucket) + if err != nil { + continue + } + folder.BaseVirtualFolder = baseFolder + folders = append(folders, *folder) + } + user.VirtualFolders = folders + } + + user.SetEmptySecretsIfNil() + user.PrepareForRendering() + users = append(users, user) + } + return nil + }) + + return users, err +} + func (p *BoltProvider) getUsers(limit int, offset int, order string) ([]User, error) { users := make([]User, 0, limit) var err error diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index 6f7e63ea..4e2ccb61 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -381,6 +381,26 @@ func (c *Config) IsDefenderSupported() bool { } } +// ActiveTransfer defines an active protocol transfer +type ActiveTransfer struct { + ID int64 + Type int + ConnID string + Username string + FolderName string + TruncatedSize int64 + CurrentULSize int64 + CurrentDLSize int64 + CreatedAt int64 + UpdatedAt int64 +} + +// GetKey returns an aggregation key. +// The same key will be returned for similar transfers +func (t *ActiveTransfer) GetKey() string { + return fmt.Sprintf("%v%v%v", t.Username, t.FolderName, t.Type) +} + // DefenderEntry defines a defender entry type DefenderEntry struct { ID int64 `json:"-"` @@ -476,6 +496,7 @@ type Provider interface { getUsers(limit int, offset int, order string) ([]User, error) dumpUsers() ([]User, error) getRecentlyUpdatedUsers(after int64) ([]User, error) + getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) updateLastLogin(username string) error updateAdminLastLogin(username string) error setUpdatedAt(username string) @@ -1268,6 +1289,11 @@ func GetUsers(limit, offset int, order string) ([]User, error) { return provider.getUsers(limit, offset, order) } +// GetUsersForQuotaCheck returns the users with the fields required for a quota check +func GetUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { + return provider.getUsersForQuotaCheck(toFetch) +} + // AddFolder adds a new virtual folder. func AddFolder(folder *vfs.BaseVirtualFolder) error { return provider.addFolder(folder) diff --git a/dataprovider/memory.go b/dataprovider/memory.go index 5c85af74..81fb5d1b 100644 --- a/dataprovider/memory.go +++ b/dataprovider/memory.go @@ -349,6 +349,7 @@ func (p *MemoryProvider) dumpUsers() ([]User, error) { for _, username := range p.dbHandle.usernames { u := p.dbHandle.users[username] user := u.getACopy() + p.addVirtualFoldersToUser(&user) err = addCredentialsToUser(&user) if err != nil { return users, err @@ -376,6 +377,28 @@ func (p *MemoryProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) { return nil, nil } +func (p *MemoryProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { + users := make([]User, 0, 30) + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return users, errMemoryProviderClosed + } + for _, username := range p.dbHandle.usernames { + if val, ok := toFetch[username]; ok { + u := p.dbHandle.users[username] + user := u.getACopy() + if val { + p.addVirtualFoldersToUser(&user) + } + user.PrepareForRendering() + users = append(users, user) + } + } + + return users, nil +} + func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User, error) { users := make([]User, 0, limit) var err error @@ -396,6 +419,7 @@ func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User, } u := p.dbHandle.users[username] user := u.getACopy() + p.addVirtualFoldersToUser(&user) user.PrepareForRendering() users = append(users, user) if len(users) >= limit { @@ -411,6 +435,7 @@ func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User, username := p.dbHandle.usernames[i] u := p.dbHandle.users[username] user := u.getACopy() + p.addVirtualFoldersToUser(&user) user.PrepareForRendering() users = append(users, user) if len(users) >= limit { @@ -427,7 +452,12 @@ func (p *MemoryProvider) userExists(username string) (User, error) { if p.dbHandle.isClosed { return User{}, errMemoryProviderClosed } - return p.userExistsInternal(username) + user, err := p.userExistsInternal(username) + if err != nil { + return user, err + } + p.addVirtualFoldersToUser(&user) + return user, nil } func (p *MemoryProvider) userExistsInternal(username string) (User, error) { @@ -632,6 +662,22 @@ func (p *MemoryProvider) joinVirtualFoldersFields(user *User) []vfs.VirtualFolde return folders } +func (p *MemoryProvider) addVirtualFoldersToUser(user *User) { + if len(user.VirtualFolders) > 0 { + var folders []vfs.VirtualFolder + for idx := range user.VirtualFolders { + folder := &user.VirtualFolders[idx] + baseFolder, err := p.folderExistsInternal(folder.Name) + if err != nil { + continue + } + folder.BaseVirtualFolder = baseFolder.GetACopy() + folders = append(folders, *folder) + } + user.VirtualFolders = folders + } +} + func (p *MemoryProvider) removeUserFromFolderMapping(folderName, username string) { folder, err := p.folderExistsInternal(folderName) if err == nil { @@ -655,7 +701,8 @@ func (p *MemoryProvider) updateFoldersMappingInternal(folder vfs.BaseVirtualFold } func (p *MemoryProvider) addOrUpdateFolderInternal(baseFolder *vfs.BaseVirtualFolder, username string, usedQuotaSize int64, - usedQuotaFiles int, lastQuotaUpdate int64) (vfs.BaseVirtualFolder, error) { + usedQuotaFiles int, lastQuotaUpdate int64) (vfs.BaseVirtualFolder, error, +) { folder, err := p.folderExistsInternal(baseFolder.Name) if err == nil { // exists diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go index 1991e6c6..71432831 100644 --- a/dataprovider/mysql.go +++ b/dataprovider/mysql.go @@ -186,6 +186,10 @@ func (p *MySQLProvider) getUsers(limit int, offset int, order string) ([]User, e return sqlCommonGetUsers(limit, offset, order, p.dbHandle) } +func (p *MySQLProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { + return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle) +} + func (p *MySQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { return sqlCommonDumpFolders(p.dbHandle) } diff --git a/dataprovider/pgsql.go b/dataprovider/pgsql.go index 4f01b59d..7eb37197 100644 --- a/dataprovider/pgsql.go +++ b/dataprovider/pgsql.go @@ -198,6 +198,10 @@ func (p *PGSQLProvider) getUsers(limit int, offset int, order string) ([]User, e return sqlCommonGetUsers(limit, offset, order, p.dbHandle) } +func (p *PGSQLProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { + return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle) +} + func (p *PGSQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { return sqlCommonDumpFolders(p.dbHandle) } diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go index 4bba10f1..7dce91f9 100644 --- a/dataprovider/sqlcommon.go +++ b/dataprovider/sqlcommon.go @@ -939,6 +939,90 @@ func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User, return getUsersWithVirtualFolders(ctx, users, dbHandle) } +func sqlCommonGetUsersForQuotaCheck(toFetch map[string]bool, dbHandle sqlQuerier) ([]User, error) { + users := make([]User, 0, 30) + + usernames := make([]string, 0, len(toFetch)) + for k := range toFetch { + usernames = append(usernames, k) + } + + maxUsers := 30 + for len(usernames) > 0 { + if maxUsers > len(usernames) { + maxUsers = len(usernames) + } + usersRange, err := sqlCommonGetUsersRangeForQuotaCheck(usernames[:maxUsers], dbHandle) + if err != nil { + return users, err + } + users = append(users, usersRange...) + usernames = usernames[maxUsers:] + } + + var usersWithFolders []User + + validIdx := 0 + for _, user := range users { + if toFetch[user.Username] { + usersWithFolders = append(usersWithFolders, user) + } else { + users[validIdx] = user + validIdx++ + } + } + users = users[:validIdx] + if len(usersWithFolders) == 0 { + return users, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + usersWithFolders, err := getUsersWithVirtualFolders(ctx, usersWithFolders, dbHandle) + if err != nil { + return users, err + } + users = append(users, usersWithFolders...) + return users, nil +} + +func sqlCommonGetUsersRangeForQuotaCheck(usernames []string, dbHandle sqlQuerier) ([]User, error) { + users := make([]User, 0, len(usernames)) + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUsersForQuotaCheckQuery(len(usernames)) + stmt, err := dbHandle.PrepareContext(ctx, q) + if err != nil { + providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err) + return users, err + } + defer stmt.Close() + + queryArgs := make([]interface{}, 0, len(usernames)) + for idx := range usernames { + queryArgs = append(queryArgs, usernames[idx]) + } + + rows, err := stmt.QueryContext(ctx, queryArgs...) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var user User + err = rows.Scan(&user.ID, &user.Username, &user.QuotaSize, &user.UsedQuotaSize) + if err != nil { + return users, err + } + users = append(users, user) + } + + return users, rows.Err() +} + func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier) ([]User, error) { users := make([]User, 0, limit) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) diff --git a/dataprovider/sqlite.go b/dataprovider/sqlite.go index b9b55432..264849ec 100644 --- a/dataprovider/sqlite.go +++ b/dataprovider/sqlite.go @@ -183,6 +183,10 @@ func (p *SQLiteProvider) getUsers(limit int, offset int, order string) ([]User, return sqlCommonGetUsers(limit, offset, order, p.dbHandle) } +func (p *SQLiteProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { + return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle) +} + func (p *SQLiteProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { return sqlCommonDumpFolders(p.dbHandle) } diff --git a/dataprovider/sqlqueries.go b/dataprovider/sqlqueries.go index 133c1a89..35ef7265 100644 --- a/dataprovider/sqlqueries.go +++ b/dataprovider/sqlqueries.go @@ -21,7 +21,7 @@ const ( func getSQLPlaceholders() []string { var placeholders []string - for i := 1; i <= 30; i++ { + for i := 1; i <= 50; i++ { if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName { placeholders = append(placeholders, fmt.Sprintf("$%v", i)) } else { @@ -263,6 +263,23 @@ func getUsersQuery(order string) string { order, sqlPlaceholders[0], sqlPlaceholders[1]) } +func getUsersForQuotaCheckQuery(numArgs int) string { + var sb strings.Builder + for idx := 0; idx < numArgs; idx++ { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(sqlPlaceholders[idx]) + } + if sb.Len() > 0 { + sb.WriteString(")") + } + return fmt.Sprintf(`SELECT id,username,quota_size,used_quota_size FROM %v WHERE username IN %v`, + sqlTableUsers, sb.String()) +} + func getRecentlyUpdatedUsersQuery() string { return fmt.Sprintf(`SELECT %v FROM %v WHERE updated_at >= %v`, selectUserFields, sqlTableUsers, sqlPlaceholders[0]) } diff --git a/ftpd/handler.go b/ftpd/handler.go index 55f2f0c3..da58808c 100644 --- a/ftpd/handler.go +++ b/ftpd/handler.go @@ -335,8 +335,8 @@ func (c *Connection) downloadFile(fs vfs.Fs, fsPath, ftpPath string, offset int6 return nil, c.GetFsError(fs, err) } - baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath, common.TransferDownload, - 0, 0, 0, false, fs) + baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath, + common.TransferDownload, 0, 0, 0, 0, false, fs) baseTransfer.SetFtpMode(c.getFTPMode()) t := newTransfer(baseTransfer, nil, r, offset) @@ -402,7 +402,7 @@ func (c *Connection) handleFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, 0, 0, maxWriteSize, true, fs) + common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs) baseTransfer.SetFtpMode(c.getFTPMode()) t := newTransfer(baseTransfer, w, nil, 0) @@ -452,6 +452,7 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve } initialSize := int64(0) + truncatedSize := int64(0) // bytes truncated and not included in quota if isResume { c.Log(logger.LevelDebug, "resuming upload requested, file path: %#v initial size: %v", filePath, fileSize) minWriteOffset = fileSize @@ -473,13 +474,14 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve } } else { initialSize = fileSize + truncatedSize = fileSize } } vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs) + common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs) baseTransfer.SetFtpMode(c.getFTPMode()) t := newTransfer(baseTransfer, w, nil, 0) diff --git a/ftpd/internal_test.go b/ftpd/internal_test.go index 2dc89e4c..7febea90 100644 --- a/ftpd/internal_test.go +++ b/ftpd/internal_test.go @@ -808,7 +808,7 @@ func TestTransferErrors(t *testing.T) { clientContext: mockCC, } baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, file.Name(), file.Name(), testfile, - common.TransferDownload, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs) tr := newTransfer(baseTransfer, nil, nil, 0) err = tr.Close() assert.NoError(t, err) @@ -826,7 +826,7 @@ func TestTransferErrors(t *testing.T) { r, _, err := pipeat.Pipe() assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile, - common.TransferUpload, 0, 0, 0, false, fs) + common.TransferUpload, 0, 0, 0, 0, false, fs) tr = newTransfer(baseTransfer, nil, r, 10) pos, err := tr.Seek(10, 0) assert.NoError(t, err) @@ -838,7 +838,7 @@ func TestTransferErrors(t *testing.T) { assert.NoError(t, err) pipeWriter := vfs.NewPipeWriter(w) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile, - common.TransferUpload, 0, 0, 0, false, fs) + common.TransferUpload, 0, 0, 0, 0, false, fs) tr = newTransfer(baseTransfer, pipeWriter, nil, 0) err = r.Close() diff --git a/go.mod b/go.mod index 1ccb0758..fdc44374 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,8 @@ require ( github.com/Azure/azure-storage-blob-go v0.14.0 github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962 github.com/alexedwards/argon2id v0.0.0-20211130144151-3585854a6387 - github.com/aws/aws-sdk-go v1.42.35 - github.com/cockroachdb/cockroach-go/v2 v2.2.5 + github.com/aws/aws-sdk-go v1.42.37 + github.com/cockroachdb/cockroach-go/v2 v2.2.6 github.com/eikenb/pipeat v0.0.0-20210603033007-44fc3ffce52b github.com/fclairamb/ftpserverlib v0.17.0 github.com/fclairamb/go-log v0.2.0 @@ -35,7 +35,7 @@ require ( github.com/pires/go-proxyproto v0.6.1 github.com/pkg/sftp v1.13.5-0.20211217081921-1849af66afae github.com/pquerna/otp v1.3.0 - github.com/prometheus/client_golang v1.11.0 + github.com/prometheus/client_golang v1.12.0 github.com/rs/cors v1.8.2 github.com/rs/xid v1.3.0 github.com/rs/zerolog v1.26.2-0.20211219225053-665519c4da50 @@ -62,8 +62,8 @@ require ( require ( cloud.google.com/go v0.100.2 // indirect - cloud.google.com/go/compute v1.0.0 // indirect - cloud.google.com/go/iam v0.1.0 // indirect + cloud.google.com/go/compute v1.1.0 // indirect + cloud.google.com/go/iam v0.1.1 // indirect github.com/Azure/azure-pipeline-go v0.2.3 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/boombuler/barcode v1.0.1 // indirect @@ -79,7 +79,7 @@ require ( github.com/goccy/go-json v0.9.3 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.2 // indirect - github.com/google/go-cmp v0.5.6 // indirect + github.com/google/go-cmp v0.5.7 // indirect github.com/googleapis/gax-go/v2 v2.1.1 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect @@ -126,10 +126,10 @@ require ( golang.org/x/tools v0.1.8 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0 // indirect + google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 // indirect google.golang.org/grpc v1.43.0 // indirect google.golang.org/protobuf v1.27.1 // indirect - gopkg.in/ini.v1 v1.66.2 // indirect + gopkg.in/ini.v1 v1.66.3 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index 5c0c6752..be2ed194 100644 --- a/go.sum +++ b/go.sum @@ -46,14 +46,14 @@ cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUM cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= cloud.google.com/go/compute v0.1.0/go.mod h1:GAesmwr110a34z04OlxYkATPBEfVhkymfTBXtfbBFow= -cloud.google.com/go/compute v1.0.0 h1:SJYBzih8Jj9EUm6IDirxKG0I0AGWduhtb6BmdqWarw4= -cloud.google.com/go/compute v1.0.0/go.mod h1:GAesmwr110a34z04OlxYkATPBEfVhkymfTBXtfbBFow= +cloud.google.com/go/compute v1.1.0 h1:pyPhehLfZ6pVzRgJmXGYvCY4K7WSWRhVw0AwhgVvS84= +cloud.google.com/go/compute v1.1.0/go.mod h1:2NIffxgWfORSI7EOYMFatGTfjMLnqrOKBEyYb6NoRgA= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/firestore v1.5.0/go.mod h1:c4nNYR1qdq7eaZ+jSc5fonrQN2k3M7sWATcYTiakjEo= cloud.google.com/go/firestore v1.6.1/go.mod h1:asNXNOzBdyVQmEU+ggO8UPodTkEVFW5Qx+rwHnAz+EY= -cloud.google.com/go/iam v0.1.0 h1:W2vbGCrE3Z7J/x3WXLxxGl9LMSB2uhsAA7Ss/6u/qRY= -cloud.google.com/go/iam v0.1.0/go.mod h1:vcUNEa0pEm0qRVpmWepWaFMIAI8/hjB9mO8rNCJtF6c= +cloud.google.com/go/iam v0.1.1 h1:4CapQyNFjiksks1/x7jsvsygFPhihslYk5GptIrlX68= +cloud.google.com/go/iam v0.1.1/go.mod h1:CKqrcnI/suGpybEHxZ7BMehL0oA4LpdyJdUlTl9jVMw= cloud.google.com/go/kms v0.1.0 h1:VXAb5OzejDcyhFzIDeZ5n5AUdlsFnCyexuascIwWMj0= cloud.google.com/go/kms v0.1.0/go.mod h1:8Qp8PCAypHg4FdmlyW1QRAv09BGQ9Uzh7JnmIZxPk+c= cloud.google.com/go/monitoring v0.1.0/go.mod h1:Hpm3XfzJv+UTiXzCG5Ffp0wijzHTC7Cv4eR7o3x/fEE= @@ -141,8 +141,8 @@ github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgI github.com/aws/aws-sdk-go v1.15.27/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0= github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= github.com/aws/aws-sdk-go v1.40.34/go.mod h1:585smgzpB/KqRA+K3y/NL/oYRqQvpNJYvLm+LY1U59Q= -github.com/aws/aws-sdk-go v1.42.35 h1:N4N9buNs4YlosI9N0+WYrq8cIZwdgv34yRbxzZlTvFs= -github.com/aws/aws-sdk-go v1.42.35/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc= +github.com/aws/aws-sdk-go v1.42.37 h1:EIziSq3REaoi1LgUBgxoQr29DQS7GYHnBbZPajtJmXM= +github.com/aws/aws-sdk-go v1.42.37/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc= github.com/aws/aws-sdk-go-v2 v1.9.0/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4= github.com/aws/aws-sdk-go-v2/config v1.7.0/go.mod h1:w9+nMZ7soXCe5nT46Ri354SNhXDQ6v+V5wqDjnZE+GY= github.com/aws/aws-sdk-go-v2/credentials v1.4.0/go.mod h1:dgGR+Qq7Wjcd4AOAW5Rf5Tnv3+x7ed6kETXyS9WCuAY= @@ -190,8 +190,8 @@ github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211130200136-a8f946100490/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/cockroachdb/cockroach-go/v2 v2.2.5 h1:tfPdGHO5YpmrpN2ikJZYpaSGgU8WALwwjH3s+msiTQ0= -github.com/cockroachdb/cockroach-go/v2 v2.2.5/go.mod h1:q4ZRgO6CQpwNyEvEwSxwNrOSVchsmzrBnAv3HuZ3Abc= +github.com/cockroachdb/cockroach-go/v2 v2.2.6 h1:LTh++UIVvmDBihDo1oYbM8+OruXheusw+ILCONlAm/w= +github.com/cockroachdb/cockroach-go/v2 v2.2.6/go.mod h1:q4ZRgO6CQpwNyEvEwSxwNrOSVchsmzrBnAv3HuZ3Abc= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f h1:JOrtw2xFKzlg+cbHpyrpLDmnN1HqhBfnX7WDiW7eG2c= @@ -343,8 +343,9 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-replayers/grpcreplay v1.1.0/go.mod h1:qzAvJ8/wi57zq7gWqaE6AwLM6miiXUQwP1S+I9icmhk= github.com/google/go-replayers/httpreplay v1.0.0/go.mod h1:LJhKoTwS5Wy5Ld/peq8dFFG5OfJyHEz7ft+DsTUv25M= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -649,8 +650,9 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= -github.com/prometheus/client_golang v1.11.0 h1:HNkLOAEQMIDv/K+04rukrLx6ch7msSRwf3/SASFAGtQ= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= +github.com/prometheus/client_golang v1.12.0 h1:C+UIj/QWtmqY13Arb8kwMt5j34/0Z2iKamrJ+ryC0Gg= +github.com/prometheus/client_golang v1.12.0/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -1075,6 +1077,7 @@ google.golang.org/api v0.59.0/go.mod h1:sT2boj7M9YJxZzgeZqXogmhfmRWDtPzT31xkieUb google.golang.org/api v0.61.0/go.mod h1:xQRti5UdCmoCEqFxcz93fTl338AVqDgyaDRuOZ3hg9I= google.golang.org/api v0.62.0/go.mod h1:dKmwPCydfsad4qCH08MSdgWjfHOyfpd4VtDGgRFdavw= google.golang.org/api v0.63.0/go.mod h1:gs4ij2ffTRXwuzzgJl/56BdwJaA194ijkfn++9tDuPo= +google.golang.org/api v0.64.0/go.mod h1:931CdxA8Rm4t6zqTFGSsgwbAEZ2+GMYurbndwSimebM= google.golang.org/api v0.65.0 h1:MTW9c+LIBAbwoS1Gb+YV7NjFBt2f7GtAS5hIzh2NjgQ= google.golang.org/api v0.65.0/go.mod h1:ArYhxgGadlWmqO1IqVujw6Cs8IdD33bTmzKo2Sh+cbg= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= @@ -1162,8 +1165,9 @@ google.golang.org/genproto v0.0.0-20211208223120-3a66f561d7aa/go.mod h1:5CzLGKJ6 google.golang.org/genproto v0.0.0-20211221195035-429b39de9b1c/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20220107163113-42d7afdf6368/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= -google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0 h1:aCsSLXylHWFno0r4S3joLpiaWayvqd2Mn4iSvx4WZZc= -google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= +google.golang.org/genproto v0.0.0-20220111164026-67b88f271998/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= +google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 h1:zzNejm+EgrbLfDZ6lu9Uud2IVvHySPl8vQzf04laR5Q= +google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= @@ -1217,8 +1221,9 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= -gopkg.in/ini.v1 v1.66.2 h1:XfR1dOYubytKy4Shzc2LHrrGhU0lDCfDGG1yLPmpgsI= gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/ini.v1 v1.66.3 h1:jRskFVxYaMGAMUbN0UZ7niA9gzL9B49DOqE78vg0k3w= +gopkg.in/ini.v1 v1.66.3/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/httpd/file.go b/httpd/file.go index 63c98c05..c43381c5 100644 --- a/httpd/file.go +++ b/httpd/file.go @@ -1,7 +1,6 @@ package httpd import ( - "errors" "io" "sync/atomic" @@ -11,8 +10,6 @@ import ( "github.com/drakkan/sftpgo/v2/vfs" ) -var errTransferAborted = errors.New("transfer aborted") - type httpdFile struct { *common.BaseTransfer writer io.WriteCloser @@ -42,7 +39,9 @@ func newHTTPDFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, // Read reads the contents to downloads. func (f *httpdFile) Read(p []byte) (n int, err error) { if atomic.LoadInt32(&f.AbortTransfer) == 1 { - return 0, errTransferAborted + err := f.GetAbortError() + f.TransferError(err) + return 0, err } f.Connection.UpdateLastActivity() @@ -61,7 +60,9 @@ func (f *httpdFile) Read(p []byte) (n int, err error) { // Write writes the contents to upload func (f *httpdFile) Write(p []byte) (n int, err error) { if atomic.LoadInt32(&f.AbortTransfer) == 1 { - return 0, errTransferAborted + err := f.GetAbortError() + f.TransferError(err) + return 0, err } f.Connection.UpdateLastActivity() diff --git a/httpd/handler.go b/httpd/handler.go index 173fc068..9d1bcd82 100644 --- a/httpd/handler.go +++ b/httpd/handler.go @@ -6,6 +6,7 @@ import ( "os" "path" "strings" + "sync" "sync/atomic" "time" @@ -113,7 +114,7 @@ func (c *Connection) getFileReader(name string, offset int64, method string) (io } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, name, common.TransferDownload, - 0, 0, 0, false, fs) + 0, 0, 0, 0, false, fs) return newHTTPDFile(baseTransfer, nil, r), nil } @@ -190,6 +191,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request } initialSize := int64(0) + truncatedSize := int64(0) // bytes truncated and not included in quota if !isNewFile { if vfs.IsLocalOrSFTPFs(fs) { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) @@ -203,6 +205,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request } } else { initialSize = fileSize + truncatedSize = fileSize } if maxWriteSize > 0 { maxWriteSize += fileSize @@ -212,7 +215,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs) + common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs) return newHTTPDFile(baseTransfer, w, nil), nil } @@ -232,15 +235,17 @@ func newThrottledReader(r io.ReadCloser, limit int64, conn *Connection) *throttl type throttledReader struct { bytesRead int64 - id uint64 + id int64 limit int64 r io.ReadCloser abortTransfer int32 start time.Time conn *Connection + mu sync.Mutex + errAbort error } -func (t *throttledReader) GetID() uint64 { +func (t *throttledReader) GetID() int64 { return t.id } @@ -252,6 +257,14 @@ func (t *throttledReader) GetSize() int64 { return atomic.LoadInt64(&t.bytesRead) } +func (t *throttledReader) GetDownloadedSize() int64 { + return 0 +} + +func (t *throttledReader) GetUploadedSize() int64 { + return atomic.LoadInt64(&t.bytesRead) +} + func (t *throttledReader) GetVirtualPath() string { return "**reading request body**" } @@ -260,10 +273,31 @@ func (t *throttledReader) GetStartTime() time.Time { return t.start } -func (t *throttledReader) SignalClose() { +func (t *throttledReader) GetAbortError() error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.errAbort != nil { + return t.errAbort + } + return common.ErrTransferAborted +} + +func (t *throttledReader) SignalClose(err error) { + t.mu.Lock() + t.errAbort = err + t.mu.Unlock() atomic.StoreInt32(&(t.abortTransfer), 1) } +func (t *throttledReader) GetTruncatedSize() int64 { + return 0 +} + +func (t *throttledReader) GetMaxAllowedSize() int64 { + return 0 +} + func (t *throttledReader) Truncate(fsPath string, size int64) (int64, error) { return 0, vfs.ErrVfsUnsupported } @@ -278,7 +312,7 @@ func (t *throttledReader) SetTimes(fsPath string, atime time.Time, mtime time.Ti func (t *throttledReader) Read(p []byte) (n int, err error) { if atomic.LoadInt32(&t.abortTransfer) == 1 { - return 0, errTransferAborted + return 0, t.GetAbortError() } t.conn.UpdateLastActivity() diff --git a/httpd/internal_test.go b/httpd/internal_test.go index 5e003a92..f5b0493e 100644 --- a/httpd/internal_test.go +++ b/httpd/internal_test.go @@ -1844,12 +1844,15 @@ func TestThrottledHandler(t *testing.T) { tr := &throttledReader{ r: io.NopCloser(bytes.NewBuffer(nil)), } + assert.Equal(t, int64(0), tr.GetTruncatedSize()) err := tr.Close() assert.NoError(t, err) assert.Empty(t, tr.GetRealFsPath("real path")) assert.False(t, tr.SetTimes("p", time.Now(), time.Now())) _, err = tr.Truncate("", 0) assert.ErrorIs(t, err, vfs.ErrVfsUnsupported) + err = tr.GetAbortError() + assert.ErrorIs(t, err, common.ErrTransferAborted) } func TestHTTPDFile(t *testing.T) { @@ -1879,7 +1882,7 @@ func TestHTTPDFile(t *testing.T) { assert.NoError(t, err) baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, p, p, name, common.TransferDownload, - 0, 0, 0, false, fs) + 0, 0, 0, 0, false, fs) httpdFile := newHTTPDFile(baseTransfer, nil, nil) // the file is closed, read should fail buf := make([]byte, 100) @@ -1899,9 +1902,9 @@ func TestHTTPDFile(t *testing.T) { assert.Error(t, err) assert.Error(t, httpdFile.ErrTransfer) assert.Equal(t, err, httpdFile.ErrTransfer) - httpdFile.SignalClose() + httpdFile.SignalClose(nil) _, err = httpdFile.Write(nil) - assert.ErrorIs(t, err, errTransferAborted) + assert.ErrorIs(t, err, common.ErrQuotaExceeded) } func TestChangeUserPwd(t *testing.T) { diff --git a/sftpd/handler.go b/sftpd/handler.go index 35d90956..ab6edf7d 100644 --- a/sftpd/handler.go +++ b/sftpd/handler.go @@ -85,7 +85,7 @@ func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) { } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, request.Filepath, common.TransferDownload, - 0, 0, 0, false, fs) + 0, 0, 0, 0, false, fs) t := newTransfer(baseTransfer, nil, r, nil) return t, nil @@ -364,7 +364,7 @@ func (c *Connection) handleSFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, 0, 0, maxWriteSize, true, fs) + common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs) t := newTransfer(baseTransfer, w, nil, errForRead) return t, nil @@ -415,6 +415,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO } initialSize := int64(0) + truncatedSize := int64(0) // bytes truncated and not included in quota if isResume { c.Log(logger.LevelDebug, "resuming upload requested, file path %#v initial size: %v has append flag %v", filePath, fileSize, pflags.Append) @@ -436,13 +437,14 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO } } else { initialSize = fileSize + truncatedSize = fileSize } } vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs) + common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs) t := newTransfer(baseTransfer, w, nil, errForRead) return t, nil diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 8653cedf..abf7ced5 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -162,7 +162,8 @@ func TestUploadResumeInvalidOffset(t *testing.T) { } fs := vfs.NewOsFs("", os.TempDir(), "") conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) - baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferUpload, 10, 0, 0, false, fs) + baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, + common.TransferUpload, 10, 0, 0, 0, false, fs) transfer := newTransfer(baseTransfer, nil, nil, nil) _, err = transfer.WriteAt([]byte("test"), 0) assert.Error(t, err, "upload with invalid offset must fail") @@ -193,7 +194,8 @@ func TestReadWriteErrors(t *testing.T) { } fs := vfs.NewOsFs("", os.TempDir(), "") conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) - baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs) + baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, + 0, 0, 0, 0, false, fs) transfer := newTransfer(baseTransfer, nil, nil, nil) err = file.Close() assert.NoError(t, err) @@ -207,7 +209,8 @@ func TestReadWriteErrors(t *testing.T) { r, _, err := pipeat.Pipe() assert.NoError(t, err) - baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs) + baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, + 0, 0, 0, 0, false, fs) transfer = newTransfer(baseTransfer, nil, r, nil) err = transfer.Close() assert.NoError(t, err) @@ -217,7 +220,8 @@ func TestReadWriteErrors(t *testing.T) { r, w, err := pipeat.Pipe() assert.NoError(t, err) pipeWriter := vfs.NewPipeWriter(w) - baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs) + baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, + 0, 0, 0, 0, false, fs) transfer = newTransfer(baseTransfer, pipeWriter, nil, nil) err = r.Close() @@ -264,7 +268,8 @@ func TestTransferCancelFn(t *testing.T) { } fs := vfs.NewOsFs("", os.TempDir(), "") conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) - baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs) + baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), file.Name(), testfile, common.TransferDownload, + 0, 0, 0, 0, false, fs) transfer := newTransfer(baseTransfer, nil, nil, nil) errFake := errors.New("fake error, this will trigger cancelFn") @@ -971,8 +976,8 @@ func TestSystemCommandErrors(t *testing.T) { WriteError: nil, } sshCmd.connection.channel = &mockSSHChannel - baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", "", common.TransferDownload, - 0, 0, 0, false, fs) + baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", "", + common.TransferDownload, 0, 0, 0, 0, false, fs) transfer := newTransfer(baseTransfer, nil, nil, nil) destBuff := make([]byte, 65535) dst := bytes.NewBuffer(destBuff) @@ -1639,7 +1644,7 @@ func TestSCPUploadFiledata(t *testing.T) { assert.NoError(t, err) baseTransfer := common.NewBaseTransfer(file, scpCommand.connection.BaseConnection, nil, file.Name(), file.Name(), - "/"+testfile, common.TransferDownload, 0, 0, 0, true, fs) + "/"+testfile, common.TransferDownload, 0, 0, 0, 0, true, fs) transfer := newTransfer(baseTransfer, nil, nil, nil) err = scpCommand.getUploadFileData(2, transfer) @@ -1724,7 +1729,7 @@ func TestUploadError(t *testing.T) { file, err := os.Create(fileTempName) assert.NoError(t, err) baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, testfile, file.Name(), - testfile, common.TransferUpload, 0, 0, 0, true, fs) + testfile, common.TransferUpload, 0, 0, 0, 0, true, fs) transfer := newTransfer(baseTransfer, nil, nil, nil) errFake := errors.New("fake error") @@ -1782,7 +1787,8 @@ func TestTransferFailingReader(t *testing.T) { r, _, err := pipeat.Pipe() assert.NoError(t, err) - baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, fsPath, filepath.Base(fsPath), common.TransferUpload, 0, 0, 0, false, fs) + baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, fsPath, filepath.Base(fsPath), + common.TransferUpload, 0, 0, 0, 0, false, fs) errRead := errors.New("read is not allowed") tr := newTransfer(baseTransfer, nil, r, errRead) _, err = tr.ReadAt(buf, 0) diff --git a/sftpd/scp.go b/sftpd/scp.go index a1eacdf6..ee3a1947 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -238,6 +238,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string, } initialSize := int64(0) + truncatedSize := int64(0) // bytes truncated and not included in quota if !isNewFile { if vfs.IsLocalOrSFTPFs(fs) { vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath)) @@ -251,6 +252,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string, } } else { initialSize = fileSize + truncatedSize = initialSize } if maxWriteSize > 0 { maxWriteSize += fileSize @@ -260,7 +262,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string, vfs.SetPathPermissions(fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs) + common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs) t := newTransfer(baseTransfer, w, nil, nil) return c.getUploadFileData(sizeToRead, t) @@ -529,7 +531,7 @@ func (c *scpCommand) handleDownload(filePath string) error { } baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, p, filePath, - common.TransferDownload, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs) t := newTransfer(baseTransfer, nil, r, nil) err = c.sendDownloadFileData(fs, p, stat, t) diff --git a/sftpd/ssh_cmd.go b/sftpd/ssh_cmd.go index eefe8a1b..7f591c36 100644 --- a/sftpd/ssh_cmd.go +++ b/sftpd/ssh_cmd.go @@ -356,7 +356,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { go func() { defer stdin.Close() baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath, - common.TransferUpload, 0, 0, remainingQuotaSize, false, command.fs) + common.TransferUpload, 0, 0, remainingQuotaSize, 0, false, command.fs) transfer := newTransfer(baseTransfer, nil, nil, nil) w, e := transfer.copyFromReaderToWriter(stdin, c.connection.channel) @@ -369,7 +369,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { go func() { baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath, - common.TransferDownload, 0, 0, 0, false, command.fs) + common.TransferDownload, 0, 0, 0, 0, false, command.fs) transfer := newTransfer(baseTransfer, nil, nil, nil) w, e := transfer.copyFromReaderToWriter(c.connection.channel, stdout) @@ -383,7 +383,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { go func() { baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath, - common.TransferDownload, 0, 0, 0, false, command.fs) + common.TransferDownload, 0, 0, 0, 0, false, command.fs) transfer := newTransfer(baseTransfer, nil, nil, nil) w, e := transfer.copyFromReaderToWriter(c.connection.channel.(ssh.Channel).Stderr(), stderr) diff --git a/tests/eventsearcher/go.mod b/tests/eventsearcher/go.mod index e31067de..f0283c04 100644 --- a/tests/eventsearcher/go.mod +++ b/tests/eventsearcher/go.mod @@ -4,23 +4,23 @@ go 1.17 require ( github.com/hashicorp/go-plugin v1.4.3 - github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a + github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea ) require ( github.com/fatih/color v1.13.0 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/google/go-cmp v0.5.6 // indirect - github.com/hashicorp/go-hclog v1.0.0 // indirect + github.com/hashicorp/go-hclog v1.1.0 // indirect github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87 // indirect github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/mitchellh/go-testing-interface v1.14.1 // indirect github.com/oklog/run v1.1.0 // indirect - golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98 // indirect - golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e // indirect + golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d // indirect + golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect golang.org/x/text v0.3.7 // indirect - google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb // indirect + google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 // indirect google.golang.org/grpc v1.43.0 // indirect google.golang.org/protobuf v1.27.1 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect diff --git a/tests/eventsearcher/go.sum b/tests/eventsearcher/go.sum index 5a30aa71..81bcf694 100644 --- a/tests/eventsearcher/go.sum +++ b/tests/eventsearcher/go.sum @@ -57,8 +57,9 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/hashicorp/go-hclog v0.14.1/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= -github.com/hashicorp/go-hclog v1.0.0 h1:bkKf0BeBXcSYa7f5Fyi9gMuQ8gNsxeiNpZjR6VxNZeo= github.com/hashicorp/go-hclog v1.0.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= +github.com/hashicorp/go-hclog v1.1.0 h1:QsGcniKx5/LuX2eYoeL+Np3UKYPNaN7YKpTh29h8rbw= +github.com/hashicorp/go-hclog v1.1.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= github.com/hashicorp/go-plugin v1.4.3 h1:DXmvivbWD5qdiBts9TpBC7BYL1Aia5sxbRgQB+v6UZM= github.com/hashicorp/go-plugin v1.4.3/go.mod h1:5fGEH17QVwTTcR0zV7yhDPLLmFX9YSZ38b18Udy6vYQ= github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= @@ -85,8 +86,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= -github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a h1:JJc19rE0eW2knPa/KIFYvqyu25CwzKltJ5Cw1kK3o4A= -github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a/go.mod h1:Bhgac6kiwIziILXLzH4wepT8lQXyhF83poDXqZorN6Q= +github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea h1:ouwL3x9tXiAXIhdXtJGONd905f1dBLu3HhfFoaTq24k= +github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea/go.mod h1:Bhgac6kiwIziILXLzH4wepT8lQXyhF83poDXqZorN6Q= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= @@ -110,8 +111,9 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98 h1:+6WJMRLHlD7X7frgp7TUZ36RnQzSf9wVVTNakEp+nqY= golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d h1:1n1fc535VhN8SYtD4cDUyNlfpAF2ROMM9+11equK3hs= +golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -132,8 +134,9 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 h1:XfKQ4OlFl8okEOr5UvAqFRVj8pY/4yfcXrddB8qAbU0= +golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -156,8 +159,9 @@ google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoA google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb h1:ZrsicilzPCS/Xr8qtBZZLpy4P9TYXAfl49ctG1/5tgw= google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= +google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 h1:zzNejm+EgrbLfDZ6lu9Uud2IVvHySPl8vQzf04laR5Q= +google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= diff --git a/webdavd/handler.go b/webdavd/handler.go index c834ffd6..ee9ab1e5 100644 --- a/webdavd/handler.go +++ b/webdavd/handler.go @@ -149,8 +149,8 @@ func (c *Connection) getFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File } } - baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath, common.TransferDownload, - 0, 0, 0, false, fs) + baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath, + common.TransferDownload, 0, 0, 0, 0, false, fs) return newWebDavFile(baseTransfer, nil, r), nil } @@ -214,7 +214,7 @@ func (c *Connection) handleUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, re maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, 0, 0, maxWriteSize, true, fs) + common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs) return newWebDavFile(baseTransfer, w, nil), nil } @@ -252,6 +252,7 @@ func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePat return nil, c.GetFsError(fs, err) } initialSize := int64(0) + truncatedSize := int64(0) // bytes truncated and not included in quota if vfs.IsLocalOrSFTPFs(fs) { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { @@ -264,12 +265,13 @@ func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePat } } else { initialSize = fileSize + truncatedSize = fileSize } vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, - common.TransferUpload, 0, initialSize, maxWriteSize, false, fs) + common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, false, fs) return newWebDavFile(baseTransfer, w, nil), nil } diff --git a/webdavd/internal_test.go b/webdavd/internal_test.go index f5dd1365..424c628c 100644 --- a/webdavd/internal_test.go +++ b/webdavd/internal_test.go @@ -695,7 +695,7 @@ func TestContentType(t *testing.T) { testFilePath := filepath.Join(user.HomeDir, testFile) ctx := context.Background() baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs) fs = newMockOsFs(nil, false, fs.ConnectionID(), user.GetHomeDir(), nil) err := os.WriteFile(testFilePath, []byte(""), os.ModePerm) assert.NoError(t, err) @@ -745,7 +745,7 @@ func TestTransferReadWriteErrors(t *testing.T) { } testFilePath := filepath.Join(user.HomeDir, testFile) baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferUpload, 0, 0, 0, false, fs) + common.TransferUpload, 0, 0, 0, 0, false, fs) davFile := newWebDavFile(baseTransfer, nil, nil) p := make([]byte, 1) _, err := davFile.Read(p) @@ -763,7 +763,7 @@ func TestTransferReadWriteErrors(t *testing.T) { assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Read(p) assert.True(t, os.IsNotExist(err)) @@ -771,7 +771,7 @@ func TestTransferReadWriteErrors(t *testing.T) { assert.True(t, os.IsNotExist(err)) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs) err = os.WriteFile(testFilePath, []byte(""), os.ModePerm) assert.NoError(t, err) f, err := os.Open(testFilePath) @@ -796,7 +796,7 @@ func TestTransferReadWriteErrors(t *testing.T) { assert.NoError(t, err) mockFs := newMockOsFs(nil, false, fs.ConnectionID(), user.HomeDir, r) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, false, mockFs) + common.TransferDownload, 0, 0, 0, 0, false, mockFs) davFile = newWebDavFile(baseTransfer, nil, nil) writeContent := []byte("content\r\n") @@ -816,7 +816,7 @@ func TestTransferReadWriteErrors(t *testing.T) { assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.writer = f err = davFile.Close() @@ -841,7 +841,7 @@ func TestTransferSeek(t *testing.T) { testFilePath := filepath.Join(user.HomeDir, testFile) testFileContents := []byte("content") baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferUpload, 0, 0, 0, false, fs) + common.TransferUpload, 0, 0, 0, 0, false, fs) davFile := newWebDavFile(baseTransfer, nil, nil) _, err := davFile.Seek(0, io.SeekStart) assert.EqualError(t, err, common.ErrOpUnsupported.Error()) @@ -849,7 +849,7 @@ func TestTransferSeek(t *testing.T) { assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Seek(0, io.SeekCurrent) assert.True(t, os.IsNotExist(err)) @@ -863,14 +863,14 @@ func TestTransferSeek(t *testing.T) { assert.NoError(t, err) } baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Seek(0, io.SeekStart) assert.Error(t, err) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs) davFile = newWebDavFile(baseTransfer, nil, nil) res, err := davFile.Seek(0, io.SeekStart) assert.NoError(t, err) @@ -885,14 +885,14 @@ func TestTransferSeek(t *testing.T) { assert.Nil(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, - common.TransferDownload, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Seek(0, io.SeekEnd) assert.True(t, os.IsNotExist(err)) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.reader = f davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil) @@ -907,7 +907,7 @@ func TestTransferSeek(t *testing.T) { assert.Equal(t, int64(5), res) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, - common.TransferDownload, 0, 0, 0, false, fs) + common.TransferDownload, 0, 0, 0, 0, false, fs) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil)