check quota usage between ongoing transfers

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2022-01-20 18:19:20 +01:00
parent d73be7aee5
commit d2a4178846
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
30 changed files with 1228 additions and 158 deletions

View file

@ -53,9 +53,10 @@ const (
operationMkdir = "mkdir"
operationRmdir = "rmdir"
// SSH command action name
OperationSSHCmd = "ssh_cmd"
chtimesFormat = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS
idleTimeoutCheckInterval = 3 * time.Minute
OperationSSHCmd = "ssh_cmd"
chtimesFormat = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS
idleTimeoutCheckInterval = 3 * time.Minute
periodicTimeoutCheckInterval = 1 * time.Minute
)
// Stat flags
@ -110,6 +111,7 @@ var (
ErrCrtRevoked = errors.New("your certificate has been revoked")
ErrNoCredentials = errors.New("no credential provided")
ErrInternalFailure = errors.New("internal failure")
ErrTransferAborted = errors.New("transfer aborted")
errNoTransfer = errors.New("requested transfer not found")
errTransferMismatch = errors.New("transfer mismatch")
)
@ -120,10 +122,11 @@ var (
// Connections is the list of active connections
Connections ActiveConnections
// QuotaScans is the list of active quota scans
QuotaScans ActiveScans
idleTimeoutTicker *time.Ticker
idleTimeoutTickerDone chan bool
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
QuotaScans ActiveScans
transfersChecker TransfersChecker
periodicTimeoutTicker *time.Ticker
periodicTimeoutTickerDone chan bool
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
ProtocolHTTP, ProtocolHTTPShare}
disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
// the map key is the protocol, for each protocol we can have multiple rate limiters
@ -135,9 +138,7 @@ func Initialize(c Configuration) error {
Config = c
Config.idleLoginTimeout = 2 * time.Minute
Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute
if Config.IdleTimeout > 0 {
startIdleTimeoutTicker(idleTimeoutCheckInterval)
}
startPeriodicTimeoutTicker(periodicTimeoutCheckInterval)
Config.defender = nil
rateLimiters = make(map[string][]*rateLimiter)
for _, rlCfg := range c.RateLimitersConfig {
@ -176,6 +177,7 @@ func Initialize(c Configuration) error {
}
vfs.SetTempPath(c.TempPath)
dataprovider.SetTempPath(c.TempPath)
transfersChecker = getTransfersChecker()
return nil
}
@ -267,41 +269,52 @@ func AddDefenderEvent(ip string, event HostEvent) {
}
// the ticker cannot be started/stopped from multiple goroutines
func startIdleTimeoutTicker(duration time.Duration) {
stopIdleTimeoutTicker()
idleTimeoutTicker = time.NewTicker(duration)
idleTimeoutTickerDone = make(chan bool)
func startPeriodicTimeoutTicker(duration time.Duration) {
stopPeriodicTimeoutTicker()
periodicTimeoutTicker = time.NewTicker(duration)
periodicTimeoutTickerDone = make(chan bool)
go func() {
counter := int64(0)
ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval
for {
select {
case <-idleTimeoutTickerDone:
case <-periodicTimeoutTickerDone:
return
case <-idleTimeoutTicker.C:
Connections.checkIdles()
case <-periodicTimeoutTicker.C:
counter++
if Config.IdleTimeout > 0 && counter >= int64(ratio) {
counter = 0
Connections.checkIdles()
}
go Connections.checkTransfers()
}
}
}()
}
func stopIdleTimeoutTicker() {
if idleTimeoutTicker != nil {
idleTimeoutTicker.Stop()
idleTimeoutTickerDone <- true
idleTimeoutTicker = nil
func stopPeriodicTimeoutTicker() {
if periodicTimeoutTicker != nil {
periodicTimeoutTicker.Stop()
periodicTimeoutTickerDone <- true
periodicTimeoutTicker = nil
}
}
// ActiveTransfer defines the interface for the current active transfers
type ActiveTransfer interface {
GetID() uint64
GetID() int64
GetType() int
GetSize() int64
GetDownloadedSize() int64
GetUploadedSize() int64
GetVirtualPath() string
GetStartTime() time.Time
SignalClose()
SignalClose(err error)
Truncate(fsPath string, size int64) (int64, error)
GetRealFsPath(fsPath string) string
SetTimes(fsPath string, atime time.Time, mtime time.Time) bool
GetTruncatedSize() int64
GetMaxAllowedSize() int64
}
// ActiveConnection defines the interface for the current active connections
@ -319,6 +332,7 @@ type ActiveConnection interface {
AddTransfer(t ActiveTransfer)
RemoveTransfer(t ActiveTransfer)
GetTransfers() []ConnectionTransfer
SignalTransferClose(transferID int64, err error)
CloseFS() error
}
@ -335,11 +349,14 @@ type StatAttributes struct {
// ConnectionTransfer defines the trasfer details to expose
type ConnectionTransfer struct {
ID uint64 `json:"-"`
OperationType string `json:"operation_type"`
StartTime int64 `json:"start_time"`
Size int64 `json:"size"`
VirtualPath string `json:"path"`
ID int64 `json:"-"`
OperationType string `json:"operation_type"`
StartTime int64 `json:"start_time"`
Size int64 `json:"size"`
VirtualPath string `json:"path"`
MaxAllowedSize int64 `json:"-"`
ULSize int64 `json:"-"`
DLSize int64 `json:"-"`
}
func (t *ConnectionTransfer) getConnectionTransferAsString() string {
@ -653,7 +670,8 @@ func (c *SSHConnection) Close() error {
type ActiveConnections struct {
// clients contains both authenticated and estabilished connections and the ones waiting
// for authentication
clients clientsMap
clients clientsMap
transfersCheckStatus int32
sync.RWMutex
connections []ActiveConnection
sshConnections []*SSHConnection
@ -825,6 +843,59 @@ func (conns *ActiveConnections) checkIdles() {
conns.RUnlock()
}
func (conns *ActiveConnections) checkTransfers() {
if atomic.LoadInt32(&conns.transfersCheckStatus) == 1 {
logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution")
return
}
atomic.StoreInt32(&conns.transfersCheckStatus, 1)
defer atomic.StoreInt32(&conns.transfersCheckStatus, 0)
var wg sync.WaitGroup
logger.Debug(logSender, "", "start concurrent transfers check")
conns.RLock()
// update the current size for transfers to monitors
for _, c := range conns.connections {
for _, t := range c.GetTransfers() {
if t.MaxAllowedSize > 0 {
wg.Add(1)
go func(transfer ConnectionTransfer, connID string) {
defer wg.Done()
transfersChecker.UpdateTransferCurrentSize(transfer.ULSize, transfer.DLSize, transfer.ID, connID)
}(t, c.GetID())
}
}
}
conns.RUnlock()
logger.Debug(logSender, "", "waiting for the update of the transfers current size")
wg.Wait()
logger.Debug(logSender, "", "getting overquota transfers")
overquotaTransfers := transfersChecker.GetOverquotaTransfers()
logger.Debug(logSender, "", "number of overquota transfers: %v", len(overquotaTransfers))
if len(overquotaTransfers) == 0 {
return
}
conns.RLock()
defer conns.RUnlock()
for _, c := range conns.connections {
for _, overquotaTransfer := range overquotaTransfers {
if c.GetID() == overquotaTransfer.ConnID {
logger.Info(logSender, c.GetID(), "user %#v is overquota, try to close transfer id %v ",
c.GetUsername(), overquotaTransfer.TransferID)
c.SignalTransferClose(overquotaTransfer.TransferID, getQuotaExceededError(c.GetProtocol()))
}
}
}
logger.Debug(logSender, "", "transfers check completed")
}
// AddClientConnection stores a new client connection
func (conns *ActiveConnections) AddClientConnection(ipAddr string) {
conns.clients.add(ipAddr)

View file

@ -408,19 +408,19 @@ func TestIdleConnections(t *testing.T) {
assert.Len(t, Connections.sshConnections, 2)
Connections.RUnlock()
startIdleTimeoutTicker(100 * time.Millisecond)
startPeriodicTimeoutTicker(100 * time.Millisecond)
assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond)
assert.Eventually(t, func() bool {
Connections.RLock()
defer Connections.RUnlock()
return len(Connections.sshConnections) == 1
}, 1*time.Second, 200*time.Millisecond)
stopIdleTimeoutTicker()
stopPeriodicTimeoutTicker()
assert.Len(t, Connections.GetStats(), 2)
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
sshConn2.lastActivity = c.lastActivity
startIdleTimeoutTicker(100 * time.Millisecond)
startPeriodicTimeoutTicker(100 * time.Millisecond)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond)
assert.Eventually(t, func() bool {
Connections.RLock()
@ -428,7 +428,7 @@ func TestIdleConnections(t *testing.T) {
return len(Connections.sshConnections) == 0
}, 1*time.Second, 200*time.Millisecond)
assert.Equal(t, int32(0), Connections.GetClientConnections())
stopIdleTimeoutTicker()
stopPeriodicTimeoutTicker()
assert.True(t, customConn1.isClosed)
assert.True(t, customConn2.isClosed)
@ -505,9 +505,9 @@ func TestConnectionStatus(t *testing.T) {
fakeConn1 := &fakeConnection{
BaseConnection: c1,
}
t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, true, fs)
t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs)
t1.BytesReceived = 123
t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs)
t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs)
t2.BytesSent = 456
c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user)
fakeConn2 := &fakeConnection{
@ -519,7 +519,7 @@ func TestConnectionStatus(t *testing.T) {
BaseConnection: c3,
command: "PROPFIND",
}
t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs)
t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs)
Connections.Add(fakeConn1)
Connections.Add(fakeConn2)
Connections.Add(fakeConn3)

View file

@ -27,7 +27,7 @@ type BaseConnection struct {
lastActivity int64
// unique ID for a transfer.
// This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment
transferID uint64
transferID int64
// Unique identifier for the connection
ID string
// user associated with this connection if any
@ -66,8 +66,8 @@ func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...interfac
}
// GetTransferID returns an unique transfer ID for this connection
func (c *BaseConnection) GetTransferID() uint64 {
return atomic.AddUint64(&c.transferID, 1)
func (c *BaseConnection) GetTransferID() int64 {
return atomic.AddInt64(&c.transferID, 1)
}
// GetID returns the connection ID
@ -125,6 +125,27 @@ func (c *BaseConnection) AddTransfer(t ActiveTransfer) {
c.activeTransfers = append(c.activeTransfers, t)
c.Log(logger.LevelDebug, "transfer added, id: %v, active transfers: %v", t.GetID(), len(c.activeTransfers))
if t.GetMaxAllowedSize() > 0 {
folderName := ""
if t.GetType() == TransferUpload {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(t.GetVirtualPath()))
if err == nil {
if !vfolder.IsIncludedInUserQuota() {
folderName = vfolder.Name
}
}
}
go transfersChecker.AddTransfer(dataprovider.ActiveTransfer{
ID: t.GetID(),
Type: t.GetType(),
ConnID: c.ID,
Username: c.GetUsername(),
FolderName: folderName,
TruncatedSize: t.GetTruncatedSize(),
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
}
}
// RemoveTransfer removes the specified transfer from the active ones
@ -132,6 +153,10 @@ func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) {
c.Lock()
defer c.Unlock()
if t.GetMaxAllowedSize() > 0 {
go transfersChecker.RemoveTransfer(t.GetID(), c.ID)
}
for idx, transfer := range c.activeTransfers {
if transfer.GetID() == t.GetID() {
lastIdx := len(c.activeTransfers) - 1
@ -145,6 +170,20 @@ func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) {
c.Log(logger.LevelWarn, "transfer to remove with id %v not found!", t.GetID())
}
// SignalTransferClose makes the transfer fail on the next read/write with the
// specified error
func (c *BaseConnection) SignalTransferClose(transferID int64, err error) {
c.RLock()
defer c.RUnlock()
for _, t := range c.activeTransfers {
if t.GetID() == transferID {
c.Log(logger.LevelInfo, "signal transfer close for transfer id %v", transferID)
t.SignalClose(err)
}
}
}
// GetTransfers returns the active transfers
func (c *BaseConnection) GetTransfers() []ConnectionTransfer {
c.RLock()
@ -160,11 +199,14 @@ func (c *BaseConnection) GetTransfers() []ConnectionTransfer {
operationType = operationUpload
}
transfers = append(transfers, ConnectionTransfer{
ID: t.GetID(),
OperationType: operationType,
StartTime: util.GetTimeAsMsSinceEpoch(t.GetStartTime()),
Size: t.GetSize(),
VirtualPath: t.GetVirtualPath(),
ID: t.GetID(),
OperationType: operationType,
StartTime: util.GetTimeAsMsSinceEpoch(t.GetStartTime()),
Size: t.GetSize(),
VirtualPath: t.GetVirtualPath(),
MaxAllowedSize: t.GetMaxAllowedSize(),
ULSize: t.GetUploadedSize(),
DLSize: t.GetDownloadedSize(),
})
}
@ -181,7 +223,7 @@ func (c *BaseConnection) SignalTransfersAbort() error {
}
for _, t := range c.activeTransfers {
t.SignalClose()
t.SignalClose(ErrTransferAborted)
}
return nil
}
@ -1208,9 +1250,8 @@ func (c *BaseConnection) GetOpUnsupportedError() error {
}
}
// GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol
func (c *BaseConnection) GetQuotaExceededError() error {
switch c.protocol {
func getQuotaExceededError(protocol string) error {
switch protocol {
case ProtocolSFTP:
return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, ErrQuotaExceeded.Error())
case ProtocolFTP:
@ -1220,6 +1261,11 @@ func (c *BaseConnection) GetQuotaExceededError() error {
}
}
// GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol
func (c *BaseConnection) GetQuotaExceededError() error {
return getQuotaExceededError(c.protocol)
}
// IsQuotaExceededError returns true if the given error is a quota exceeded error
func (c *BaseConnection) IsQuotaExceededError(err error) bool {
switch c.protocol {

View file

@ -20,7 +20,7 @@ var (
// BaseTransfer contains protocols common transfer details for an upload or a download.
type BaseTransfer struct { //nolint:maligned
ID uint64
ID int64
BytesSent int64
BytesReceived int64
Fs vfs.Fs
@ -35,18 +35,21 @@ type BaseTransfer struct { //nolint:maligned
MaxWriteSize int64
MinWriteOffset int64
InitialSize int64
truncatedSize int64
isNewFile bool
transferType int
AbortTransfer int32
aTime time.Time
mTime time.Time
sync.Mutex
errAbort error
ErrTransfer error
}
// NewBaseTransfer returns a new BaseTransfer and adds it to the given connection
func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPath, effectiveFsPath, requestPath string,
transferType int, minWriteOffset, initialSize, maxWriteSize int64, isNewFile bool, fs vfs.Fs) *BaseTransfer {
transferType int, minWriteOffset, initialSize, maxWriteSize, truncatedSize int64, isNewFile bool, fs vfs.Fs,
) *BaseTransfer {
t := &BaseTransfer{
ID: conn.GetTransferID(),
File: file,
@ -64,6 +67,7 @@ func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPat
BytesReceived: 0,
MaxWriteSize: maxWriteSize,
AbortTransfer: 0,
truncatedSize: truncatedSize,
Fs: fs,
}
@ -77,7 +81,7 @@ func (t *BaseTransfer) SetFtpMode(mode string) {
}
// GetID returns the transfer ID
func (t *BaseTransfer) GetID() uint64 {
func (t *BaseTransfer) GetID() int64 {
return t.ID
}
@ -94,19 +98,53 @@ func (t *BaseTransfer) GetSize() int64 {
return atomic.LoadInt64(&t.BytesReceived)
}
// GetDownloadedSize returns the transferred size
func (t *BaseTransfer) GetDownloadedSize() int64 {
return atomic.LoadInt64(&t.BytesSent)
}
// GetUploadedSize returns the transferred size
func (t *BaseTransfer) GetUploadedSize() int64 {
return atomic.LoadInt64(&t.BytesReceived)
}
// GetStartTime returns the start time
func (t *BaseTransfer) GetStartTime() time.Time {
return t.start
}
// SignalClose signals that the transfer should be closed.
// For same protocols, for example WebDAV, we have no
// access to the network connection, so we use this method
// to make the next read or write to fail
func (t *BaseTransfer) SignalClose() {
// GetAbortError returns the error to send to the client if the transfer was aborted
func (t *BaseTransfer) GetAbortError() error {
t.Lock()
defer t.Unlock()
if t.errAbort != nil {
return t.errAbort
}
return getQuotaExceededError(t.Connection.protocol)
}
// SignalClose signals that the transfer should be closed after the next read/write.
// The optional error argument allow to send a specific error, otherwise a generic
// transfer aborted error is sent
func (t *BaseTransfer) SignalClose(err error) {
t.Lock()
t.errAbort = err
t.Unlock()
atomic.StoreInt32(&(t.AbortTransfer), 1)
}
// GetTruncatedSize returns the truncated sized if this is an upload overwriting
// an existing file
func (t *BaseTransfer) GetTruncatedSize() int64 {
return t.truncatedSize
}
// GetMaxAllowedSize returns the max allowed size
func (t *BaseTransfer) GetMaxAllowedSize() int64 {
return t.MaxWriteSize
}
// GetVirtualPath returns the transfer virtual path
func (t *BaseTransfer) GetVirtualPath() string {
return t.requestPath

View file

@ -65,7 +65,7 @@ func TestTransferThrottling(t *testing.T) {
wantedUploadElapsed -= wantedDownloadElapsed / 10
wantedDownloadElapsed -= wantedDownloadElapsed / 10
conn := NewBaseConnection("id", ProtocolSCP, "", "", u)
transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, true, fs)
transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs)
transfer.BytesReceived = testFileSize
transfer.Connection.UpdateLastActivity()
startTime := transfer.Connection.GetLastActivity()
@ -75,7 +75,7 @@ func TestTransferThrottling(t *testing.T) {
err := transfer.Close()
assert.NoError(t, err)
transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, true, fs)
transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs)
transfer.BytesSent = testFileSize
transfer.Connection.UpdateLastActivity()
startTime = transfer.Connection.GetLastActivity()
@ -101,7 +101,8 @@ func TestRealPath(t *testing.T) {
file, err := os.Create(testFile)
require.NoError(t, err)
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs)
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file",
TransferUpload, 0, 0, 0, 0, true, fs)
rPath := transfer.GetRealFsPath(testFile)
assert.Equal(t, testFile, rPath)
rPath = conn.getRealFsPath(testFile)
@ -138,7 +139,8 @@ func TestTruncate(t *testing.T) {
_, err = file.Write([]byte("hello"))
assert.NoError(t, err)
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 5, 100, false, fs)
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 5,
100, 0, false, fs)
err = conn.SetStat("/transfer_test_file", &StatAttributes{
Size: 2,
@ -155,7 +157,8 @@ func TestTruncate(t *testing.T) {
assert.Equal(t, int64(2), fi.Size())
}
transfer = NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 100, true, fs)
transfer = NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0,
100, 0, true, fs)
// file.Stat will fail on a closed file
err = conn.SetStat("/transfer_test_file", &StatAttributes{
Size: 2,
@ -165,7 +168,7 @@ func TestTruncate(t *testing.T) {
err = transfer.Close()
assert.NoError(t, err)
transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, true, fs)
transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, 0, true, fs)
_, err = transfer.Truncate("mismatch", 0)
assert.EqualError(t, err, errTransferMismatch.Error())
_, err = transfer.Truncate(testFile, 0)
@ -202,7 +205,8 @@ func TestTransferErrors(t *testing.T) {
assert.FailNow(t, "unable to open test file")
}
conn := NewBaseConnection("id", ProtocolSFTP, "", "", u)
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs)
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload,
0, 0, 0, 0, true, fs)
assert.Nil(t, transfer.cancelFn)
assert.Equal(t, testFile, transfer.GetFsPath())
transfer.SetCancelFn(cancelFn)
@ -228,7 +232,7 @@ func TestTransferErrors(t *testing.T) {
assert.FailNow(t, "unable to open test file")
}
fsPath := filepath.Join(os.TempDir(), "test_file")
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, true, fs)
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs)
transfer.BytesReceived = 9
transfer.TransferError(errFake)
assert.Error(t, transfer.ErrTransfer, errFake.Error())
@ -247,7 +251,7 @@ func TestTransferErrors(t *testing.T) {
if !assert.NoError(t, err) {
assert.FailNow(t, "unable to open test file")
}
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, true, fs)
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs)
transfer.BytesReceived = 9
// the file is closed from the embedding struct before to call close
err = file.Close()
@ -273,7 +277,8 @@ func TestRemovePartialCryptoFile(t *testing.T) {
},
}
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs)
transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload,
0, 0, 0, 0, true, fs)
transfer.ErrTransfer = errors.New("test error")
_, err = transfer.getUploadFileSize()
assert.Error(t, err)

167
common/transferschecker.go Normal file
View file

@ -0,0 +1,167 @@
package common
import (
"errors"
"sync"
"time"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
type overquotaTransfer struct {
ConnID string
TransferID int64
}
// TransfersChecker defines the interface that transfer checkers must implement.
// A transfer checker ensure that multiple concurrent transfers does not exceeded
// the remaining user quota
type TransfersChecker interface {
AddTransfer(transfer dataprovider.ActiveTransfer)
RemoveTransfer(ID int64, connectionID string)
UpdateTransferCurrentSize(ulSize int64, dlSize int64, ID int64, connectionID string)
GetOverquotaTransfers() []overquotaTransfer
}
func getTransfersChecker() TransfersChecker {
return &transfersCheckerMem{}
}
type transfersCheckerMem struct {
sync.RWMutex
transfers []dataprovider.ActiveTransfer
}
func (t *transfersCheckerMem) AddTransfer(transfer dataprovider.ActiveTransfer) {
t.Lock()
defer t.Unlock()
t.transfers = append(t.transfers, transfer)
}
func (t *transfersCheckerMem) RemoveTransfer(ID int64, connectionID string) {
t.Lock()
defer t.Unlock()
for idx, transfer := range t.transfers {
if transfer.ID == ID && transfer.ConnID == connectionID {
lastIdx := len(t.transfers) - 1
t.transfers[idx] = t.transfers[lastIdx]
t.transfers = t.transfers[:lastIdx]
return
}
}
}
func (t *transfersCheckerMem) UpdateTransferCurrentSize(ulSize int64, dlSize int64, ID int64, connectionID string) {
t.Lock()
defer t.Unlock()
for idx := range t.transfers {
if t.transfers[idx].ID == ID && t.transfers[idx].ConnID == connectionID {
t.transfers[idx].CurrentDLSize = dlSize
t.transfers[idx].CurrentULSize = ulSize
t.transfers[idx].UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
return
}
}
}
func (t *transfersCheckerMem) getRemainingDiskQuota(user dataprovider.User, folderName string) (int64, error) {
var result int64
if folderName != "" {
for _, folder := range user.VirtualFolders {
if folder.Name == folderName {
if folder.QuotaSize > 0 {
return folder.QuotaSize - folder.UsedQuotaSize, nil
}
}
}
} else {
if user.QuotaSize > 0 {
return user.QuotaSize - user.UsedQuotaSize, nil
}
}
return result, errors.New("no quota limit defined")
}
func (t *transfersCheckerMem) aggregateTransfers() (map[string]bool, map[string][]dataprovider.ActiveTransfer) {
t.RLock()
defer t.RUnlock()
usersToFetch := make(map[string]bool)
aggregations := make(map[string][]dataprovider.ActiveTransfer)
for _, transfer := range t.transfers {
key := transfer.GetKey()
aggregations[key] = append(aggregations[key], transfer)
if len(aggregations[key]) > 1 {
if transfer.FolderName != "" {
usersToFetch[transfer.Username] = true
} else {
if _, ok := usersToFetch[transfer.Username]; !ok {
usersToFetch[transfer.Username] = false
}
}
}
}
return usersToFetch, aggregations
}
func (t *transfersCheckerMem) GetOverquotaTransfers() []overquotaTransfer {
usersToFetch, aggregations := t.aggregateTransfers()
if len(usersToFetch) == 0 {
return nil
}
users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch)
if err != nil {
logger.Warn(logSender, "", "unable to check transfers, error getting users quota: %v", err)
return nil
}
usersMap := make(map[string]dataprovider.User)
for _, user := range users {
usersMap[user.Username] = user
}
var overquotaTransfers []overquotaTransfer
for _, transfers := range aggregations {
if len(transfers) > 1 {
username := transfers[0].Username
folderName := transfers[0].FolderName
// transfer type is always upload for now
remaningDiskQuota, err := t.getRemainingDiskQuota(usersMap[username], folderName)
if err != nil {
continue
}
var usedDiskQuota int64
for _, tr := range transfers {
// We optimistically assume that a cloud transfer that replaces an existing
// file will be successful
usedDiskQuota += tr.CurrentULSize - tr.TruncatedSize
}
logger.Debug(logSender, "", "username %#v, folder %#v, concurrent transfers: %v, remaining disk quota: %v, disk quota used in ongoing transfers: %v",
username, folderName, len(transfers), remaningDiskQuota, usedDiskQuota)
if usedDiskQuota > remaningDiskQuota {
for _, tr := range transfers {
if tr.CurrentULSize > tr.TruncatedSize {
overquotaTransfers = append(overquotaTransfers, overquotaTransfer{
ConnID: tr.ConnID,
TransferID: tr.ID,
})
}
}
}
}
}
return overquotaTransfers
}

View file

@ -0,0 +1,449 @@
package common
import (
"fmt"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/rs/xid"
"github.com/sftpgo/sdk"
"github.com/stretchr/testify/assert"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/vfs"
)
func TestTransfersCheckerDiskQuota(t *testing.T) {
username := "transfers_check_username"
folderName := "test_transfers_folder"
vdirPath := "/vdir"
user := dataprovider.User{
BaseUser: sdk.BaseUser{
Username: username,
Password: "testpwd",
HomeDir: filepath.Join(os.TempDir(), username),
Status: 1,
QuotaSize: 120,
Permissions: map[string][]string{
"/": {dataprovider.PermAny},
},
},
VirtualFolders: []vfs.VirtualFolder{
{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderName,
MappedPath: filepath.Join(os.TempDir(), folderName),
},
VirtualPath: vdirPath,
QuotaSize: 100,
},
},
}
err := dataprovider.AddUser(&user, "", "")
assert.NoError(t, err)
user, err = dataprovider.UserExists(username)
assert.NoError(t, err)
connID1 := xid.New().String()
fsUser, err := user.GetFilesystemForPath("/file1", connID1)
assert.NoError(t, err)
conn1 := NewBaseConnection(connID1, ProtocolSFTP, "", "", user)
fakeConn1 := &fakeConnection{
BaseConnection: conn1,
}
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
"/file1", TransferUpload, 0, 0, 120, 0, true, fsUser)
transfer1.BytesReceived = 150
Connections.Add(fakeConn1)
// the transferschecker will do nothing if there is only one ongoing transfer
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
connID2 := xid.New().String()
conn2 := NewBaseConnection(connID2, ProtocolSFTP, "", "", user)
fakeConn2 := &fakeConnection{
BaseConnection: conn2,
}
transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
"/file2", TransferUpload, 0, 0, 120, 40, true, fsUser)
transfer1.BytesReceived = 50
transfer2.BytesReceived = 60
Connections.Add(fakeConn2)
connID3 := xid.New().String()
conn3 := NewBaseConnection(connID3, ProtocolSFTP, "", "", user)
fakeConn3 := &fakeConnection{
BaseConnection: conn3,
}
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"),
"/file3", TransferDownload, 0, 0, 120, 0, true, fsUser)
transfer3.BytesReceived = 60 // this value will be ignored, this is a download
Connections.Add(fakeConn3)
// the transfers are not overquota
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
transfer1.BytesReceived = 80 // truncated size will be subtracted, we are not overquota
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
transfer1.BytesReceived = 120
// we are now overquota
// if another check is in progress nothing is done
atomic.StoreInt32(&Connections.transfersCheckStatus, 1)
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
atomic.StoreInt32(&Connections.transfersCheckStatus, 0)
Connections.checkTransfers()
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort))
assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort))
assert.True(t, conn1.IsQuotaExceededError(transfer1.GetAbortError()))
assert.Nil(t, transfer3.errAbort)
assert.True(t, conn3.IsQuotaExceededError(transfer3.GetAbortError()))
// update the user quota size
user.QuotaSize = 1000
err = dataprovider.UpdateUser(&user, "", "")
assert.NoError(t, err)
transfer1.errAbort = nil
transfer2.errAbort = nil
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
user.QuotaSize = 0
err = dataprovider.UpdateUser(&user, "", "")
assert.NoError(t, err)
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
// now check a public folder
transfer1.BytesReceived = 0
transfer2.BytesReceived = 0
connID4 := xid.New().String()
fsFolder, err := user.GetFilesystemForPath(path.Join(vdirPath, "/file1"), connID4)
assert.NoError(t, err)
conn4 := NewBaseConnection(connID4, ProtocolSFTP, "", "", user)
fakeConn4 := &fakeConnection{
BaseConnection: conn4,
}
transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(os.TempDir(), folderName, "file1"),
filepath.Join(os.TempDir(), folderName, "file1"), path.Join(vdirPath, "/file1"), TransferUpload, 0, 0,
100, 0, true, fsFolder)
Connections.Add(fakeConn4)
connID5 := xid.New().String()
conn5 := NewBaseConnection(connID5, ProtocolSFTP, "", "", user)
fakeConn5 := &fakeConnection{
BaseConnection: conn5,
}
transfer5 := NewBaseTransfer(nil, conn5, nil, filepath.Join(os.TempDir(), folderName, "file2"),
filepath.Join(os.TempDir(), folderName, "file2"), path.Join(vdirPath, "/file2"), TransferUpload, 0, 0,
100, 0, true, fsFolder)
Connections.Add(fakeConn5)
transfer4.BytesReceived = 50
transfer5.BytesReceived = 40
Connections.checkTransfers()
assert.Nil(t, transfer4.errAbort)
assert.Nil(t, transfer5.errAbort)
transfer5.BytesReceived = 60
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
assert.True(t, conn1.IsQuotaExceededError(transfer4.errAbort))
assert.True(t, conn2.IsQuotaExceededError(transfer5.errAbort))
if dataprovider.GetProviderStatus().Driver != dataprovider.MemoryDataProviderName {
providerConf := dataprovider.GetProviderConfig()
err = dataprovider.Close()
assert.NoError(t, err)
transfer4.errAbort = nil
transfer5.errAbort = nil
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
assert.Nil(t, transfer4.errAbort)
assert.Nil(t, transfer5.errAbort)
err = dataprovider.Initialize(providerConf, configDir, true)
assert.NoError(t, err)
}
Connections.Remove(fakeConn1.GetID())
Connections.Remove(fakeConn2.GetID())
Connections.Remove(fakeConn3.GetID())
Connections.Remove(fakeConn4.GetID())
Connections.Remove(fakeConn5.GetID())
stats := Connections.GetStats()
assert.Len(t, stats, 0)
err = dataprovider.DeleteUser(user.Username, "", "")
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
err = dataprovider.DeleteFolder(folderName, "", "")
assert.NoError(t, err)
err = os.RemoveAll(filepath.Join(os.TempDir(), folderName))
assert.NoError(t, err)
}
func TestAggregateTransfers(t *testing.T) {
checker := transfersCheckerMem{}
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferUpload,
ConnID: "1",
Username: "user",
FolderName: "",
TruncatedSize: 0,
CurrentULSize: 100,
CurrentDLSize: 0,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations := checker.aggregateTransfers()
assert.Len(t, usersToFetch, 0)
assert.Len(t, aggregations, 1)
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferDownload,
ConnID: "2",
Username: "user",
FolderName: "",
TruncatedSize: 0,
CurrentULSize: 0,
CurrentDLSize: 100,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations = checker.aggregateTransfers()
assert.Len(t, usersToFetch, 0)
assert.Len(t, aggregations, 2)
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferUpload,
ConnID: "3",
Username: "user",
FolderName: "folder",
TruncatedSize: 0,
CurrentULSize: 10,
CurrentDLSize: 0,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations = checker.aggregateTransfers()
assert.Len(t, usersToFetch, 0)
assert.Len(t, aggregations, 3)
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferUpload,
ConnID: "4",
Username: "user1",
FolderName: "",
TruncatedSize: 0,
CurrentULSize: 100,
CurrentDLSize: 0,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations = checker.aggregateTransfers()
assert.Len(t, usersToFetch, 0)
assert.Len(t, aggregations, 4)
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferUpload,
ConnID: "5",
Username: "user",
FolderName: "",
TruncatedSize: 0,
CurrentULSize: 100,
CurrentDLSize: 0,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations = checker.aggregateTransfers()
assert.Len(t, usersToFetch, 1)
val, ok := usersToFetch["user"]
assert.True(t, ok)
assert.False(t, val)
assert.Len(t, aggregations, 4)
aggregate, ok := aggregations["user0"]
assert.True(t, ok)
assert.Len(t, aggregate, 2)
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferUpload,
ConnID: "6",
Username: "user",
FolderName: "",
TruncatedSize: 0,
CurrentULSize: 100,
CurrentDLSize: 0,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations = checker.aggregateTransfers()
assert.Len(t, usersToFetch, 1)
val, ok = usersToFetch["user"]
assert.True(t, ok)
assert.False(t, val)
assert.Len(t, aggregations, 4)
aggregate, ok = aggregations["user0"]
assert.True(t, ok)
assert.Len(t, aggregate, 3)
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferUpload,
ConnID: "7",
Username: "user",
FolderName: "folder",
TruncatedSize: 0,
CurrentULSize: 10,
CurrentDLSize: 0,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations = checker.aggregateTransfers()
assert.Len(t, usersToFetch, 1)
val, ok = usersToFetch["user"]
assert.True(t, ok)
assert.True(t, val)
assert.Len(t, aggregations, 4)
aggregate, ok = aggregations["user0"]
assert.True(t, ok)
assert.Len(t, aggregate, 3)
aggregate, ok = aggregations["userfolder0"]
assert.True(t, ok)
assert.Len(t, aggregate, 2)
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferUpload,
ConnID: "8",
Username: "user",
FolderName: "",
TruncatedSize: 0,
CurrentULSize: 100,
CurrentDLSize: 0,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations = checker.aggregateTransfers()
assert.Len(t, usersToFetch, 1)
val, ok = usersToFetch["user"]
assert.True(t, ok)
assert.True(t, val)
assert.Len(t, aggregations, 4)
aggregate, ok = aggregations["user0"]
assert.True(t, ok)
assert.Len(t, aggregate, 4)
aggregate, ok = aggregations["userfolder0"]
assert.True(t, ok)
assert.Len(t, aggregate, 2)
}
func TestGetUsersForQuotaCheck(t *testing.T) {
usersToFetch := make(map[string]bool)
for i := 0; i < 50; i++ {
usersToFetch[fmt.Sprintf("user%v", i)] = i%2 == 0
}
users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch)
assert.NoError(t, err)
assert.Len(t, users, 0)
for i := 0; i < 40; i++ {
user := dataprovider.User{
BaseUser: sdk.BaseUser{
Username: fmt.Sprintf("user%v", i),
Password: "pwd",
HomeDir: filepath.Join(os.TempDir(), fmt.Sprintf("user%v", i)),
Status: 1,
QuotaSize: 120,
Permissions: map[string][]string{
"/": {dataprovider.PermAny},
},
},
VirtualFolders: []vfs.VirtualFolder{
{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: fmt.Sprintf("f%v", i),
MappedPath: filepath.Join(os.TempDir(), fmt.Sprintf("f%v", i)),
},
VirtualPath: "/vfolder",
QuotaSize: 100,
},
},
}
err = dataprovider.AddUser(&user, "", "")
assert.NoError(t, err)
err = dataprovider.UpdateVirtualFolderQuota(&vfs.BaseVirtualFolder{Name: fmt.Sprintf("f%v", i)}, 1, 50, false)
assert.NoError(t, err)
}
users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch)
assert.NoError(t, err)
assert.Len(t, users, 40)
for _, user := range users {
userIdxStr := strings.Replace(user.Username, "user", "", 1)
userIdx, err := strconv.Atoi(userIdxStr)
assert.NoError(t, err)
if userIdx%2 == 0 {
if assert.Len(t, user.VirtualFolders, 1, user.Username) {
assert.Equal(t, int64(100), user.VirtualFolders[0].QuotaSize)
assert.Equal(t, int64(50), user.VirtualFolders[0].UsedQuotaSize)
}
} else {
switch dataprovider.GetProviderStatus().Driver {
case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName,
dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName:
assert.Len(t, user.VirtualFolders, 0, user.Username)
}
}
}
for i := 0; i < 40; i++ {
err = dataprovider.DeleteUser(fmt.Sprintf("user%v", i), "", "")
assert.NoError(t, err)
err = dataprovider.DeleteFolder(fmt.Sprintf("f%v", i), "", "")
assert.NoError(t, err)
}
users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch)
assert.NoError(t, err)
assert.Len(t, users, 0)
}

View file

@ -647,6 +647,53 @@ func (p *BoltProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) {
return nil, nil
}
func (p *BoltProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
users := make([]User, 0, 30)
err := p.dbHandle.View(func(tx *bolt.Tx) error {
bucket, err := getUsersBucket(tx)
if err != nil {
return err
}
foldersBucket, err := getFoldersBucket(tx)
if err != nil {
return err
}
cursor := bucket.Cursor()
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
var user User
err := json.Unmarshal(v, &user)
if err != nil {
return err
}
needFolders, ok := toFetch[user.Username]
if !ok {
continue
}
if needFolders && len(user.VirtualFolders) > 0 {
var folders []vfs.VirtualFolder
for idx := range user.VirtualFolders {
folder := &user.VirtualFolders[idx]
baseFolder, err := folderExistsInternal(folder.Name, foldersBucket)
if err != nil {
continue
}
folder.BaseVirtualFolder = baseFolder
folders = append(folders, *folder)
}
user.VirtualFolders = folders
}
user.SetEmptySecretsIfNil()
user.PrepareForRendering()
users = append(users, user)
}
return nil
})
return users, err
}
func (p *BoltProvider) getUsers(limit int, offset int, order string) ([]User, error) {
users := make([]User, 0, limit)
var err error

View file

@ -381,6 +381,26 @@ func (c *Config) IsDefenderSupported() bool {
}
}
// ActiveTransfer defines an active protocol transfer
type ActiveTransfer struct {
ID int64
Type int
ConnID string
Username string
FolderName string
TruncatedSize int64
CurrentULSize int64
CurrentDLSize int64
CreatedAt int64
UpdatedAt int64
}
// GetKey returns an aggregation key.
// The same key will be returned for similar transfers
func (t *ActiveTransfer) GetKey() string {
return fmt.Sprintf("%v%v%v", t.Username, t.FolderName, t.Type)
}
// DefenderEntry defines a defender entry
type DefenderEntry struct {
ID int64 `json:"-"`
@ -476,6 +496,7 @@ type Provider interface {
getUsers(limit int, offset int, order string) ([]User, error)
dumpUsers() ([]User, error)
getRecentlyUpdatedUsers(after int64) ([]User, error)
getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error)
updateLastLogin(username string) error
updateAdminLastLogin(username string) error
setUpdatedAt(username string)
@ -1268,6 +1289,11 @@ func GetUsers(limit, offset int, order string) ([]User, error) {
return provider.getUsers(limit, offset, order)
}
// GetUsersForQuotaCheck returns the users with the fields required for a quota check
func GetUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
return provider.getUsersForQuotaCheck(toFetch)
}
// AddFolder adds a new virtual folder.
func AddFolder(folder *vfs.BaseVirtualFolder) error {
return provider.addFolder(folder)

View file

@ -349,6 +349,7 @@ func (p *MemoryProvider) dumpUsers() ([]User, error) {
for _, username := range p.dbHandle.usernames {
u := p.dbHandle.users[username]
user := u.getACopy()
p.addVirtualFoldersToUser(&user)
err = addCredentialsToUser(&user)
if err != nil {
return users, err
@ -376,6 +377,28 @@ func (p *MemoryProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) {
return nil, nil
}
func (p *MemoryProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
users := make([]User, 0, 30)
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return users, errMemoryProviderClosed
}
for _, username := range p.dbHandle.usernames {
if val, ok := toFetch[username]; ok {
u := p.dbHandle.users[username]
user := u.getACopy()
if val {
p.addVirtualFoldersToUser(&user)
}
user.PrepareForRendering()
users = append(users, user)
}
}
return users, nil
}
func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User, error) {
users := make([]User, 0, limit)
var err error
@ -396,6 +419,7 @@ func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User,
}
u := p.dbHandle.users[username]
user := u.getACopy()
p.addVirtualFoldersToUser(&user)
user.PrepareForRendering()
users = append(users, user)
if len(users) >= limit {
@ -411,6 +435,7 @@ func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User,
username := p.dbHandle.usernames[i]
u := p.dbHandle.users[username]
user := u.getACopy()
p.addVirtualFoldersToUser(&user)
user.PrepareForRendering()
users = append(users, user)
if len(users) >= limit {
@ -427,7 +452,12 @@ func (p *MemoryProvider) userExists(username string) (User, error) {
if p.dbHandle.isClosed {
return User{}, errMemoryProviderClosed
}
return p.userExistsInternal(username)
user, err := p.userExistsInternal(username)
if err != nil {
return user, err
}
p.addVirtualFoldersToUser(&user)
return user, nil
}
func (p *MemoryProvider) userExistsInternal(username string) (User, error) {
@ -632,6 +662,22 @@ func (p *MemoryProvider) joinVirtualFoldersFields(user *User) []vfs.VirtualFolde
return folders
}
func (p *MemoryProvider) addVirtualFoldersToUser(user *User) {
if len(user.VirtualFolders) > 0 {
var folders []vfs.VirtualFolder
for idx := range user.VirtualFolders {
folder := &user.VirtualFolders[idx]
baseFolder, err := p.folderExistsInternal(folder.Name)
if err != nil {
continue
}
folder.BaseVirtualFolder = baseFolder.GetACopy()
folders = append(folders, *folder)
}
user.VirtualFolders = folders
}
}
func (p *MemoryProvider) removeUserFromFolderMapping(folderName, username string) {
folder, err := p.folderExistsInternal(folderName)
if err == nil {
@ -655,7 +701,8 @@ func (p *MemoryProvider) updateFoldersMappingInternal(folder vfs.BaseVirtualFold
}
func (p *MemoryProvider) addOrUpdateFolderInternal(baseFolder *vfs.BaseVirtualFolder, username string, usedQuotaSize int64,
usedQuotaFiles int, lastQuotaUpdate int64) (vfs.BaseVirtualFolder, error) {
usedQuotaFiles int, lastQuotaUpdate int64) (vfs.BaseVirtualFolder, error,
) {
folder, err := p.folderExistsInternal(baseFolder.Name)
if err == nil {
// exists

View file

@ -186,6 +186,10 @@ func (p *MySQLProvider) getUsers(limit int, offset int, order string) ([]User, e
return sqlCommonGetUsers(limit, offset, order, p.dbHandle)
}
func (p *MySQLProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle)
}
func (p *MySQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) {
return sqlCommonDumpFolders(p.dbHandle)
}

View file

@ -198,6 +198,10 @@ func (p *PGSQLProvider) getUsers(limit int, offset int, order string) ([]User, e
return sqlCommonGetUsers(limit, offset, order, p.dbHandle)
}
func (p *PGSQLProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle)
}
func (p *PGSQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) {
return sqlCommonDumpFolders(p.dbHandle)
}

View file

@ -939,6 +939,90 @@ func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User,
return getUsersWithVirtualFolders(ctx, users, dbHandle)
}
func sqlCommonGetUsersForQuotaCheck(toFetch map[string]bool, dbHandle sqlQuerier) ([]User, error) {
users := make([]User, 0, 30)
usernames := make([]string, 0, len(toFetch))
for k := range toFetch {
usernames = append(usernames, k)
}
maxUsers := 30
for len(usernames) > 0 {
if maxUsers > len(usernames) {
maxUsers = len(usernames)
}
usersRange, err := sqlCommonGetUsersRangeForQuotaCheck(usernames[:maxUsers], dbHandle)
if err != nil {
return users, err
}
users = append(users, usersRange...)
usernames = usernames[maxUsers:]
}
var usersWithFolders []User
validIdx := 0
for _, user := range users {
if toFetch[user.Username] {
usersWithFolders = append(usersWithFolders, user)
} else {
users[validIdx] = user
validIdx++
}
}
users = users[:validIdx]
if len(usersWithFolders) == 0 {
return users, nil
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
usersWithFolders, err := getUsersWithVirtualFolders(ctx, usersWithFolders, dbHandle)
if err != nil {
return users, err
}
users = append(users, usersWithFolders...)
return users, nil
}
func sqlCommonGetUsersRangeForQuotaCheck(usernames []string, dbHandle sqlQuerier) ([]User, error) {
users := make([]User, 0, len(usernames))
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getUsersForQuotaCheckQuery(len(usernames))
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return users, err
}
defer stmt.Close()
queryArgs := make([]interface{}, 0, len(usernames))
for idx := range usernames {
queryArgs = append(queryArgs, usernames[idx])
}
rows, err := stmt.QueryContext(ctx, queryArgs...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var user User
err = rows.Scan(&user.ID, &user.Username, &user.QuotaSize, &user.UsedQuotaSize)
if err != nil {
return users, err
}
users = append(users, user)
}
return users, rows.Err()
}
func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier) ([]User, error) {
users := make([]User, 0, limit)
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)

View file

@ -183,6 +183,10 @@ func (p *SQLiteProvider) getUsers(limit int, offset int, order string) ([]User,
return sqlCommonGetUsers(limit, offset, order, p.dbHandle)
}
func (p *SQLiteProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle)
}
func (p *SQLiteProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) {
return sqlCommonDumpFolders(p.dbHandle)
}

View file

@ -21,7 +21,7 @@ const (
func getSQLPlaceholders() []string {
var placeholders []string
for i := 1; i <= 30; i++ {
for i := 1; i <= 50; i++ {
if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName {
placeholders = append(placeholders, fmt.Sprintf("$%v", i))
} else {
@ -263,6 +263,23 @@ func getUsersQuery(order string) string {
order, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getUsersForQuotaCheckQuery(numArgs int) string {
var sb strings.Builder
for idx := 0; idx < numArgs; idx++ {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(sqlPlaceholders[idx])
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT id,username,quota_size,used_quota_size FROM %v WHERE username IN %v`,
sqlTableUsers, sb.String())
}
func getRecentlyUpdatedUsersQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v WHERE updated_at >= %v`, selectUserFields, sqlTableUsers, sqlPlaceholders[0])
}

View file

@ -335,8 +335,8 @@ func (c *Connection) downloadFile(fs vfs.Fs, fsPath, ftpPath string, offset int6
return nil, c.GetFsError(fs, err)
}
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath, common.TransferDownload,
0, 0, 0, false, fs)
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath,
common.TransferDownload, 0, 0, 0, 0, false, fs)
baseTransfer.SetFtpMode(c.getFTPMode())
t := newTransfer(baseTransfer, nil, r, offset)
@ -402,7 +402,7 @@ func (c *Connection) handleFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath,
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, 0, 0, maxWriteSize, true, fs)
common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
baseTransfer.SetFtpMode(c.getFTPMode())
t := newTransfer(baseTransfer, w, nil, 0)
@ -452,6 +452,7 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve
}
initialSize := int64(0)
truncatedSize := int64(0) // bytes truncated and not included in quota
if isResume {
c.Log(logger.LevelDebug, "resuming upload requested, file path: %#v initial size: %v", filePath, fileSize)
minWriteOffset = fileSize
@ -473,13 +474,14 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve
}
} else {
initialSize = fileSize
truncatedSize = fileSize
}
}
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs)
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs)
baseTransfer.SetFtpMode(c.getFTPMode())
t := newTransfer(baseTransfer, w, nil, 0)

View file

@ -808,7 +808,7 @@ func TestTransferErrors(t *testing.T) {
clientContext: mockCC,
}
baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, file.Name(), file.Name(), testfile,
common.TransferDownload, 0, 0, 0, false, fs)
common.TransferDownload, 0, 0, 0, 0, false, fs)
tr := newTransfer(baseTransfer, nil, nil, 0)
err = tr.Close()
assert.NoError(t, err)
@ -826,7 +826,7 @@ func TestTransferErrors(t *testing.T) {
r, _, err := pipeat.Pipe()
assert.NoError(t, err)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile,
common.TransferUpload, 0, 0, 0, false, fs)
common.TransferUpload, 0, 0, 0, 0, false, fs)
tr = newTransfer(baseTransfer, nil, r, 10)
pos, err := tr.Seek(10, 0)
assert.NoError(t, err)
@ -838,7 +838,7 @@ func TestTransferErrors(t *testing.T) {
assert.NoError(t, err)
pipeWriter := vfs.NewPipeWriter(w)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile,
common.TransferUpload, 0, 0, 0, false, fs)
common.TransferUpload, 0, 0, 0, 0, false, fs)
tr = newTransfer(baseTransfer, pipeWriter, nil, 0)
err = r.Close()

16
go.mod
View file

@ -7,8 +7,8 @@ require (
github.com/Azure/azure-storage-blob-go v0.14.0
github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962
github.com/alexedwards/argon2id v0.0.0-20211130144151-3585854a6387
github.com/aws/aws-sdk-go v1.42.35
github.com/cockroachdb/cockroach-go/v2 v2.2.5
github.com/aws/aws-sdk-go v1.42.37
github.com/cockroachdb/cockroach-go/v2 v2.2.6
github.com/eikenb/pipeat v0.0.0-20210603033007-44fc3ffce52b
github.com/fclairamb/ftpserverlib v0.17.0
github.com/fclairamb/go-log v0.2.0
@ -35,7 +35,7 @@ require (
github.com/pires/go-proxyproto v0.6.1
github.com/pkg/sftp v1.13.5-0.20211217081921-1849af66afae
github.com/pquerna/otp v1.3.0
github.com/prometheus/client_golang v1.11.0
github.com/prometheus/client_golang v1.12.0
github.com/rs/cors v1.8.2
github.com/rs/xid v1.3.0
github.com/rs/zerolog v1.26.2-0.20211219225053-665519c4da50
@ -62,8 +62,8 @@ require (
require (
cloud.google.com/go v0.100.2 // indirect
cloud.google.com/go/compute v1.0.0 // indirect
cloud.google.com/go/iam v0.1.0 // indirect
cloud.google.com/go/compute v1.1.0 // indirect
cloud.google.com/go/iam v0.1.1 // indirect
github.com/Azure/azure-pipeline-go v0.2.3 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/boombuler/barcode v1.0.1 // indirect
@ -79,7 +79,7 @@ require (
github.com/goccy/go-json v0.9.3 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/google/go-cmp v0.5.6 // indirect
github.com/google/go-cmp v0.5.7 // indirect
github.com/googleapis/gax-go/v2 v2.1.1 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
@ -126,10 +126,10 @@ require (
golang.org/x/tools v0.1.8 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0 // indirect
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 // indirect
google.golang.org/grpc v1.43.0 // indirect
google.golang.org/protobuf v1.27.1 // indirect
gopkg.in/ini.v1 v1.66.2 // indirect
gopkg.in/ini.v1 v1.66.3 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)

31
go.sum
View file

@ -46,14 +46,14 @@ cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUM
cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc=
cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ=
cloud.google.com/go/compute v0.1.0/go.mod h1:GAesmwr110a34z04OlxYkATPBEfVhkymfTBXtfbBFow=
cloud.google.com/go/compute v1.0.0 h1:SJYBzih8Jj9EUm6IDirxKG0I0AGWduhtb6BmdqWarw4=
cloud.google.com/go/compute v1.0.0/go.mod h1:GAesmwr110a34z04OlxYkATPBEfVhkymfTBXtfbBFow=
cloud.google.com/go/compute v1.1.0 h1:pyPhehLfZ6pVzRgJmXGYvCY4K7WSWRhVw0AwhgVvS84=
cloud.google.com/go/compute v1.1.0/go.mod h1:2NIffxgWfORSI7EOYMFatGTfjMLnqrOKBEyYb6NoRgA=
cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
cloud.google.com/go/firestore v1.5.0/go.mod h1:c4nNYR1qdq7eaZ+jSc5fonrQN2k3M7sWATcYTiakjEo=
cloud.google.com/go/firestore v1.6.1/go.mod h1:asNXNOzBdyVQmEU+ggO8UPodTkEVFW5Qx+rwHnAz+EY=
cloud.google.com/go/iam v0.1.0 h1:W2vbGCrE3Z7J/x3WXLxxGl9LMSB2uhsAA7Ss/6u/qRY=
cloud.google.com/go/iam v0.1.0/go.mod h1:vcUNEa0pEm0qRVpmWepWaFMIAI8/hjB9mO8rNCJtF6c=
cloud.google.com/go/iam v0.1.1 h1:4CapQyNFjiksks1/x7jsvsygFPhihslYk5GptIrlX68=
cloud.google.com/go/iam v0.1.1/go.mod h1:CKqrcnI/suGpybEHxZ7BMehL0oA4LpdyJdUlTl9jVMw=
cloud.google.com/go/kms v0.1.0 h1:VXAb5OzejDcyhFzIDeZ5n5AUdlsFnCyexuascIwWMj0=
cloud.google.com/go/kms v0.1.0/go.mod h1:8Qp8PCAypHg4FdmlyW1QRAv09BGQ9Uzh7JnmIZxPk+c=
cloud.google.com/go/monitoring v0.1.0/go.mod h1:Hpm3XfzJv+UTiXzCG5Ffp0wijzHTC7Cv4eR7o3x/fEE=
@ -141,8 +141,8 @@ github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgI
github.com/aws/aws-sdk-go v1.15.27/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0=
github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro=
github.com/aws/aws-sdk-go v1.40.34/go.mod h1:585smgzpB/KqRA+K3y/NL/oYRqQvpNJYvLm+LY1U59Q=
github.com/aws/aws-sdk-go v1.42.35 h1:N4N9buNs4YlosI9N0+WYrq8cIZwdgv34yRbxzZlTvFs=
github.com/aws/aws-sdk-go v1.42.35/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc=
github.com/aws/aws-sdk-go v1.42.37 h1:EIziSq3REaoi1LgUBgxoQr29DQS7GYHnBbZPajtJmXM=
github.com/aws/aws-sdk-go v1.42.37/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc=
github.com/aws/aws-sdk-go-v2 v1.9.0/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4=
github.com/aws/aws-sdk-go-v2/config v1.7.0/go.mod h1:w9+nMZ7soXCe5nT46Ri354SNhXDQ6v+V5wqDjnZE+GY=
github.com/aws/aws-sdk-go-v2/credentials v1.4.0/go.mod h1:dgGR+Qq7Wjcd4AOAW5Rf5Tnv3+x7ed6kETXyS9WCuAY=
@ -190,8 +190,8 @@ github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWH
github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20211130200136-a8f946100490/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/cockroachdb/cockroach-go/v2 v2.2.5 h1:tfPdGHO5YpmrpN2ikJZYpaSGgU8WALwwjH3s+msiTQ0=
github.com/cockroachdb/cockroach-go/v2 v2.2.5/go.mod h1:q4ZRgO6CQpwNyEvEwSxwNrOSVchsmzrBnAv3HuZ3Abc=
github.com/cockroachdb/cockroach-go/v2 v2.2.6 h1:LTh++UIVvmDBihDo1oYbM8+OruXheusw+ILCONlAm/w=
github.com/cockroachdb/cockroach-go/v2 v2.2.6/go.mod h1:q4ZRgO6CQpwNyEvEwSxwNrOSVchsmzrBnAv3HuZ3Abc=
github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f h1:JOrtw2xFKzlg+cbHpyrpLDmnN1HqhBfnX7WDiW7eG2c=
@ -343,8 +343,9 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/google/go-replayers/grpcreplay v1.1.0/go.mod h1:qzAvJ8/wi57zq7gWqaE6AwLM6miiXUQwP1S+I9icmhk=
github.com/google/go-replayers/httpreplay v1.0.0/go.mod h1:LJhKoTwS5Wy5Ld/peq8dFFG5OfJyHEz7ft+DsTUv25M=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@ -649,8 +650,9 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0 h1:HNkLOAEQMIDv/K+04rukrLx6ch7msSRwf3/SASFAGtQ=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.12.0 h1:C+UIj/QWtmqY13Arb8kwMt5j34/0Z2iKamrJ+ryC0Gg=
github.com/prometheus/client_golang v1.12.0/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@ -1075,6 +1077,7 @@ google.golang.org/api v0.59.0/go.mod h1:sT2boj7M9YJxZzgeZqXogmhfmRWDtPzT31xkieUb
google.golang.org/api v0.61.0/go.mod h1:xQRti5UdCmoCEqFxcz93fTl338AVqDgyaDRuOZ3hg9I=
google.golang.org/api v0.62.0/go.mod h1:dKmwPCydfsad4qCH08MSdgWjfHOyfpd4VtDGgRFdavw=
google.golang.org/api v0.63.0/go.mod h1:gs4ij2ffTRXwuzzgJl/56BdwJaA194ijkfn++9tDuPo=
google.golang.org/api v0.64.0/go.mod h1:931CdxA8Rm4t6zqTFGSsgwbAEZ2+GMYurbndwSimebM=
google.golang.org/api v0.65.0 h1:MTW9c+LIBAbwoS1Gb+YV7NjFBt2f7GtAS5hIzh2NjgQ=
google.golang.org/api v0.65.0/go.mod h1:ArYhxgGadlWmqO1IqVujw6Cs8IdD33bTmzKo2Sh+cbg=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
@ -1162,8 +1165,9 @@ google.golang.org/genproto v0.0.0-20211208223120-3a66f561d7aa/go.mod h1:5CzLGKJ6
google.golang.org/genproto v0.0.0-20211221195035-429b39de9b1c/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20220107163113-42d7afdf6368/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0 h1:aCsSLXylHWFno0r4S3joLpiaWayvqd2Mn4iSvx4WZZc=
google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20220111164026-67b88f271998/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 h1:zzNejm+EgrbLfDZ6lu9Uud2IVvHySPl8vQzf04laR5Q=
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
@ -1217,8 +1221,9 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
gopkg.in/ini.v1 v1.66.2 h1:XfR1dOYubytKy4Shzc2LHrrGhU0lDCfDGG1yLPmpgsI=
gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/ini.v1 v1.66.3 h1:jRskFVxYaMGAMUbN0UZ7niA9gzL9B49DOqE78vg0k3w=
gopkg.in/ini.v1 v1.66.3/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8=
gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View file

@ -1,7 +1,6 @@
package httpd
import (
"errors"
"io"
"sync/atomic"
@ -11,8 +10,6 @@ import (
"github.com/drakkan/sftpgo/v2/vfs"
)
var errTransferAborted = errors.New("transfer aborted")
type httpdFile struct {
*common.BaseTransfer
writer io.WriteCloser
@ -42,7 +39,9 @@ func newHTTPDFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter,
// Read reads the contents to downloads.
func (f *httpdFile) Read(p []byte) (n int, err error) {
if atomic.LoadInt32(&f.AbortTransfer) == 1 {
return 0, errTransferAborted
err := f.GetAbortError()
f.TransferError(err)
return 0, err
}
f.Connection.UpdateLastActivity()
@ -61,7 +60,9 @@ func (f *httpdFile) Read(p []byte) (n int, err error) {
// Write writes the contents to upload
func (f *httpdFile) Write(p []byte) (n int, err error) {
if atomic.LoadInt32(&f.AbortTransfer) == 1 {
return 0, errTransferAborted
err := f.GetAbortError()
f.TransferError(err)
return 0, err
}
f.Connection.UpdateLastActivity()

View file

@ -6,6 +6,7 @@ import (
"os"
"path"
"strings"
"sync"
"sync/atomic"
"time"
@ -113,7 +114,7 @@ func (c *Connection) getFileReader(name string, offset int64, method string) (io
}
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, name, common.TransferDownload,
0, 0, 0, false, fs)
0, 0, 0, 0, false, fs)
return newHTTPDFile(baseTransfer, nil, r), nil
}
@ -190,6 +191,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
}
initialSize := int64(0)
truncatedSize := int64(0) // bytes truncated and not included in quota
if !isNewFile {
if vfs.IsLocalOrSFTPFs(fs) {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
@ -203,6 +205,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
}
} else {
initialSize = fileSize
truncatedSize = fileSize
}
if maxWriteSize > 0 {
maxWriteSize += fileSize
@ -212,7 +215,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs)
common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs)
return newHTTPDFile(baseTransfer, w, nil), nil
}
@ -232,15 +235,17 @@ func newThrottledReader(r io.ReadCloser, limit int64, conn *Connection) *throttl
type throttledReader struct {
bytesRead int64
id uint64
id int64
limit int64
r io.ReadCloser
abortTransfer int32
start time.Time
conn *Connection
mu sync.Mutex
errAbort error
}
func (t *throttledReader) GetID() uint64 {
func (t *throttledReader) GetID() int64 {
return t.id
}
@ -252,6 +257,14 @@ func (t *throttledReader) GetSize() int64 {
return atomic.LoadInt64(&t.bytesRead)
}
func (t *throttledReader) GetDownloadedSize() int64 {
return 0
}
func (t *throttledReader) GetUploadedSize() int64 {
return atomic.LoadInt64(&t.bytesRead)
}
func (t *throttledReader) GetVirtualPath() string {
return "**reading request body**"
}
@ -260,10 +273,31 @@ func (t *throttledReader) GetStartTime() time.Time {
return t.start
}
func (t *throttledReader) SignalClose() {
func (t *throttledReader) GetAbortError() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.errAbort != nil {
return t.errAbort
}
return common.ErrTransferAborted
}
func (t *throttledReader) SignalClose(err error) {
t.mu.Lock()
t.errAbort = err
t.mu.Unlock()
atomic.StoreInt32(&(t.abortTransfer), 1)
}
func (t *throttledReader) GetTruncatedSize() int64 {
return 0
}
func (t *throttledReader) GetMaxAllowedSize() int64 {
return 0
}
func (t *throttledReader) Truncate(fsPath string, size int64) (int64, error) {
return 0, vfs.ErrVfsUnsupported
}
@ -278,7 +312,7 @@ func (t *throttledReader) SetTimes(fsPath string, atime time.Time, mtime time.Ti
func (t *throttledReader) Read(p []byte) (n int, err error) {
if atomic.LoadInt32(&t.abortTransfer) == 1 {
return 0, errTransferAborted
return 0, t.GetAbortError()
}
t.conn.UpdateLastActivity()

View file

@ -1844,12 +1844,15 @@ func TestThrottledHandler(t *testing.T) {
tr := &throttledReader{
r: io.NopCloser(bytes.NewBuffer(nil)),
}
assert.Equal(t, int64(0), tr.GetTruncatedSize())
err := tr.Close()
assert.NoError(t, err)
assert.Empty(t, tr.GetRealFsPath("real path"))
assert.False(t, tr.SetTimes("p", time.Now(), time.Now()))
_, err = tr.Truncate("", 0)
assert.ErrorIs(t, err, vfs.ErrVfsUnsupported)
err = tr.GetAbortError()
assert.ErrorIs(t, err, common.ErrTransferAborted)
}
func TestHTTPDFile(t *testing.T) {
@ -1879,7 +1882,7 @@ func TestHTTPDFile(t *testing.T) {
assert.NoError(t, err)
baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, p, p, name, common.TransferDownload,
0, 0, 0, false, fs)
0, 0, 0, 0, false, fs)
httpdFile := newHTTPDFile(baseTransfer, nil, nil)
// the file is closed, read should fail
buf := make([]byte, 100)
@ -1899,9 +1902,9 @@ func TestHTTPDFile(t *testing.T) {
assert.Error(t, err)
assert.Error(t, httpdFile.ErrTransfer)
assert.Equal(t, err, httpdFile.ErrTransfer)
httpdFile.SignalClose()
httpdFile.SignalClose(nil)
_, err = httpdFile.Write(nil)
assert.ErrorIs(t, err, errTransferAborted)
assert.ErrorIs(t, err, common.ErrQuotaExceeded)
}
func TestChangeUserPwd(t *testing.T) {

View file

@ -85,7 +85,7 @@ func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
}
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, request.Filepath, common.TransferDownload,
0, 0, 0, false, fs)
0, 0, 0, 0, false, fs)
t := newTransfer(baseTransfer, nil, r, nil)
return t, nil
@ -364,7 +364,7 @@ func (c *Connection) handleSFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, 0, 0, maxWriteSize, true, fs)
common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
t := newTransfer(baseTransfer, w, nil, errForRead)
return t, nil
@ -415,6 +415,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO
}
initialSize := int64(0)
truncatedSize := int64(0) // bytes truncated and not included in quota
if isResume {
c.Log(logger.LevelDebug, "resuming upload requested, file path %#v initial size: %v has append flag %v",
filePath, fileSize, pflags.Append)
@ -436,13 +437,14 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO
}
} else {
initialSize = fileSize
truncatedSize = fileSize
}
}
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs)
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs)
t := newTransfer(baseTransfer, w, nil, errForRead)
return t, nil

View file

@ -162,7 +162,8 @@ func TestUploadResumeInvalidOffset(t *testing.T) {
}
fs := vfs.NewOsFs("", os.TempDir(), "")
conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user)
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferUpload, 10, 0, 0, false, fs)
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile,
common.TransferUpload, 10, 0, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
_, err = transfer.WriteAt([]byte("test"), 0)
assert.Error(t, err, "upload with invalid offset must fail")
@ -193,7 +194,8 @@ func TestReadWriteErrors(t *testing.T) {
}
fs := vfs.NewOsFs("", os.TempDir(), "")
conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user)
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload,
0, 0, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
err = file.Close()
assert.NoError(t, err)
@ -207,7 +209,8 @@ func TestReadWriteErrors(t *testing.T) {
r, _, err := pipeat.Pipe()
assert.NoError(t, err)
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload,
0, 0, 0, 0, false, fs)
transfer = newTransfer(baseTransfer, nil, r, nil)
err = transfer.Close()
assert.NoError(t, err)
@ -217,7 +220,8 @@ func TestReadWriteErrors(t *testing.T) {
r, w, err := pipeat.Pipe()
assert.NoError(t, err)
pipeWriter := vfs.NewPipeWriter(w)
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload,
0, 0, 0, 0, false, fs)
transfer = newTransfer(baseTransfer, pipeWriter, nil, nil)
err = r.Close()
@ -264,7 +268,8 @@ func TestTransferCancelFn(t *testing.T) {
}
fs := vfs.NewOsFs("", os.TempDir(), "")
conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user)
baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), file.Name(), testfile, common.TransferDownload,
0, 0, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
errFake := errors.New("fake error, this will trigger cancelFn")
@ -971,8 +976,8 @@ func TestSystemCommandErrors(t *testing.T) {
WriteError: nil,
}
sshCmd.connection.channel = &mockSSHChannel
baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", "", common.TransferDownload,
0, 0, 0, false, fs)
baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", "",
common.TransferDownload, 0, 0, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
destBuff := make([]byte, 65535)
dst := bytes.NewBuffer(destBuff)
@ -1639,7 +1644,7 @@ func TestSCPUploadFiledata(t *testing.T) {
assert.NoError(t, err)
baseTransfer := common.NewBaseTransfer(file, scpCommand.connection.BaseConnection, nil, file.Name(), file.Name(),
"/"+testfile, common.TransferDownload, 0, 0, 0, true, fs)
"/"+testfile, common.TransferDownload, 0, 0, 0, 0, true, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
err = scpCommand.getUploadFileData(2, transfer)
@ -1724,7 +1729,7 @@ func TestUploadError(t *testing.T) {
file, err := os.Create(fileTempName)
assert.NoError(t, err)
baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, testfile, file.Name(),
testfile, common.TransferUpload, 0, 0, 0, true, fs)
testfile, common.TransferUpload, 0, 0, 0, 0, true, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
errFake := errors.New("fake error")
@ -1782,7 +1787,8 @@ func TestTransferFailingReader(t *testing.T) {
r, _, err := pipeat.Pipe()
assert.NoError(t, err)
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, fsPath, filepath.Base(fsPath), common.TransferUpload, 0, 0, 0, false, fs)
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, fsPath, filepath.Base(fsPath),
common.TransferUpload, 0, 0, 0, 0, false, fs)
errRead := errors.New("read is not allowed")
tr := newTransfer(baseTransfer, nil, r, errRead)
_, err = tr.ReadAt(buf, 0)

View file

@ -238,6 +238,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
}
initialSize := int64(0)
truncatedSize := int64(0) // bytes truncated and not included in quota
if !isNewFile {
if vfs.IsLocalOrSFTPFs(fs) {
vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
@ -251,6 +252,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
}
} else {
initialSize = fileSize
truncatedSize = initialSize
}
if maxWriteSize > 0 {
maxWriteSize += fileSize
@ -260,7 +262,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
vfs.SetPathPermissions(fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID())
baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs)
common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs)
t := newTransfer(baseTransfer, w, nil, nil)
return c.getUploadFileData(sizeToRead, t)
@ -529,7 +531,7 @@ func (c *scpCommand) handleDownload(filePath string) error {
}
baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, p, filePath,
common.TransferDownload, 0, 0, 0, false, fs)
common.TransferDownload, 0, 0, 0, 0, false, fs)
t := newTransfer(baseTransfer, nil, r, nil)
err = c.sendDownloadFileData(fs, p, stat, t)

View file

@ -356,7 +356,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
go func() {
defer stdin.Close()
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
common.TransferUpload, 0, 0, remainingQuotaSize, false, command.fs)
common.TransferUpload, 0, 0, remainingQuotaSize, 0, false, command.fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
w, e := transfer.copyFromReaderToWriter(stdin, c.connection.channel)
@ -369,7 +369,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
go func() {
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
common.TransferDownload, 0, 0, 0, false, command.fs)
common.TransferDownload, 0, 0, 0, 0, false, command.fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
w, e := transfer.copyFromReaderToWriter(c.connection.channel, stdout)
@ -383,7 +383,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
go func() {
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
common.TransferDownload, 0, 0, 0, false, command.fs)
common.TransferDownload, 0, 0, 0, 0, false, command.fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
w, e := transfer.copyFromReaderToWriter(c.connection.channel.(ssh.Channel).Stderr(), stderr)

View file

@ -4,23 +4,23 @@ go 1.17
require (
github.com/hashicorp/go-plugin v1.4.3
github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a
github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea
)
require (
github.com/fatih/color v1.13.0 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/google/go-cmp v0.5.6 // indirect
github.com/hashicorp/go-hclog v1.0.0 // indirect
github.com/hashicorp/go-hclog v1.1.0 // indirect
github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/mitchellh/go-testing-interface v1.14.1 // indirect
github.com/oklog/run v1.1.0 // indirect
golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98 // indirect
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e // indirect
golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d // indirect
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect
golang.org/x/text v0.3.7 // indirect
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb // indirect
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 // indirect
google.golang.org/grpc v1.43.0 // indirect
google.golang.org/protobuf v1.27.1 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect

View file

@ -57,8 +57,9 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/hashicorp/go-hclog v0.14.1/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
github.com/hashicorp/go-hclog v1.0.0 h1:bkKf0BeBXcSYa7f5Fyi9gMuQ8gNsxeiNpZjR6VxNZeo=
github.com/hashicorp/go-hclog v1.0.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
github.com/hashicorp/go-hclog v1.1.0 h1:QsGcniKx5/LuX2eYoeL+Np3UKYPNaN7YKpTh29h8rbw=
github.com/hashicorp/go-hclog v1.1.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
github.com/hashicorp/go-plugin v1.4.3 h1:DXmvivbWD5qdiBts9TpBC7BYL1Aia5sxbRgQB+v6UZM=
github.com/hashicorp/go-plugin v1.4.3/go.mod h1:5fGEH17QVwTTcR0zV7yhDPLLmFX9YSZ38b18Udy6vYQ=
github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM=
@ -85,8 +86,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a h1:JJc19rE0eW2knPa/KIFYvqyu25CwzKltJ5Cw1kK3o4A=
github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a/go.mod h1:Bhgac6kiwIziILXLzH4wepT8lQXyhF83poDXqZorN6Q=
github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea h1:ouwL3x9tXiAXIhdXtJGONd905f1dBLu3HhfFoaTq24k=
github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea/go.mod h1:Bhgac6kiwIziILXLzH4wepT8lQXyhF83poDXqZorN6Q=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
@ -110,8 +111,9 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98 h1:+6WJMRLHlD7X7frgp7TUZ36RnQzSf9wVVTNakEp+nqY=
golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d h1:1n1fc535VhN8SYtD4cDUyNlfpAF2ROMM9+11equK3hs=
golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -132,8 +134,9 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 h1:XfKQ4OlFl8okEOr5UvAqFRVj8pY/4yfcXrddB8qAbU0=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@ -156,8 +159,9 @@ google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoA
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb h1:ZrsicilzPCS/Xr8qtBZZLpy4P9TYXAfl49ctG1/5tgw=
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 h1:zzNejm+EgrbLfDZ6lu9Uud2IVvHySPl8vQzf04laR5Q=
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=

View file

@ -149,8 +149,8 @@ func (c *Connection) getFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File
}
}
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath, common.TransferDownload,
0, 0, 0, false, fs)
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath,
common.TransferDownload, 0, 0, 0, 0, false, fs)
return newWebDavFile(baseTransfer, nil, r), nil
}
@ -214,7 +214,7 @@ func (c *Connection) handleUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, re
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, 0, 0, maxWriteSize, true, fs)
common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
return newWebDavFile(baseTransfer, w, nil), nil
}
@ -252,6 +252,7 @@ func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePat
return nil, c.GetFsError(fs, err)
}
initialSize := int64(0)
truncatedSize := int64(0) // bytes truncated and not included in quota
if vfs.IsLocalOrSFTPFs(fs) {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil {
@ -264,12 +265,13 @@ func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePat
}
} else {
initialSize = fileSize
truncatedSize = fileSize
}
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, 0, initialSize, maxWriteSize, false, fs)
common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, false, fs)
return newWebDavFile(baseTransfer, w, nil), nil
}

View file

@ -695,7 +695,7 @@ func TestContentType(t *testing.T) {
testFilePath := filepath.Join(user.HomeDir, testFile)
ctx := context.Background()
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
common.TransferDownload, 0, 0, 0, 0, false, fs)
fs = newMockOsFs(nil, false, fs.ConnectionID(), user.GetHomeDir(), nil)
err := os.WriteFile(testFilePath, []byte(""), os.ModePerm)
assert.NoError(t, err)
@ -745,7 +745,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
}
testFilePath := filepath.Join(user.HomeDir, testFile)
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferUpload, 0, 0, 0, false, fs)
common.TransferUpload, 0, 0, 0, 0, false, fs)
davFile := newWebDavFile(baseTransfer, nil, nil)
p := make([]byte, 1)
_, err := davFile.Read(p)
@ -763,7 +763,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
assert.NoError(t, err)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Read(p)
assert.True(t, os.IsNotExist(err))
@ -771,7 +771,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
assert.True(t, os.IsNotExist(err))
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
common.TransferDownload, 0, 0, 0, 0, false, fs)
err = os.WriteFile(testFilePath, []byte(""), os.ModePerm)
assert.NoError(t, err)
f, err := os.Open(testFilePath)
@ -796,7 +796,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
assert.NoError(t, err)
mockFs := newMockOsFs(nil, false, fs.ConnectionID(), user.HomeDir, r)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, mockFs)
common.TransferDownload, 0, 0, 0, 0, false, mockFs)
davFile = newWebDavFile(baseTransfer, nil, nil)
writeContent := []byte("content\r\n")
@ -816,7 +816,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
assert.NoError(t, err)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.writer = f
err = davFile.Close()
@ -841,7 +841,7 @@ func TestTransferSeek(t *testing.T) {
testFilePath := filepath.Join(user.HomeDir, testFile)
testFileContents := []byte("content")
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferUpload, 0, 0, 0, false, fs)
common.TransferUpload, 0, 0, 0, 0, false, fs)
davFile := newWebDavFile(baseTransfer, nil, nil)
_, err := davFile.Seek(0, io.SeekStart)
assert.EqualError(t, err, common.ErrOpUnsupported.Error())
@ -849,7 +849,7 @@ func TestTransferSeek(t *testing.T) {
assert.NoError(t, err)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekCurrent)
assert.True(t, os.IsNotExist(err))
@ -863,14 +863,14 @@ func TestTransferSeek(t *testing.T) {
assert.NoError(t, err)
}
baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekStart)
assert.Error(t, err)
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil)
res, err := davFile.Seek(0, io.SeekStart)
assert.NoError(t, err)
@ -885,14 +885,14 @@ func TestTransferSeek(t *testing.T) {
assert.Nil(t, err)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile,
common.TransferDownload, 0, 0, 0, false, fs)
common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekEnd)
assert.True(t, os.IsNotExist(err))
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.reader = f
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil)
@ -907,7 +907,7 @@ func TestTransferSeek(t *testing.T) {
assert.Equal(t, int64(5), res)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile,
common.TransferDownload, 0, 0, 0, false, fs)
common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil)