check quota usage between ongoing transfers

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2022-01-20 18:19:20 +01:00
parent d73be7aee5
commit d2a4178846
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
30 changed files with 1228 additions and 158 deletions

View file

@ -56,6 +56,7 @@ const (
OperationSSHCmd = "ssh_cmd" OperationSSHCmd = "ssh_cmd"
chtimesFormat = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS chtimesFormat = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS
idleTimeoutCheckInterval = 3 * time.Minute idleTimeoutCheckInterval = 3 * time.Minute
periodicTimeoutCheckInterval = 1 * time.Minute
) )
// Stat flags // Stat flags
@ -110,6 +111,7 @@ var (
ErrCrtRevoked = errors.New("your certificate has been revoked") ErrCrtRevoked = errors.New("your certificate has been revoked")
ErrNoCredentials = errors.New("no credential provided") ErrNoCredentials = errors.New("no credential provided")
ErrInternalFailure = errors.New("internal failure") ErrInternalFailure = errors.New("internal failure")
ErrTransferAborted = errors.New("transfer aborted")
errNoTransfer = errors.New("requested transfer not found") errNoTransfer = errors.New("requested transfer not found")
errTransferMismatch = errors.New("transfer mismatch") errTransferMismatch = errors.New("transfer mismatch")
) )
@ -121,8 +123,9 @@ var (
Connections ActiveConnections Connections ActiveConnections
// QuotaScans is the list of active quota scans // QuotaScans is the list of active quota scans
QuotaScans ActiveScans QuotaScans ActiveScans
idleTimeoutTicker *time.Ticker transfersChecker TransfersChecker
idleTimeoutTickerDone chan bool periodicTimeoutTicker *time.Ticker
periodicTimeoutTickerDone chan bool
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV, supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
ProtocolHTTP, ProtocolHTTPShare} ProtocolHTTP, ProtocolHTTPShare}
disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP} disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
@ -135,9 +138,7 @@ func Initialize(c Configuration) error {
Config = c Config = c
Config.idleLoginTimeout = 2 * time.Minute Config.idleLoginTimeout = 2 * time.Minute
Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute
if Config.IdleTimeout > 0 { startPeriodicTimeoutTicker(periodicTimeoutCheckInterval)
startIdleTimeoutTicker(idleTimeoutCheckInterval)
}
Config.defender = nil Config.defender = nil
rateLimiters = make(map[string][]*rateLimiter) rateLimiters = make(map[string][]*rateLimiter)
for _, rlCfg := range c.RateLimitersConfig { for _, rlCfg := range c.RateLimitersConfig {
@ -176,6 +177,7 @@ func Initialize(c Configuration) error {
} }
vfs.SetTempPath(c.TempPath) vfs.SetTempPath(c.TempPath)
dataprovider.SetTempPath(c.TempPath) dataprovider.SetTempPath(c.TempPath)
transfersChecker = getTransfersChecker()
return nil return nil
} }
@ -267,41 +269,52 @@ func AddDefenderEvent(ip string, event HostEvent) {
} }
// the ticker cannot be started/stopped from multiple goroutines // the ticker cannot be started/stopped from multiple goroutines
func startIdleTimeoutTicker(duration time.Duration) { func startPeriodicTimeoutTicker(duration time.Duration) {
stopIdleTimeoutTicker() stopPeriodicTimeoutTicker()
idleTimeoutTicker = time.NewTicker(duration) periodicTimeoutTicker = time.NewTicker(duration)
idleTimeoutTickerDone = make(chan bool) periodicTimeoutTickerDone = make(chan bool)
go func() { go func() {
counter := int64(0)
ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval
for { for {
select { select {
case <-idleTimeoutTickerDone: case <-periodicTimeoutTickerDone:
return return
case <-idleTimeoutTicker.C: case <-periodicTimeoutTicker.C:
counter++
if Config.IdleTimeout > 0 && counter >= int64(ratio) {
counter = 0
Connections.checkIdles() Connections.checkIdles()
} }
go Connections.checkTransfers()
}
} }
}() }()
} }
func stopIdleTimeoutTicker() { func stopPeriodicTimeoutTicker() {
if idleTimeoutTicker != nil { if periodicTimeoutTicker != nil {
idleTimeoutTicker.Stop() periodicTimeoutTicker.Stop()
idleTimeoutTickerDone <- true periodicTimeoutTickerDone <- true
idleTimeoutTicker = nil periodicTimeoutTicker = nil
} }
} }
// ActiveTransfer defines the interface for the current active transfers // ActiveTransfer defines the interface for the current active transfers
type ActiveTransfer interface { type ActiveTransfer interface {
GetID() uint64 GetID() int64
GetType() int GetType() int
GetSize() int64 GetSize() int64
GetDownloadedSize() int64
GetUploadedSize() int64
GetVirtualPath() string GetVirtualPath() string
GetStartTime() time.Time GetStartTime() time.Time
SignalClose() SignalClose(err error)
Truncate(fsPath string, size int64) (int64, error) Truncate(fsPath string, size int64) (int64, error)
GetRealFsPath(fsPath string) string GetRealFsPath(fsPath string) string
SetTimes(fsPath string, atime time.Time, mtime time.Time) bool SetTimes(fsPath string, atime time.Time, mtime time.Time) bool
GetTruncatedSize() int64
GetMaxAllowedSize() int64
} }
// ActiveConnection defines the interface for the current active connections // ActiveConnection defines the interface for the current active connections
@ -319,6 +332,7 @@ type ActiveConnection interface {
AddTransfer(t ActiveTransfer) AddTransfer(t ActiveTransfer)
RemoveTransfer(t ActiveTransfer) RemoveTransfer(t ActiveTransfer)
GetTransfers() []ConnectionTransfer GetTransfers() []ConnectionTransfer
SignalTransferClose(transferID int64, err error)
CloseFS() error CloseFS() error
} }
@ -335,11 +349,14 @@ type StatAttributes struct {
// ConnectionTransfer defines the trasfer details to expose // ConnectionTransfer defines the trasfer details to expose
type ConnectionTransfer struct { type ConnectionTransfer struct {
ID uint64 `json:"-"` ID int64 `json:"-"`
OperationType string `json:"operation_type"` OperationType string `json:"operation_type"`
StartTime int64 `json:"start_time"` StartTime int64 `json:"start_time"`
Size int64 `json:"size"` Size int64 `json:"size"`
VirtualPath string `json:"path"` VirtualPath string `json:"path"`
MaxAllowedSize int64 `json:"-"`
ULSize int64 `json:"-"`
DLSize int64 `json:"-"`
} }
func (t *ConnectionTransfer) getConnectionTransferAsString() string { func (t *ConnectionTransfer) getConnectionTransferAsString() string {
@ -654,6 +671,7 @@ type ActiveConnections struct {
// clients contains both authenticated and estabilished connections and the ones waiting // clients contains both authenticated and estabilished connections and the ones waiting
// for authentication // for authentication
clients clientsMap clients clientsMap
transfersCheckStatus int32
sync.RWMutex sync.RWMutex
connections []ActiveConnection connections []ActiveConnection
sshConnections []*SSHConnection sshConnections []*SSHConnection
@ -825,6 +843,59 @@ func (conns *ActiveConnections) checkIdles() {
conns.RUnlock() conns.RUnlock()
} }
func (conns *ActiveConnections) checkTransfers() {
if atomic.LoadInt32(&conns.transfersCheckStatus) == 1 {
logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution")
return
}
atomic.StoreInt32(&conns.transfersCheckStatus, 1)
defer atomic.StoreInt32(&conns.transfersCheckStatus, 0)
var wg sync.WaitGroup
logger.Debug(logSender, "", "start concurrent transfers check")
conns.RLock()
// update the current size for transfers to monitors
for _, c := range conns.connections {
for _, t := range c.GetTransfers() {
if t.MaxAllowedSize > 0 {
wg.Add(1)
go func(transfer ConnectionTransfer, connID string) {
defer wg.Done()
transfersChecker.UpdateTransferCurrentSize(transfer.ULSize, transfer.DLSize, transfer.ID, connID)
}(t, c.GetID())
}
}
}
conns.RUnlock()
logger.Debug(logSender, "", "waiting for the update of the transfers current size")
wg.Wait()
logger.Debug(logSender, "", "getting overquota transfers")
overquotaTransfers := transfersChecker.GetOverquotaTransfers()
logger.Debug(logSender, "", "number of overquota transfers: %v", len(overquotaTransfers))
if len(overquotaTransfers) == 0 {
return
}
conns.RLock()
defer conns.RUnlock()
for _, c := range conns.connections {
for _, overquotaTransfer := range overquotaTransfers {
if c.GetID() == overquotaTransfer.ConnID {
logger.Info(logSender, c.GetID(), "user %#v is overquota, try to close transfer id %v ",
c.GetUsername(), overquotaTransfer.TransferID)
c.SignalTransferClose(overquotaTransfer.TransferID, getQuotaExceededError(c.GetProtocol()))
}
}
}
logger.Debug(logSender, "", "transfers check completed")
}
// AddClientConnection stores a new client connection // AddClientConnection stores a new client connection
func (conns *ActiveConnections) AddClientConnection(ipAddr string) { func (conns *ActiveConnections) AddClientConnection(ipAddr string) {
conns.clients.add(ipAddr) conns.clients.add(ipAddr)

View file

@ -408,19 +408,19 @@ func TestIdleConnections(t *testing.T) {
assert.Len(t, Connections.sshConnections, 2) assert.Len(t, Connections.sshConnections, 2)
Connections.RUnlock() Connections.RUnlock()
startIdleTimeoutTicker(100 * time.Millisecond) startPeriodicTimeoutTicker(100 * time.Millisecond)
assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond) assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
Connections.RLock() Connections.RLock()
defer Connections.RUnlock() defer Connections.RUnlock()
return len(Connections.sshConnections) == 1 return len(Connections.sshConnections) == 1
}, 1*time.Second, 200*time.Millisecond) }, 1*time.Second, 200*time.Millisecond)
stopIdleTimeoutTicker() stopPeriodicTimeoutTicker()
assert.Len(t, Connections.GetStats(), 2) assert.Len(t, Connections.GetStats(), 2)
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano() c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano() cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
sshConn2.lastActivity = c.lastActivity sshConn2.lastActivity = c.lastActivity
startIdleTimeoutTicker(100 * time.Millisecond) startPeriodicTimeoutTicker(100 * time.Millisecond)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond) assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
Connections.RLock() Connections.RLock()
@ -428,7 +428,7 @@ func TestIdleConnections(t *testing.T) {
return len(Connections.sshConnections) == 0 return len(Connections.sshConnections) == 0
}, 1*time.Second, 200*time.Millisecond) }, 1*time.Second, 200*time.Millisecond)
assert.Equal(t, int32(0), Connections.GetClientConnections()) assert.Equal(t, int32(0), Connections.GetClientConnections())
stopIdleTimeoutTicker() stopPeriodicTimeoutTicker()
assert.True(t, customConn1.isClosed) assert.True(t, customConn1.isClosed)
assert.True(t, customConn2.isClosed) assert.True(t, customConn2.isClosed)
@ -505,9 +505,9 @@ func TestConnectionStatus(t *testing.T) {
fakeConn1 := &fakeConnection{ fakeConn1 := &fakeConnection{
BaseConnection: c1, BaseConnection: c1,
} }
t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, true, fs) t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs)
t1.BytesReceived = 123 t1.BytesReceived = 123
t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs) t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs)
t2.BytesSent = 456 t2.BytesSent = 456
c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user) c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user)
fakeConn2 := &fakeConnection{ fakeConn2 := &fakeConnection{
@ -519,7 +519,7 @@ func TestConnectionStatus(t *testing.T) {
BaseConnection: c3, BaseConnection: c3,
command: "PROPFIND", command: "PROPFIND",
} }
t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs) t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs)
Connections.Add(fakeConn1) Connections.Add(fakeConn1)
Connections.Add(fakeConn2) Connections.Add(fakeConn2)
Connections.Add(fakeConn3) Connections.Add(fakeConn3)

View file

@ -27,7 +27,7 @@ type BaseConnection struct {
lastActivity int64 lastActivity int64
// unique ID for a transfer. // unique ID for a transfer.
// This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment // This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment
transferID uint64 transferID int64
// Unique identifier for the connection // Unique identifier for the connection
ID string ID string
// user associated with this connection if any // user associated with this connection if any
@ -66,8 +66,8 @@ func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...interfac
} }
// GetTransferID returns an unique transfer ID for this connection // GetTransferID returns an unique transfer ID for this connection
func (c *BaseConnection) GetTransferID() uint64 { func (c *BaseConnection) GetTransferID() int64 {
return atomic.AddUint64(&c.transferID, 1) return atomic.AddInt64(&c.transferID, 1)
} }
// GetID returns the connection ID // GetID returns the connection ID
@ -125,6 +125,27 @@ func (c *BaseConnection) AddTransfer(t ActiveTransfer) {
c.activeTransfers = append(c.activeTransfers, t) c.activeTransfers = append(c.activeTransfers, t)
c.Log(logger.LevelDebug, "transfer added, id: %v, active transfers: %v", t.GetID(), len(c.activeTransfers)) c.Log(logger.LevelDebug, "transfer added, id: %v, active transfers: %v", t.GetID(), len(c.activeTransfers))
if t.GetMaxAllowedSize() > 0 {
folderName := ""
if t.GetType() == TransferUpload {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(t.GetVirtualPath()))
if err == nil {
if !vfolder.IsIncludedInUserQuota() {
folderName = vfolder.Name
}
}
}
go transfersChecker.AddTransfer(dataprovider.ActiveTransfer{
ID: t.GetID(),
Type: t.GetType(),
ConnID: c.ID,
Username: c.GetUsername(),
FolderName: folderName,
TruncatedSize: t.GetTruncatedSize(),
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
}
} }
// RemoveTransfer removes the specified transfer from the active ones // RemoveTransfer removes the specified transfer from the active ones
@ -132,6 +153,10 @@ func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
if t.GetMaxAllowedSize() > 0 {
go transfersChecker.RemoveTransfer(t.GetID(), c.ID)
}
for idx, transfer := range c.activeTransfers { for idx, transfer := range c.activeTransfers {
if transfer.GetID() == t.GetID() { if transfer.GetID() == t.GetID() {
lastIdx := len(c.activeTransfers) - 1 lastIdx := len(c.activeTransfers) - 1
@ -145,6 +170,20 @@ func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) {
c.Log(logger.LevelWarn, "transfer to remove with id %v not found!", t.GetID()) c.Log(logger.LevelWarn, "transfer to remove with id %v not found!", t.GetID())
} }
// SignalTransferClose makes the transfer fail on the next read/write with the
// specified error
func (c *BaseConnection) SignalTransferClose(transferID int64, err error) {
c.RLock()
defer c.RUnlock()
for _, t := range c.activeTransfers {
if t.GetID() == transferID {
c.Log(logger.LevelInfo, "signal transfer close for transfer id %v", transferID)
t.SignalClose(err)
}
}
}
// GetTransfers returns the active transfers // GetTransfers returns the active transfers
func (c *BaseConnection) GetTransfers() []ConnectionTransfer { func (c *BaseConnection) GetTransfers() []ConnectionTransfer {
c.RLock() c.RLock()
@ -165,6 +204,9 @@ func (c *BaseConnection) GetTransfers() []ConnectionTransfer {
StartTime: util.GetTimeAsMsSinceEpoch(t.GetStartTime()), StartTime: util.GetTimeAsMsSinceEpoch(t.GetStartTime()),
Size: t.GetSize(), Size: t.GetSize(),
VirtualPath: t.GetVirtualPath(), VirtualPath: t.GetVirtualPath(),
MaxAllowedSize: t.GetMaxAllowedSize(),
ULSize: t.GetUploadedSize(),
DLSize: t.GetDownloadedSize(),
}) })
} }
@ -181,7 +223,7 @@ func (c *BaseConnection) SignalTransfersAbort() error {
} }
for _, t := range c.activeTransfers { for _, t := range c.activeTransfers {
t.SignalClose() t.SignalClose(ErrTransferAborted)
} }
return nil return nil
} }
@ -1208,9 +1250,8 @@ func (c *BaseConnection) GetOpUnsupportedError() error {
} }
} }
// GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol func getQuotaExceededError(protocol string) error {
func (c *BaseConnection) GetQuotaExceededError() error { switch protocol {
switch c.protocol {
case ProtocolSFTP: case ProtocolSFTP:
return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, ErrQuotaExceeded.Error()) return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, ErrQuotaExceeded.Error())
case ProtocolFTP: case ProtocolFTP:
@ -1220,6 +1261,11 @@ func (c *BaseConnection) GetQuotaExceededError() error {
} }
} }
// GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol
func (c *BaseConnection) GetQuotaExceededError() error {
return getQuotaExceededError(c.protocol)
}
// IsQuotaExceededError returns true if the given error is a quota exceeded error // IsQuotaExceededError returns true if the given error is a quota exceeded error
func (c *BaseConnection) IsQuotaExceededError(err error) bool { func (c *BaseConnection) IsQuotaExceededError(err error) bool {
switch c.protocol { switch c.protocol {

View file

@ -20,7 +20,7 @@ var (
// BaseTransfer contains protocols common transfer details for an upload or a download. // BaseTransfer contains protocols common transfer details for an upload or a download.
type BaseTransfer struct { //nolint:maligned type BaseTransfer struct { //nolint:maligned
ID uint64 ID int64
BytesSent int64 BytesSent int64
BytesReceived int64 BytesReceived int64
Fs vfs.Fs Fs vfs.Fs
@ -35,18 +35,21 @@ type BaseTransfer struct { //nolint:maligned
MaxWriteSize int64 MaxWriteSize int64
MinWriteOffset int64 MinWriteOffset int64
InitialSize int64 InitialSize int64
truncatedSize int64
isNewFile bool isNewFile bool
transferType int transferType int
AbortTransfer int32 AbortTransfer int32
aTime time.Time aTime time.Time
mTime time.Time mTime time.Time
sync.Mutex sync.Mutex
errAbort error
ErrTransfer error ErrTransfer error
} }
// NewBaseTransfer returns a new BaseTransfer and adds it to the given connection // NewBaseTransfer returns a new BaseTransfer and adds it to the given connection
func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPath, effectiveFsPath, requestPath string, func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPath, effectiveFsPath, requestPath string,
transferType int, minWriteOffset, initialSize, maxWriteSize int64, isNewFile bool, fs vfs.Fs) *BaseTransfer { transferType int, minWriteOffset, initialSize, maxWriteSize, truncatedSize int64, isNewFile bool, fs vfs.Fs,
) *BaseTransfer {
t := &BaseTransfer{ t := &BaseTransfer{
ID: conn.GetTransferID(), ID: conn.GetTransferID(),
File: file, File: file,
@ -64,6 +67,7 @@ func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPat
BytesReceived: 0, BytesReceived: 0,
MaxWriteSize: maxWriteSize, MaxWriteSize: maxWriteSize,
AbortTransfer: 0, AbortTransfer: 0,
truncatedSize: truncatedSize,
Fs: fs, Fs: fs,
} }
@ -77,7 +81,7 @@ func (t *BaseTransfer) SetFtpMode(mode string) {
} }
// GetID returns the transfer ID // GetID returns the transfer ID
func (t *BaseTransfer) GetID() uint64 { func (t *BaseTransfer) GetID() int64 {
return t.ID return t.ID
} }
@ -94,19 +98,53 @@ func (t *BaseTransfer) GetSize() int64 {
return atomic.LoadInt64(&t.BytesReceived) return atomic.LoadInt64(&t.BytesReceived)
} }
// GetDownloadedSize returns the transferred size
func (t *BaseTransfer) GetDownloadedSize() int64 {
return atomic.LoadInt64(&t.BytesSent)
}
// GetUploadedSize returns the transferred size
func (t *BaseTransfer) GetUploadedSize() int64 {
return atomic.LoadInt64(&t.BytesReceived)
}
// GetStartTime returns the start time // GetStartTime returns the start time
func (t *BaseTransfer) GetStartTime() time.Time { func (t *BaseTransfer) GetStartTime() time.Time {
return t.start return t.start
} }
// SignalClose signals that the transfer should be closed. // GetAbortError returns the error to send to the client if the transfer was aborted
// For same protocols, for example WebDAV, we have no func (t *BaseTransfer) GetAbortError() error {
// access to the network connection, so we use this method t.Lock()
// to make the next read or write to fail defer t.Unlock()
func (t *BaseTransfer) SignalClose() {
if t.errAbort != nil {
return t.errAbort
}
return getQuotaExceededError(t.Connection.protocol)
}
// SignalClose signals that the transfer should be closed after the next read/write.
// The optional error argument allow to send a specific error, otherwise a generic
// transfer aborted error is sent
func (t *BaseTransfer) SignalClose(err error) {
t.Lock()
t.errAbort = err
t.Unlock()
atomic.StoreInt32(&(t.AbortTransfer), 1) atomic.StoreInt32(&(t.AbortTransfer), 1)
} }
// GetTruncatedSize returns the truncated sized if this is an upload overwriting
// an existing file
func (t *BaseTransfer) GetTruncatedSize() int64 {
return t.truncatedSize
}
// GetMaxAllowedSize returns the max allowed size
func (t *BaseTransfer) GetMaxAllowedSize() int64 {
return t.MaxWriteSize
}
// GetVirtualPath returns the transfer virtual path // GetVirtualPath returns the transfer virtual path
func (t *BaseTransfer) GetVirtualPath() string { func (t *BaseTransfer) GetVirtualPath() string {
return t.requestPath return t.requestPath

View file

@ -65,7 +65,7 @@ func TestTransferThrottling(t *testing.T) {
wantedUploadElapsed -= wantedDownloadElapsed / 10 wantedUploadElapsed -= wantedDownloadElapsed / 10
wantedDownloadElapsed -= wantedDownloadElapsed / 10 wantedDownloadElapsed -= wantedDownloadElapsed / 10
conn := NewBaseConnection("id", ProtocolSCP, "", "", u) conn := NewBaseConnection("id", ProtocolSCP, "", "", u)
transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, true, fs) transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs)
transfer.BytesReceived = testFileSize transfer.BytesReceived = testFileSize
transfer.Connection.UpdateLastActivity() transfer.Connection.UpdateLastActivity()
startTime := transfer.Connection.GetLastActivity() startTime := transfer.Connection.GetLastActivity()
@ -75,7 +75,7 @@ func TestTransferThrottling(t *testing.T) {
err := transfer.Close() err := transfer.Close()
assert.NoError(t, err) assert.NoError(t, err)
transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, true, fs) transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs)
transfer.BytesSent = testFileSize transfer.BytesSent = testFileSize
transfer.Connection.UpdateLastActivity() transfer.Connection.UpdateLastActivity()
startTime = transfer.Connection.GetLastActivity() startTime = transfer.Connection.GetLastActivity()
@ -101,7 +101,8 @@ func TestRealPath(t *testing.T) {
file, err := os.Create(testFile) file, err := os.Create(testFile)
require.NoError(t, err) require.NoError(t, err)
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs) transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file",
TransferUpload, 0, 0, 0, 0, true, fs)
rPath := transfer.GetRealFsPath(testFile) rPath := transfer.GetRealFsPath(testFile)
assert.Equal(t, testFile, rPath) assert.Equal(t, testFile, rPath)
rPath = conn.getRealFsPath(testFile) rPath = conn.getRealFsPath(testFile)
@ -138,7 +139,8 @@ func TestTruncate(t *testing.T) {
_, err = file.Write([]byte("hello")) _, err = file.Write([]byte("hello"))
assert.NoError(t, err) assert.NoError(t, err)
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 5, 100, false, fs) transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 5,
100, 0, false, fs)
err = conn.SetStat("/transfer_test_file", &StatAttributes{ err = conn.SetStat("/transfer_test_file", &StatAttributes{
Size: 2, Size: 2,
@ -155,7 +157,8 @@ func TestTruncate(t *testing.T) {
assert.Equal(t, int64(2), fi.Size()) assert.Equal(t, int64(2), fi.Size())
} }
transfer = NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 100, true, fs) transfer = NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0,
100, 0, true, fs)
// file.Stat will fail on a closed file // file.Stat will fail on a closed file
err = conn.SetStat("/transfer_test_file", &StatAttributes{ err = conn.SetStat("/transfer_test_file", &StatAttributes{
Size: 2, Size: 2,
@ -165,7 +168,7 @@ func TestTruncate(t *testing.T) {
err = transfer.Close() err = transfer.Close()
assert.NoError(t, err) assert.NoError(t, err)
transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, true, fs) transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, 0, true, fs)
_, err = transfer.Truncate("mismatch", 0) _, err = transfer.Truncate("mismatch", 0)
assert.EqualError(t, err, errTransferMismatch.Error()) assert.EqualError(t, err, errTransferMismatch.Error())
_, err = transfer.Truncate(testFile, 0) _, err = transfer.Truncate(testFile, 0)
@ -202,7 +205,8 @@ func TestTransferErrors(t *testing.T) {
assert.FailNow(t, "unable to open test file") assert.FailNow(t, "unable to open test file")
} }
conn := NewBaseConnection("id", ProtocolSFTP, "", "", u) conn := NewBaseConnection("id", ProtocolSFTP, "", "", u)
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs) transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload,
0, 0, 0, 0, true, fs)
assert.Nil(t, transfer.cancelFn) assert.Nil(t, transfer.cancelFn)
assert.Equal(t, testFile, transfer.GetFsPath()) assert.Equal(t, testFile, transfer.GetFsPath())
transfer.SetCancelFn(cancelFn) transfer.SetCancelFn(cancelFn)
@ -228,7 +232,7 @@ func TestTransferErrors(t *testing.T) {
assert.FailNow(t, "unable to open test file") assert.FailNow(t, "unable to open test file")
} }
fsPath := filepath.Join(os.TempDir(), "test_file") fsPath := filepath.Join(os.TempDir(), "test_file")
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, true, fs) transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs)
transfer.BytesReceived = 9 transfer.BytesReceived = 9
transfer.TransferError(errFake) transfer.TransferError(errFake)
assert.Error(t, transfer.ErrTransfer, errFake.Error()) assert.Error(t, transfer.ErrTransfer, errFake.Error())
@ -247,7 +251,7 @@ func TestTransferErrors(t *testing.T) {
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
assert.FailNow(t, "unable to open test file") assert.FailNow(t, "unable to open test file")
} }
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, true, fs) transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs)
transfer.BytesReceived = 9 transfer.BytesReceived = 9
// the file is closed from the embedding struct before to call close // the file is closed from the embedding struct before to call close
err = file.Close() err = file.Close()
@ -273,7 +277,8 @@ func TestRemovePartialCryptoFile(t *testing.T) {
}, },
} }
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs) transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload,
0, 0, 0, 0, true, fs)
transfer.ErrTransfer = errors.New("test error") transfer.ErrTransfer = errors.New("test error")
_, err = transfer.getUploadFileSize() _, err = transfer.getUploadFileSize()
assert.Error(t, err) assert.Error(t, err)

167
common/transferschecker.go Normal file
View file

@ -0,0 +1,167 @@
package common
import (
"errors"
"sync"
"time"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
type overquotaTransfer struct {
ConnID string
TransferID int64
}
// TransfersChecker defines the interface that transfer checkers must implement.
// A transfer checker ensure that multiple concurrent transfers does not exceeded
// the remaining user quota
type TransfersChecker interface {
AddTransfer(transfer dataprovider.ActiveTransfer)
RemoveTransfer(ID int64, connectionID string)
UpdateTransferCurrentSize(ulSize int64, dlSize int64, ID int64, connectionID string)
GetOverquotaTransfers() []overquotaTransfer
}
func getTransfersChecker() TransfersChecker {
return &transfersCheckerMem{}
}
type transfersCheckerMem struct {
sync.RWMutex
transfers []dataprovider.ActiveTransfer
}
func (t *transfersCheckerMem) AddTransfer(transfer dataprovider.ActiveTransfer) {
t.Lock()
defer t.Unlock()
t.transfers = append(t.transfers, transfer)
}
func (t *transfersCheckerMem) RemoveTransfer(ID int64, connectionID string) {
t.Lock()
defer t.Unlock()
for idx, transfer := range t.transfers {
if transfer.ID == ID && transfer.ConnID == connectionID {
lastIdx := len(t.transfers) - 1
t.transfers[idx] = t.transfers[lastIdx]
t.transfers = t.transfers[:lastIdx]
return
}
}
}
func (t *transfersCheckerMem) UpdateTransferCurrentSize(ulSize int64, dlSize int64, ID int64, connectionID string) {
t.Lock()
defer t.Unlock()
for idx := range t.transfers {
if t.transfers[idx].ID == ID && t.transfers[idx].ConnID == connectionID {
t.transfers[idx].CurrentDLSize = dlSize
t.transfers[idx].CurrentULSize = ulSize
t.transfers[idx].UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
return
}
}
}
func (t *transfersCheckerMem) getRemainingDiskQuota(user dataprovider.User, folderName string) (int64, error) {
var result int64
if folderName != "" {
for _, folder := range user.VirtualFolders {
if folder.Name == folderName {
if folder.QuotaSize > 0 {
return folder.QuotaSize - folder.UsedQuotaSize, nil
}
}
}
} else {
if user.QuotaSize > 0 {
return user.QuotaSize - user.UsedQuotaSize, nil
}
}
return result, errors.New("no quota limit defined")
}
func (t *transfersCheckerMem) aggregateTransfers() (map[string]bool, map[string][]dataprovider.ActiveTransfer) {
t.RLock()
defer t.RUnlock()
usersToFetch := make(map[string]bool)
aggregations := make(map[string][]dataprovider.ActiveTransfer)
for _, transfer := range t.transfers {
key := transfer.GetKey()
aggregations[key] = append(aggregations[key], transfer)
if len(aggregations[key]) > 1 {
if transfer.FolderName != "" {
usersToFetch[transfer.Username] = true
} else {
if _, ok := usersToFetch[transfer.Username]; !ok {
usersToFetch[transfer.Username] = false
}
}
}
}
return usersToFetch, aggregations
}
func (t *transfersCheckerMem) GetOverquotaTransfers() []overquotaTransfer {
usersToFetch, aggregations := t.aggregateTransfers()
if len(usersToFetch) == 0 {
return nil
}
users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch)
if err != nil {
logger.Warn(logSender, "", "unable to check transfers, error getting users quota: %v", err)
return nil
}
usersMap := make(map[string]dataprovider.User)
for _, user := range users {
usersMap[user.Username] = user
}
var overquotaTransfers []overquotaTransfer
for _, transfers := range aggregations {
if len(transfers) > 1 {
username := transfers[0].Username
folderName := transfers[0].FolderName
// transfer type is always upload for now
remaningDiskQuota, err := t.getRemainingDiskQuota(usersMap[username], folderName)
if err != nil {
continue
}
var usedDiskQuota int64
for _, tr := range transfers {
// We optimistically assume that a cloud transfer that replaces an existing
// file will be successful
usedDiskQuota += tr.CurrentULSize - tr.TruncatedSize
}
logger.Debug(logSender, "", "username %#v, folder %#v, concurrent transfers: %v, remaining disk quota: %v, disk quota used in ongoing transfers: %v",
username, folderName, len(transfers), remaningDiskQuota, usedDiskQuota)
if usedDiskQuota > remaningDiskQuota {
for _, tr := range transfers {
if tr.CurrentULSize > tr.TruncatedSize {
overquotaTransfers = append(overquotaTransfers, overquotaTransfer{
ConnID: tr.ConnID,
TransferID: tr.ID,
})
}
}
}
}
}
return overquotaTransfers
}

View file

@ -0,0 +1,449 @@
package common
import (
"fmt"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/rs/xid"
"github.com/sftpgo/sdk"
"github.com/stretchr/testify/assert"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/vfs"
)
func TestTransfersCheckerDiskQuota(t *testing.T) {
username := "transfers_check_username"
folderName := "test_transfers_folder"
vdirPath := "/vdir"
user := dataprovider.User{
BaseUser: sdk.BaseUser{
Username: username,
Password: "testpwd",
HomeDir: filepath.Join(os.TempDir(), username),
Status: 1,
QuotaSize: 120,
Permissions: map[string][]string{
"/": {dataprovider.PermAny},
},
},
VirtualFolders: []vfs.VirtualFolder{
{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderName,
MappedPath: filepath.Join(os.TempDir(), folderName),
},
VirtualPath: vdirPath,
QuotaSize: 100,
},
},
}
err := dataprovider.AddUser(&user, "", "")
assert.NoError(t, err)
user, err = dataprovider.UserExists(username)
assert.NoError(t, err)
connID1 := xid.New().String()
fsUser, err := user.GetFilesystemForPath("/file1", connID1)
assert.NoError(t, err)
conn1 := NewBaseConnection(connID1, ProtocolSFTP, "", "", user)
fakeConn1 := &fakeConnection{
BaseConnection: conn1,
}
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
"/file1", TransferUpload, 0, 0, 120, 0, true, fsUser)
transfer1.BytesReceived = 150
Connections.Add(fakeConn1)
// the transferschecker will do nothing if there is only one ongoing transfer
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
connID2 := xid.New().String()
conn2 := NewBaseConnection(connID2, ProtocolSFTP, "", "", user)
fakeConn2 := &fakeConnection{
BaseConnection: conn2,
}
transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
"/file2", TransferUpload, 0, 0, 120, 40, true, fsUser)
transfer1.BytesReceived = 50
transfer2.BytesReceived = 60
Connections.Add(fakeConn2)
connID3 := xid.New().String()
conn3 := NewBaseConnection(connID3, ProtocolSFTP, "", "", user)
fakeConn3 := &fakeConnection{
BaseConnection: conn3,
}
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"),
"/file3", TransferDownload, 0, 0, 120, 0, true, fsUser)
transfer3.BytesReceived = 60 // this value will be ignored, this is a download
Connections.Add(fakeConn3)
// the transfers are not overquota
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
transfer1.BytesReceived = 80 // truncated size will be subtracted, we are not overquota
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
transfer1.BytesReceived = 120
// we are now overquota
// if another check is in progress nothing is done
atomic.StoreInt32(&Connections.transfersCheckStatus, 1)
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
atomic.StoreInt32(&Connections.transfersCheckStatus, 0)
Connections.checkTransfers()
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort))
assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort))
assert.True(t, conn1.IsQuotaExceededError(transfer1.GetAbortError()))
assert.Nil(t, transfer3.errAbort)
assert.True(t, conn3.IsQuotaExceededError(transfer3.GetAbortError()))
// update the user quota size
user.QuotaSize = 1000
err = dataprovider.UpdateUser(&user, "", "")
assert.NoError(t, err)
transfer1.errAbort = nil
transfer2.errAbort = nil
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
user.QuotaSize = 0
err = dataprovider.UpdateUser(&user, "", "")
assert.NoError(t, err)
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
// now check a public folder
transfer1.BytesReceived = 0
transfer2.BytesReceived = 0
connID4 := xid.New().String()
fsFolder, err := user.GetFilesystemForPath(path.Join(vdirPath, "/file1"), connID4)
assert.NoError(t, err)
conn4 := NewBaseConnection(connID4, ProtocolSFTP, "", "", user)
fakeConn4 := &fakeConnection{
BaseConnection: conn4,
}
transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(os.TempDir(), folderName, "file1"),
filepath.Join(os.TempDir(), folderName, "file1"), path.Join(vdirPath, "/file1"), TransferUpload, 0, 0,
100, 0, true, fsFolder)
Connections.Add(fakeConn4)
connID5 := xid.New().String()
conn5 := NewBaseConnection(connID5, ProtocolSFTP, "", "", user)
fakeConn5 := &fakeConnection{
BaseConnection: conn5,
}
transfer5 := NewBaseTransfer(nil, conn5, nil, filepath.Join(os.TempDir(), folderName, "file2"),
filepath.Join(os.TempDir(), folderName, "file2"), path.Join(vdirPath, "/file2"), TransferUpload, 0, 0,
100, 0, true, fsFolder)
Connections.Add(fakeConn5)
transfer4.BytesReceived = 50
transfer5.BytesReceived = 40
Connections.checkTransfers()
assert.Nil(t, transfer4.errAbort)
assert.Nil(t, transfer5.errAbort)
transfer5.BytesReceived = 60
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
assert.True(t, conn1.IsQuotaExceededError(transfer4.errAbort))
assert.True(t, conn2.IsQuotaExceededError(transfer5.errAbort))
if dataprovider.GetProviderStatus().Driver != dataprovider.MemoryDataProviderName {
providerConf := dataprovider.GetProviderConfig()
err = dataprovider.Close()
assert.NoError(t, err)
transfer4.errAbort = nil
transfer5.errAbort = nil
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
assert.Nil(t, transfer4.errAbort)
assert.Nil(t, transfer5.errAbort)
err = dataprovider.Initialize(providerConf, configDir, true)
assert.NoError(t, err)
}
Connections.Remove(fakeConn1.GetID())
Connections.Remove(fakeConn2.GetID())
Connections.Remove(fakeConn3.GetID())
Connections.Remove(fakeConn4.GetID())
Connections.Remove(fakeConn5.GetID())
stats := Connections.GetStats()
assert.Len(t, stats, 0)
err = dataprovider.DeleteUser(user.Username, "", "")
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
err = dataprovider.DeleteFolder(folderName, "", "")
assert.NoError(t, err)
err = os.RemoveAll(filepath.Join(os.TempDir(), folderName))
assert.NoError(t, err)
}
func TestAggregateTransfers(t *testing.T) {
checker := transfersCheckerMem{}
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferUpload,
ConnID: "1",
Username: "user",
FolderName: "",
TruncatedSize: 0,
CurrentULSize: 100,
CurrentDLSize: 0,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations := checker.aggregateTransfers()
assert.Len(t, usersToFetch, 0)
assert.Len(t, aggregations, 1)
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferDownload,
ConnID: "2",
Username: "user",
FolderName: "",
TruncatedSize: 0,
CurrentULSize: 0,
CurrentDLSize: 100,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations = checker.aggregateTransfers()
assert.Len(t, usersToFetch, 0)
assert.Len(t, aggregations, 2)
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferUpload,
ConnID: "3",
Username: "user",
FolderName: "folder",
TruncatedSize: 0,
CurrentULSize: 10,
CurrentDLSize: 0,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations = checker.aggregateTransfers()
assert.Len(t, usersToFetch, 0)
assert.Len(t, aggregations, 3)
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferUpload,
ConnID: "4",
Username: "user1",
FolderName: "",
TruncatedSize: 0,
CurrentULSize: 100,
CurrentDLSize: 0,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations = checker.aggregateTransfers()
assert.Len(t, usersToFetch, 0)
assert.Len(t, aggregations, 4)
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferUpload,
ConnID: "5",
Username: "user",
FolderName: "",
TruncatedSize: 0,
CurrentULSize: 100,
CurrentDLSize: 0,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations = checker.aggregateTransfers()
assert.Len(t, usersToFetch, 1)
val, ok := usersToFetch["user"]
assert.True(t, ok)
assert.False(t, val)
assert.Len(t, aggregations, 4)
aggregate, ok := aggregations["user0"]
assert.True(t, ok)
assert.Len(t, aggregate, 2)
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferUpload,
ConnID: "6",
Username: "user",
FolderName: "",
TruncatedSize: 0,
CurrentULSize: 100,
CurrentDLSize: 0,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations = checker.aggregateTransfers()
assert.Len(t, usersToFetch, 1)
val, ok = usersToFetch["user"]
assert.True(t, ok)
assert.False(t, val)
assert.Len(t, aggregations, 4)
aggregate, ok = aggregations["user0"]
assert.True(t, ok)
assert.Len(t, aggregate, 3)
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferUpload,
ConnID: "7",
Username: "user",
FolderName: "folder",
TruncatedSize: 0,
CurrentULSize: 10,
CurrentDLSize: 0,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations = checker.aggregateTransfers()
assert.Len(t, usersToFetch, 1)
val, ok = usersToFetch["user"]
assert.True(t, ok)
assert.True(t, val)
assert.Len(t, aggregations, 4)
aggregate, ok = aggregations["user0"]
assert.True(t, ok)
assert.Len(t, aggregate, 3)
aggregate, ok = aggregations["userfolder0"]
assert.True(t, ok)
assert.Len(t, aggregate, 2)
checker.AddTransfer(dataprovider.ActiveTransfer{
ID: 1,
Type: TransferUpload,
ConnID: "8",
Username: "user",
FolderName: "",
TruncatedSize: 0,
CurrentULSize: 100,
CurrentDLSize: 0,
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
})
usersToFetch, aggregations = checker.aggregateTransfers()
assert.Len(t, usersToFetch, 1)
val, ok = usersToFetch["user"]
assert.True(t, ok)
assert.True(t, val)
assert.Len(t, aggregations, 4)
aggregate, ok = aggregations["user0"]
assert.True(t, ok)
assert.Len(t, aggregate, 4)
aggregate, ok = aggregations["userfolder0"]
assert.True(t, ok)
assert.Len(t, aggregate, 2)
}
func TestGetUsersForQuotaCheck(t *testing.T) {
usersToFetch := make(map[string]bool)
for i := 0; i < 50; i++ {
usersToFetch[fmt.Sprintf("user%v", i)] = i%2 == 0
}
users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch)
assert.NoError(t, err)
assert.Len(t, users, 0)
for i := 0; i < 40; i++ {
user := dataprovider.User{
BaseUser: sdk.BaseUser{
Username: fmt.Sprintf("user%v", i),
Password: "pwd",
HomeDir: filepath.Join(os.TempDir(), fmt.Sprintf("user%v", i)),
Status: 1,
QuotaSize: 120,
Permissions: map[string][]string{
"/": {dataprovider.PermAny},
},
},
VirtualFolders: []vfs.VirtualFolder{
{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: fmt.Sprintf("f%v", i),
MappedPath: filepath.Join(os.TempDir(), fmt.Sprintf("f%v", i)),
},
VirtualPath: "/vfolder",
QuotaSize: 100,
},
},
}
err = dataprovider.AddUser(&user, "", "")
assert.NoError(t, err)
err = dataprovider.UpdateVirtualFolderQuota(&vfs.BaseVirtualFolder{Name: fmt.Sprintf("f%v", i)}, 1, 50, false)
assert.NoError(t, err)
}
users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch)
assert.NoError(t, err)
assert.Len(t, users, 40)
for _, user := range users {
userIdxStr := strings.Replace(user.Username, "user", "", 1)
userIdx, err := strconv.Atoi(userIdxStr)
assert.NoError(t, err)
if userIdx%2 == 0 {
if assert.Len(t, user.VirtualFolders, 1, user.Username) {
assert.Equal(t, int64(100), user.VirtualFolders[0].QuotaSize)
assert.Equal(t, int64(50), user.VirtualFolders[0].UsedQuotaSize)
}
} else {
switch dataprovider.GetProviderStatus().Driver {
case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName,
dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName:
assert.Len(t, user.VirtualFolders, 0, user.Username)
}
}
}
for i := 0; i < 40; i++ {
err = dataprovider.DeleteUser(fmt.Sprintf("user%v", i), "", "")
assert.NoError(t, err)
err = dataprovider.DeleteFolder(fmt.Sprintf("f%v", i), "", "")
assert.NoError(t, err)
}
users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch)
assert.NoError(t, err)
assert.Len(t, users, 0)
}

View file

@ -647,6 +647,53 @@ func (p *BoltProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) {
return nil, nil return nil, nil
} }
func (p *BoltProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
users := make([]User, 0, 30)
err := p.dbHandle.View(func(tx *bolt.Tx) error {
bucket, err := getUsersBucket(tx)
if err != nil {
return err
}
foldersBucket, err := getFoldersBucket(tx)
if err != nil {
return err
}
cursor := bucket.Cursor()
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
var user User
err := json.Unmarshal(v, &user)
if err != nil {
return err
}
needFolders, ok := toFetch[user.Username]
if !ok {
continue
}
if needFolders && len(user.VirtualFolders) > 0 {
var folders []vfs.VirtualFolder
for idx := range user.VirtualFolders {
folder := &user.VirtualFolders[idx]
baseFolder, err := folderExistsInternal(folder.Name, foldersBucket)
if err != nil {
continue
}
folder.BaseVirtualFolder = baseFolder
folders = append(folders, *folder)
}
user.VirtualFolders = folders
}
user.SetEmptySecretsIfNil()
user.PrepareForRendering()
users = append(users, user)
}
return nil
})
return users, err
}
func (p *BoltProvider) getUsers(limit int, offset int, order string) ([]User, error) { func (p *BoltProvider) getUsers(limit int, offset int, order string) ([]User, error) {
users := make([]User, 0, limit) users := make([]User, 0, limit)
var err error var err error

View file

@ -381,6 +381,26 @@ func (c *Config) IsDefenderSupported() bool {
} }
} }
// ActiveTransfer defines an active protocol transfer
type ActiveTransfer struct {
ID int64
Type int
ConnID string
Username string
FolderName string
TruncatedSize int64
CurrentULSize int64
CurrentDLSize int64
CreatedAt int64
UpdatedAt int64
}
// GetKey returns an aggregation key.
// The same key will be returned for similar transfers
func (t *ActiveTransfer) GetKey() string {
return fmt.Sprintf("%v%v%v", t.Username, t.FolderName, t.Type)
}
// DefenderEntry defines a defender entry // DefenderEntry defines a defender entry
type DefenderEntry struct { type DefenderEntry struct {
ID int64 `json:"-"` ID int64 `json:"-"`
@ -476,6 +496,7 @@ type Provider interface {
getUsers(limit int, offset int, order string) ([]User, error) getUsers(limit int, offset int, order string) ([]User, error)
dumpUsers() ([]User, error) dumpUsers() ([]User, error)
getRecentlyUpdatedUsers(after int64) ([]User, error) getRecentlyUpdatedUsers(after int64) ([]User, error)
getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error)
updateLastLogin(username string) error updateLastLogin(username string) error
updateAdminLastLogin(username string) error updateAdminLastLogin(username string) error
setUpdatedAt(username string) setUpdatedAt(username string)
@ -1268,6 +1289,11 @@ func GetUsers(limit, offset int, order string) ([]User, error) {
return provider.getUsers(limit, offset, order) return provider.getUsers(limit, offset, order)
} }
// GetUsersForQuotaCheck returns the users with the fields required for a quota check
func GetUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
return provider.getUsersForQuotaCheck(toFetch)
}
// AddFolder adds a new virtual folder. // AddFolder adds a new virtual folder.
func AddFolder(folder *vfs.BaseVirtualFolder) error { func AddFolder(folder *vfs.BaseVirtualFolder) error {
return provider.addFolder(folder) return provider.addFolder(folder)

View file

@ -349,6 +349,7 @@ func (p *MemoryProvider) dumpUsers() ([]User, error) {
for _, username := range p.dbHandle.usernames { for _, username := range p.dbHandle.usernames {
u := p.dbHandle.users[username] u := p.dbHandle.users[username]
user := u.getACopy() user := u.getACopy()
p.addVirtualFoldersToUser(&user)
err = addCredentialsToUser(&user) err = addCredentialsToUser(&user)
if err != nil { if err != nil {
return users, err return users, err
@ -376,6 +377,28 @@ func (p *MemoryProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) {
return nil, nil return nil, nil
} }
func (p *MemoryProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
users := make([]User, 0, 30)
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return users, errMemoryProviderClosed
}
for _, username := range p.dbHandle.usernames {
if val, ok := toFetch[username]; ok {
u := p.dbHandle.users[username]
user := u.getACopy()
if val {
p.addVirtualFoldersToUser(&user)
}
user.PrepareForRendering()
users = append(users, user)
}
}
return users, nil
}
func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User, error) { func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User, error) {
users := make([]User, 0, limit) users := make([]User, 0, limit)
var err error var err error
@ -396,6 +419,7 @@ func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User,
} }
u := p.dbHandle.users[username] u := p.dbHandle.users[username]
user := u.getACopy() user := u.getACopy()
p.addVirtualFoldersToUser(&user)
user.PrepareForRendering() user.PrepareForRendering()
users = append(users, user) users = append(users, user)
if len(users) >= limit { if len(users) >= limit {
@ -411,6 +435,7 @@ func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User,
username := p.dbHandle.usernames[i] username := p.dbHandle.usernames[i]
u := p.dbHandle.users[username] u := p.dbHandle.users[username]
user := u.getACopy() user := u.getACopy()
p.addVirtualFoldersToUser(&user)
user.PrepareForRendering() user.PrepareForRendering()
users = append(users, user) users = append(users, user)
if len(users) >= limit { if len(users) >= limit {
@ -427,7 +452,12 @@ func (p *MemoryProvider) userExists(username string) (User, error) {
if p.dbHandle.isClosed { if p.dbHandle.isClosed {
return User{}, errMemoryProviderClosed return User{}, errMemoryProviderClosed
} }
return p.userExistsInternal(username) user, err := p.userExistsInternal(username)
if err != nil {
return user, err
}
p.addVirtualFoldersToUser(&user)
return user, nil
} }
func (p *MemoryProvider) userExistsInternal(username string) (User, error) { func (p *MemoryProvider) userExistsInternal(username string) (User, error) {
@ -632,6 +662,22 @@ func (p *MemoryProvider) joinVirtualFoldersFields(user *User) []vfs.VirtualFolde
return folders return folders
} }
func (p *MemoryProvider) addVirtualFoldersToUser(user *User) {
if len(user.VirtualFolders) > 0 {
var folders []vfs.VirtualFolder
for idx := range user.VirtualFolders {
folder := &user.VirtualFolders[idx]
baseFolder, err := p.folderExistsInternal(folder.Name)
if err != nil {
continue
}
folder.BaseVirtualFolder = baseFolder.GetACopy()
folders = append(folders, *folder)
}
user.VirtualFolders = folders
}
}
func (p *MemoryProvider) removeUserFromFolderMapping(folderName, username string) { func (p *MemoryProvider) removeUserFromFolderMapping(folderName, username string) {
folder, err := p.folderExistsInternal(folderName) folder, err := p.folderExistsInternal(folderName)
if err == nil { if err == nil {
@ -655,7 +701,8 @@ func (p *MemoryProvider) updateFoldersMappingInternal(folder vfs.BaseVirtualFold
} }
func (p *MemoryProvider) addOrUpdateFolderInternal(baseFolder *vfs.BaseVirtualFolder, username string, usedQuotaSize int64, func (p *MemoryProvider) addOrUpdateFolderInternal(baseFolder *vfs.BaseVirtualFolder, username string, usedQuotaSize int64,
usedQuotaFiles int, lastQuotaUpdate int64) (vfs.BaseVirtualFolder, error) { usedQuotaFiles int, lastQuotaUpdate int64) (vfs.BaseVirtualFolder, error,
) {
folder, err := p.folderExistsInternal(baseFolder.Name) folder, err := p.folderExistsInternal(baseFolder.Name)
if err == nil { if err == nil {
// exists // exists

View file

@ -186,6 +186,10 @@ func (p *MySQLProvider) getUsers(limit int, offset int, order string) ([]User, e
return sqlCommonGetUsers(limit, offset, order, p.dbHandle) return sqlCommonGetUsers(limit, offset, order, p.dbHandle)
} }
func (p *MySQLProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle)
}
func (p *MySQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { func (p *MySQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) {
return sqlCommonDumpFolders(p.dbHandle) return sqlCommonDumpFolders(p.dbHandle)
} }

View file

@ -198,6 +198,10 @@ func (p *PGSQLProvider) getUsers(limit int, offset int, order string) ([]User, e
return sqlCommonGetUsers(limit, offset, order, p.dbHandle) return sqlCommonGetUsers(limit, offset, order, p.dbHandle)
} }
func (p *PGSQLProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle)
}
func (p *PGSQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { func (p *PGSQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) {
return sqlCommonDumpFolders(p.dbHandle) return sqlCommonDumpFolders(p.dbHandle)
} }

View file

@ -939,6 +939,90 @@ func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User,
return getUsersWithVirtualFolders(ctx, users, dbHandle) return getUsersWithVirtualFolders(ctx, users, dbHandle)
} }
func sqlCommonGetUsersForQuotaCheck(toFetch map[string]bool, dbHandle sqlQuerier) ([]User, error) {
users := make([]User, 0, 30)
usernames := make([]string, 0, len(toFetch))
for k := range toFetch {
usernames = append(usernames, k)
}
maxUsers := 30
for len(usernames) > 0 {
if maxUsers > len(usernames) {
maxUsers = len(usernames)
}
usersRange, err := sqlCommonGetUsersRangeForQuotaCheck(usernames[:maxUsers], dbHandle)
if err != nil {
return users, err
}
users = append(users, usersRange...)
usernames = usernames[maxUsers:]
}
var usersWithFolders []User
validIdx := 0
for _, user := range users {
if toFetch[user.Username] {
usersWithFolders = append(usersWithFolders, user)
} else {
users[validIdx] = user
validIdx++
}
}
users = users[:validIdx]
if len(usersWithFolders) == 0 {
return users, nil
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
usersWithFolders, err := getUsersWithVirtualFolders(ctx, usersWithFolders, dbHandle)
if err != nil {
return users, err
}
users = append(users, usersWithFolders...)
return users, nil
}
func sqlCommonGetUsersRangeForQuotaCheck(usernames []string, dbHandle sqlQuerier) ([]User, error) {
users := make([]User, 0, len(usernames))
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getUsersForQuotaCheckQuery(len(usernames))
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return users, err
}
defer stmt.Close()
queryArgs := make([]interface{}, 0, len(usernames))
for idx := range usernames {
queryArgs = append(queryArgs, usernames[idx])
}
rows, err := stmt.QueryContext(ctx, queryArgs...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var user User
err = rows.Scan(&user.ID, &user.Username, &user.QuotaSize, &user.UsedQuotaSize)
if err != nil {
return users, err
}
users = append(users, user)
}
return users, rows.Err()
}
func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier) ([]User, error) { func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier) ([]User, error) {
users := make([]User, 0, limit) users := make([]User, 0, limit)
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)

View file

@ -183,6 +183,10 @@ func (p *SQLiteProvider) getUsers(limit int, offset int, order string) ([]User,
return sqlCommonGetUsers(limit, offset, order, p.dbHandle) return sqlCommonGetUsers(limit, offset, order, p.dbHandle)
} }
func (p *SQLiteProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle)
}
func (p *SQLiteProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { func (p *SQLiteProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) {
return sqlCommonDumpFolders(p.dbHandle) return sqlCommonDumpFolders(p.dbHandle)
} }

View file

@ -21,7 +21,7 @@ const (
func getSQLPlaceholders() []string { func getSQLPlaceholders() []string {
var placeholders []string var placeholders []string
for i := 1; i <= 30; i++ { for i := 1; i <= 50; i++ {
if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName { if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName {
placeholders = append(placeholders, fmt.Sprintf("$%v", i)) placeholders = append(placeholders, fmt.Sprintf("$%v", i))
} else { } else {
@ -263,6 +263,23 @@ func getUsersQuery(order string) string {
order, sqlPlaceholders[0], sqlPlaceholders[1]) order, sqlPlaceholders[0], sqlPlaceholders[1])
} }
func getUsersForQuotaCheckQuery(numArgs int) string {
var sb strings.Builder
for idx := 0; idx < numArgs; idx++ {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(sqlPlaceholders[idx])
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT id,username,quota_size,used_quota_size FROM %v WHERE username IN %v`,
sqlTableUsers, sb.String())
}
func getRecentlyUpdatedUsersQuery() string { func getRecentlyUpdatedUsersQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v WHERE updated_at >= %v`, selectUserFields, sqlTableUsers, sqlPlaceholders[0]) return fmt.Sprintf(`SELECT %v FROM %v WHERE updated_at >= %v`, selectUserFields, sqlTableUsers, sqlPlaceholders[0])
} }

View file

@ -335,8 +335,8 @@ func (c *Connection) downloadFile(fs vfs.Fs, fsPath, ftpPath string, offset int6
return nil, c.GetFsError(fs, err) return nil, c.GetFsError(fs, err)
} }
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath, common.TransferDownload, baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath,
0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
baseTransfer.SetFtpMode(c.getFTPMode()) baseTransfer.SetFtpMode(c.getFTPMode())
t := newTransfer(baseTransfer, nil, r, offset) t := newTransfer(baseTransfer, nil, r, offset)
@ -402,7 +402,7 @@ func (c *Connection) handleFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath,
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, 0, 0, maxWriteSize, true, fs) common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
baseTransfer.SetFtpMode(c.getFTPMode()) baseTransfer.SetFtpMode(c.getFTPMode())
t := newTransfer(baseTransfer, w, nil, 0) t := newTransfer(baseTransfer, w, nil, 0)
@ -452,6 +452,7 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve
} }
initialSize := int64(0) initialSize := int64(0)
truncatedSize := int64(0) // bytes truncated and not included in quota
if isResume { if isResume {
c.Log(logger.LevelDebug, "resuming upload requested, file path: %#v initial size: %v", filePath, fileSize) c.Log(logger.LevelDebug, "resuming upload requested, file path: %#v initial size: %v", filePath, fileSize)
minWriteOffset = fileSize minWriteOffset = fileSize
@ -473,13 +474,14 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve
} }
} else { } else {
initialSize = fileSize initialSize = fileSize
truncatedSize = fileSize
} }
} }
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs) common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs)
baseTransfer.SetFtpMode(c.getFTPMode()) baseTransfer.SetFtpMode(c.getFTPMode())
t := newTransfer(baseTransfer, w, nil, 0) t := newTransfer(baseTransfer, w, nil, 0)

View file

@ -808,7 +808,7 @@ func TestTransferErrors(t *testing.T) {
clientContext: mockCC, clientContext: mockCC,
} }
baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, file.Name(), file.Name(), testfile, baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, file.Name(), file.Name(), testfile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
tr := newTransfer(baseTransfer, nil, nil, 0) tr := newTransfer(baseTransfer, nil, nil, 0)
err = tr.Close() err = tr.Close()
assert.NoError(t, err) assert.NoError(t, err)
@ -826,7 +826,7 @@ func TestTransferErrors(t *testing.T) {
r, _, err := pipeat.Pipe() r, _, err := pipeat.Pipe()
assert.NoError(t, err) assert.NoError(t, err)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile,
common.TransferUpload, 0, 0, 0, false, fs) common.TransferUpload, 0, 0, 0, 0, false, fs)
tr = newTransfer(baseTransfer, nil, r, 10) tr = newTransfer(baseTransfer, nil, r, 10)
pos, err := tr.Seek(10, 0) pos, err := tr.Seek(10, 0)
assert.NoError(t, err) assert.NoError(t, err)
@ -838,7 +838,7 @@ func TestTransferErrors(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
pipeWriter := vfs.NewPipeWriter(w) pipeWriter := vfs.NewPipeWriter(w)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile,
common.TransferUpload, 0, 0, 0, false, fs) common.TransferUpload, 0, 0, 0, 0, false, fs)
tr = newTransfer(baseTransfer, pipeWriter, nil, 0) tr = newTransfer(baseTransfer, pipeWriter, nil, 0)
err = r.Close() err = r.Close()

16
go.mod
View file

@ -7,8 +7,8 @@ require (
github.com/Azure/azure-storage-blob-go v0.14.0 github.com/Azure/azure-storage-blob-go v0.14.0
github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962 github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962
github.com/alexedwards/argon2id v0.0.0-20211130144151-3585854a6387 github.com/alexedwards/argon2id v0.0.0-20211130144151-3585854a6387
github.com/aws/aws-sdk-go v1.42.35 github.com/aws/aws-sdk-go v1.42.37
github.com/cockroachdb/cockroach-go/v2 v2.2.5 github.com/cockroachdb/cockroach-go/v2 v2.2.6
github.com/eikenb/pipeat v0.0.0-20210603033007-44fc3ffce52b github.com/eikenb/pipeat v0.0.0-20210603033007-44fc3ffce52b
github.com/fclairamb/ftpserverlib v0.17.0 github.com/fclairamb/ftpserverlib v0.17.0
github.com/fclairamb/go-log v0.2.0 github.com/fclairamb/go-log v0.2.0
@ -35,7 +35,7 @@ require (
github.com/pires/go-proxyproto v0.6.1 github.com/pires/go-proxyproto v0.6.1
github.com/pkg/sftp v1.13.5-0.20211217081921-1849af66afae github.com/pkg/sftp v1.13.5-0.20211217081921-1849af66afae
github.com/pquerna/otp v1.3.0 github.com/pquerna/otp v1.3.0
github.com/prometheus/client_golang v1.11.0 github.com/prometheus/client_golang v1.12.0
github.com/rs/cors v1.8.2 github.com/rs/cors v1.8.2
github.com/rs/xid v1.3.0 github.com/rs/xid v1.3.0
github.com/rs/zerolog v1.26.2-0.20211219225053-665519c4da50 github.com/rs/zerolog v1.26.2-0.20211219225053-665519c4da50
@ -62,8 +62,8 @@ require (
require ( require (
cloud.google.com/go v0.100.2 // indirect cloud.google.com/go v0.100.2 // indirect
cloud.google.com/go/compute v1.0.0 // indirect cloud.google.com/go/compute v1.1.0 // indirect
cloud.google.com/go/iam v0.1.0 // indirect cloud.google.com/go/iam v0.1.1 // indirect
github.com/Azure/azure-pipeline-go v0.2.3 // indirect github.com/Azure/azure-pipeline-go v0.2.3 // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/boombuler/barcode v1.0.1 // indirect github.com/boombuler/barcode v1.0.1 // indirect
@ -79,7 +79,7 @@ require (
github.com/goccy/go-json v0.9.3 // indirect github.com/goccy/go-json v0.9.3 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.2 // indirect github.com/golang/protobuf v1.5.2 // indirect
github.com/google/go-cmp v0.5.6 // indirect github.com/google/go-cmp v0.5.7 // indirect
github.com/googleapis/gax-go/v2 v2.1.1 // indirect github.com/googleapis/gax-go/v2 v2.1.1 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
@ -126,10 +126,10 @@ require (
golang.org/x/tools v0.1.8 // indirect golang.org/x/tools v0.1.8 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0 // indirect google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 // indirect
google.golang.org/grpc v1.43.0 // indirect google.golang.org/grpc v1.43.0 // indirect
google.golang.org/protobuf v1.27.1 // indirect google.golang.org/protobuf v1.27.1 // indirect
gopkg.in/ini.v1 v1.66.2 // indirect gopkg.in/ini.v1 v1.66.3 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
) )

31
go.sum
View file

@ -46,14 +46,14 @@ cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUM
cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc=
cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ=
cloud.google.com/go/compute v0.1.0/go.mod h1:GAesmwr110a34z04OlxYkATPBEfVhkymfTBXtfbBFow= cloud.google.com/go/compute v0.1.0/go.mod h1:GAesmwr110a34z04OlxYkATPBEfVhkymfTBXtfbBFow=
cloud.google.com/go/compute v1.0.0 h1:SJYBzih8Jj9EUm6IDirxKG0I0AGWduhtb6BmdqWarw4= cloud.google.com/go/compute v1.1.0 h1:pyPhehLfZ6pVzRgJmXGYvCY4K7WSWRhVw0AwhgVvS84=
cloud.google.com/go/compute v1.0.0/go.mod h1:GAesmwr110a34z04OlxYkATPBEfVhkymfTBXtfbBFow= cloud.google.com/go/compute v1.1.0/go.mod h1:2NIffxgWfORSI7EOYMFatGTfjMLnqrOKBEyYb6NoRgA=
cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
cloud.google.com/go/firestore v1.5.0/go.mod h1:c4nNYR1qdq7eaZ+jSc5fonrQN2k3M7sWATcYTiakjEo= cloud.google.com/go/firestore v1.5.0/go.mod h1:c4nNYR1qdq7eaZ+jSc5fonrQN2k3M7sWATcYTiakjEo=
cloud.google.com/go/firestore v1.6.1/go.mod h1:asNXNOzBdyVQmEU+ggO8UPodTkEVFW5Qx+rwHnAz+EY= cloud.google.com/go/firestore v1.6.1/go.mod h1:asNXNOzBdyVQmEU+ggO8UPodTkEVFW5Qx+rwHnAz+EY=
cloud.google.com/go/iam v0.1.0 h1:W2vbGCrE3Z7J/x3WXLxxGl9LMSB2uhsAA7Ss/6u/qRY= cloud.google.com/go/iam v0.1.1 h1:4CapQyNFjiksks1/x7jsvsygFPhihslYk5GptIrlX68=
cloud.google.com/go/iam v0.1.0/go.mod h1:vcUNEa0pEm0qRVpmWepWaFMIAI8/hjB9mO8rNCJtF6c= cloud.google.com/go/iam v0.1.1/go.mod h1:CKqrcnI/suGpybEHxZ7BMehL0oA4LpdyJdUlTl9jVMw=
cloud.google.com/go/kms v0.1.0 h1:VXAb5OzejDcyhFzIDeZ5n5AUdlsFnCyexuascIwWMj0= cloud.google.com/go/kms v0.1.0 h1:VXAb5OzejDcyhFzIDeZ5n5AUdlsFnCyexuascIwWMj0=
cloud.google.com/go/kms v0.1.0/go.mod h1:8Qp8PCAypHg4FdmlyW1QRAv09BGQ9Uzh7JnmIZxPk+c= cloud.google.com/go/kms v0.1.0/go.mod h1:8Qp8PCAypHg4FdmlyW1QRAv09BGQ9Uzh7JnmIZxPk+c=
cloud.google.com/go/monitoring v0.1.0/go.mod h1:Hpm3XfzJv+UTiXzCG5Ffp0wijzHTC7Cv4eR7o3x/fEE= cloud.google.com/go/monitoring v0.1.0/go.mod h1:Hpm3XfzJv+UTiXzCG5Ffp0wijzHTC7Cv4eR7o3x/fEE=
@ -141,8 +141,8 @@ github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgI
github.com/aws/aws-sdk-go v1.15.27/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0= github.com/aws/aws-sdk-go v1.15.27/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0=
github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro=
github.com/aws/aws-sdk-go v1.40.34/go.mod h1:585smgzpB/KqRA+K3y/NL/oYRqQvpNJYvLm+LY1U59Q= github.com/aws/aws-sdk-go v1.40.34/go.mod h1:585smgzpB/KqRA+K3y/NL/oYRqQvpNJYvLm+LY1U59Q=
github.com/aws/aws-sdk-go v1.42.35 h1:N4N9buNs4YlosI9N0+WYrq8cIZwdgv34yRbxzZlTvFs= github.com/aws/aws-sdk-go v1.42.37 h1:EIziSq3REaoi1LgUBgxoQr29DQS7GYHnBbZPajtJmXM=
github.com/aws/aws-sdk-go v1.42.35/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc= github.com/aws/aws-sdk-go v1.42.37/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc=
github.com/aws/aws-sdk-go-v2 v1.9.0/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4= github.com/aws/aws-sdk-go-v2 v1.9.0/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4=
github.com/aws/aws-sdk-go-v2/config v1.7.0/go.mod h1:w9+nMZ7soXCe5nT46Ri354SNhXDQ6v+V5wqDjnZE+GY= github.com/aws/aws-sdk-go-v2/config v1.7.0/go.mod h1:w9+nMZ7soXCe5nT46Ri354SNhXDQ6v+V5wqDjnZE+GY=
github.com/aws/aws-sdk-go-v2/credentials v1.4.0/go.mod h1:dgGR+Qq7Wjcd4AOAW5Rf5Tnv3+x7ed6kETXyS9WCuAY= github.com/aws/aws-sdk-go-v2/credentials v1.4.0/go.mod h1:dgGR+Qq7Wjcd4AOAW5Rf5Tnv3+x7ed6kETXyS9WCuAY=
@ -190,8 +190,8 @@ github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWH
github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20211130200136-a8f946100490/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211130200136-a8f946100490/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/cockroachdb/cockroach-go/v2 v2.2.5 h1:tfPdGHO5YpmrpN2ikJZYpaSGgU8WALwwjH3s+msiTQ0= github.com/cockroachdb/cockroach-go/v2 v2.2.6 h1:LTh++UIVvmDBihDo1oYbM8+OruXheusw+ILCONlAm/w=
github.com/cockroachdb/cockroach-go/v2 v2.2.5/go.mod h1:q4ZRgO6CQpwNyEvEwSxwNrOSVchsmzrBnAv3HuZ3Abc= github.com/cockroachdb/cockroach-go/v2 v2.2.6/go.mod h1:q4ZRgO6CQpwNyEvEwSxwNrOSVchsmzrBnAv3HuZ3Abc=
github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f h1:JOrtw2xFKzlg+cbHpyrpLDmnN1HqhBfnX7WDiW7eG2c= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f h1:JOrtw2xFKzlg+cbHpyrpLDmnN1HqhBfnX7WDiW7eG2c=
@ -343,8 +343,9 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/google/go-replayers/grpcreplay v1.1.0/go.mod h1:qzAvJ8/wi57zq7gWqaE6AwLM6miiXUQwP1S+I9icmhk= github.com/google/go-replayers/grpcreplay v1.1.0/go.mod h1:qzAvJ8/wi57zq7gWqaE6AwLM6miiXUQwP1S+I9icmhk=
github.com/google/go-replayers/httpreplay v1.0.0/go.mod h1:LJhKoTwS5Wy5Ld/peq8dFFG5OfJyHEz7ft+DsTUv25M= github.com/google/go-replayers/httpreplay v1.0.0/go.mod h1:LJhKoTwS5Wy5Ld/peq8dFFG5OfJyHEz7ft+DsTUv25M=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@ -649,8 +650,9 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU= github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0 h1:HNkLOAEQMIDv/K+04rukrLx6ch7msSRwf3/SASFAGtQ=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.12.0 h1:C+UIj/QWtmqY13Arb8kwMt5j34/0Z2iKamrJ+ryC0Gg=
github.com/prometheus/client_golang v1.12.0/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@ -1075,6 +1077,7 @@ google.golang.org/api v0.59.0/go.mod h1:sT2boj7M9YJxZzgeZqXogmhfmRWDtPzT31xkieUb
google.golang.org/api v0.61.0/go.mod h1:xQRti5UdCmoCEqFxcz93fTl338AVqDgyaDRuOZ3hg9I= google.golang.org/api v0.61.0/go.mod h1:xQRti5UdCmoCEqFxcz93fTl338AVqDgyaDRuOZ3hg9I=
google.golang.org/api v0.62.0/go.mod h1:dKmwPCydfsad4qCH08MSdgWjfHOyfpd4VtDGgRFdavw= google.golang.org/api v0.62.0/go.mod h1:dKmwPCydfsad4qCH08MSdgWjfHOyfpd4VtDGgRFdavw=
google.golang.org/api v0.63.0/go.mod h1:gs4ij2ffTRXwuzzgJl/56BdwJaA194ijkfn++9tDuPo= google.golang.org/api v0.63.0/go.mod h1:gs4ij2ffTRXwuzzgJl/56BdwJaA194ijkfn++9tDuPo=
google.golang.org/api v0.64.0/go.mod h1:931CdxA8Rm4t6zqTFGSsgwbAEZ2+GMYurbndwSimebM=
google.golang.org/api v0.65.0 h1:MTW9c+LIBAbwoS1Gb+YV7NjFBt2f7GtAS5hIzh2NjgQ= google.golang.org/api v0.65.0 h1:MTW9c+LIBAbwoS1Gb+YV7NjFBt2f7GtAS5hIzh2NjgQ=
google.golang.org/api v0.65.0/go.mod h1:ArYhxgGadlWmqO1IqVujw6Cs8IdD33bTmzKo2Sh+cbg= google.golang.org/api v0.65.0/go.mod h1:ArYhxgGadlWmqO1IqVujw6Cs8IdD33bTmzKo2Sh+cbg=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
@ -1162,8 +1165,9 @@ google.golang.org/genproto v0.0.0-20211208223120-3a66f561d7aa/go.mod h1:5CzLGKJ6
google.golang.org/genproto v0.0.0-20211221195035-429b39de9b1c/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211221195035-429b39de9b1c/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20220107163113-42d7afdf6368/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20220107163113-42d7afdf6368/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0 h1:aCsSLXylHWFno0r4S3joLpiaWayvqd2Mn4iSvx4WZZc= google.golang.org/genproto v0.0.0-20220111164026-67b88f271998/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 h1:zzNejm+EgrbLfDZ6lu9Uud2IVvHySPl8vQzf04laR5Q=
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
@ -1217,8 +1221,9 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
gopkg.in/ini.v1 v1.66.2 h1:XfR1dOYubytKy4Shzc2LHrrGhU0lDCfDGG1yLPmpgsI=
gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/ini.v1 v1.66.3 h1:jRskFVxYaMGAMUbN0UZ7niA9gzL9B49DOqE78vg0k3w=
gopkg.in/ini.v1 v1.66.3/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8=
gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View file

@ -1,7 +1,6 @@
package httpd package httpd
import ( import (
"errors"
"io" "io"
"sync/atomic" "sync/atomic"
@ -11,8 +10,6 @@ import (
"github.com/drakkan/sftpgo/v2/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
var errTransferAborted = errors.New("transfer aborted")
type httpdFile struct { type httpdFile struct {
*common.BaseTransfer *common.BaseTransfer
writer io.WriteCloser writer io.WriteCloser
@ -42,7 +39,9 @@ func newHTTPDFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter,
// Read reads the contents to downloads. // Read reads the contents to downloads.
func (f *httpdFile) Read(p []byte) (n int, err error) { func (f *httpdFile) Read(p []byte) (n int, err error) {
if atomic.LoadInt32(&f.AbortTransfer) == 1 { if atomic.LoadInt32(&f.AbortTransfer) == 1 {
return 0, errTransferAborted err := f.GetAbortError()
f.TransferError(err)
return 0, err
} }
f.Connection.UpdateLastActivity() f.Connection.UpdateLastActivity()
@ -61,7 +60,9 @@ func (f *httpdFile) Read(p []byte) (n int, err error) {
// Write writes the contents to upload // Write writes the contents to upload
func (f *httpdFile) Write(p []byte) (n int, err error) { func (f *httpdFile) Write(p []byte) (n int, err error) {
if atomic.LoadInt32(&f.AbortTransfer) == 1 { if atomic.LoadInt32(&f.AbortTransfer) == 1 {
return 0, errTransferAborted err := f.GetAbortError()
f.TransferError(err)
return 0, err
} }
f.Connection.UpdateLastActivity() f.Connection.UpdateLastActivity()

View file

@ -6,6 +6,7 @@ import (
"os" "os"
"path" "path"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -113,7 +114,7 @@ func (c *Connection) getFileReader(name string, offset int64, method string) (io
} }
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, name, common.TransferDownload, baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, name, common.TransferDownload,
0, 0, 0, false, fs) 0, 0, 0, 0, false, fs)
return newHTTPDFile(baseTransfer, nil, r), nil return newHTTPDFile(baseTransfer, nil, r), nil
} }
@ -190,6 +191,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
} }
initialSize := int64(0) initialSize := int64(0)
truncatedSize := int64(0) // bytes truncated and not included in quota
if !isNewFile { if !isNewFile {
if vfs.IsLocalOrSFTPFs(fs) { if vfs.IsLocalOrSFTPFs(fs) {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
@ -203,6 +205,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
} }
} else { } else {
initialSize = fileSize initialSize = fileSize
truncatedSize = fileSize
} }
if maxWriteSize > 0 { if maxWriteSize > 0 {
maxWriteSize += fileSize maxWriteSize += fileSize
@ -212,7 +215,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs) common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs)
return newHTTPDFile(baseTransfer, w, nil), nil return newHTTPDFile(baseTransfer, w, nil), nil
} }
@ -232,15 +235,17 @@ func newThrottledReader(r io.ReadCloser, limit int64, conn *Connection) *throttl
type throttledReader struct { type throttledReader struct {
bytesRead int64 bytesRead int64
id uint64 id int64
limit int64 limit int64
r io.ReadCloser r io.ReadCloser
abortTransfer int32 abortTransfer int32
start time.Time start time.Time
conn *Connection conn *Connection
mu sync.Mutex
errAbort error
} }
func (t *throttledReader) GetID() uint64 { func (t *throttledReader) GetID() int64 {
return t.id return t.id
} }
@ -252,6 +257,14 @@ func (t *throttledReader) GetSize() int64 {
return atomic.LoadInt64(&t.bytesRead) return atomic.LoadInt64(&t.bytesRead)
} }
func (t *throttledReader) GetDownloadedSize() int64 {
return 0
}
func (t *throttledReader) GetUploadedSize() int64 {
return atomic.LoadInt64(&t.bytesRead)
}
func (t *throttledReader) GetVirtualPath() string { func (t *throttledReader) GetVirtualPath() string {
return "**reading request body**" return "**reading request body**"
} }
@ -260,10 +273,31 @@ func (t *throttledReader) GetStartTime() time.Time {
return t.start return t.start
} }
func (t *throttledReader) SignalClose() { func (t *throttledReader) GetAbortError() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.errAbort != nil {
return t.errAbort
}
return common.ErrTransferAborted
}
func (t *throttledReader) SignalClose(err error) {
t.mu.Lock()
t.errAbort = err
t.mu.Unlock()
atomic.StoreInt32(&(t.abortTransfer), 1) atomic.StoreInt32(&(t.abortTransfer), 1)
} }
func (t *throttledReader) GetTruncatedSize() int64 {
return 0
}
func (t *throttledReader) GetMaxAllowedSize() int64 {
return 0
}
func (t *throttledReader) Truncate(fsPath string, size int64) (int64, error) { func (t *throttledReader) Truncate(fsPath string, size int64) (int64, error) {
return 0, vfs.ErrVfsUnsupported return 0, vfs.ErrVfsUnsupported
} }
@ -278,7 +312,7 @@ func (t *throttledReader) SetTimes(fsPath string, atime time.Time, mtime time.Ti
func (t *throttledReader) Read(p []byte) (n int, err error) { func (t *throttledReader) Read(p []byte) (n int, err error) {
if atomic.LoadInt32(&t.abortTransfer) == 1 { if atomic.LoadInt32(&t.abortTransfer) == 1 {
return 0, errTransferAborted return 0, t.GetAbortError()
} }
t.conn.UpdateLastActivity() t.conn.UpdateLastActivity()

View file

@ -1844,12 +1844,15 @@ func TestThrottledHandler(t *testing.T) {
tr := &throttledReader{ tr := &throttledReader{
r: io.NopCloser(bytes.NewBuffer(nil)), r: io.NopCloser(bytes.NewBuffer(nil)),
} }
assert.Equal(t, int64(0), tr.GetTruncatedSize())
err := tr.Close() err := tr.Close()
assert.NoError(t, err) assert.NoError(t, err)
assert.Empty(t, tr.GetRealFsPath("real path")) assert.Empty(t, tr.GetRealFsPath("real path"))
assert.False(t, tr.SetTimes("p", time.Now(), time.Now())) assert.False(t, tr.SetTimes("p", time.Now(), time.Now()))
_, err = tr.Truncate("", 0) _, err = tr.Truncate("", 0)
assert.ErrorIs(t, err, vfs.ErrVfsUnsupported) assert.ErrorIs(t, err, vfs.ErrVfsUnsupported)
err = tr.GetAbortError()
assert.ErrorIs(t, err, common.ErrTransferAborted)
} }
func TestHTTPDFile(t *testing.T) { func TestHTTPDFile(t *testing.T) {
@ -1879,7 +1882,7 @@ func TestHTTPDFile(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, p, p, name, common.TransferDownload, baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, p, p, name, common.TransferDownload,
0, 0, 0, false, fs) 0, 0, 0, 0, false, fs)
httpdFile := newHTTPDFile(baseTransfer, nil, nil) httpdFile := newHTTPDFile(baseTransfer, nil, nil)
// the file is closed, read should fail // the file is closed, read should fail
buf := make([]byte, 100) buf := make([]byte, 100)
@ -1899,9 +1902,9 @@ func TestHTTPDFile(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
assert.Error(t, httpdFile.ErrTransfer) assert.Error(t, httpdFile.ErrTransfer)
assert.Equal(t, err, httpdFile.ErrTransfer) assert.Equal(t, err, httpdFile.ErrTransfer)
httpdFile.SignalClose() httpdFile.SignalClose(nil)
_, err = httpdFile.Write(nil) _, err = httpdFile.Write(nil)
assert.ErrorIs(t, err, errTransferAborted) assert.ErrorIs(t, err, common.ErrQuotaExceeded)
} }
func TestChangeUserPwd(t *testing.T) { func TestChangeUserPwd(t *testing.T) {

View file

@ -85,7 +85,7 @@ func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
} }
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, request.Filepath, common.TransferDownload, baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, request.Filepath, common.TransferDownload,
0, 0, 0, false, fs) 0, 0, 0, 0, false, fs)
t := newTransfer(baseTransfer, nil, r, nil) t := newTransfer(baseTransfer, nil, r, nil)
return t, nil return t, nil
@ -364,7 +364,7 @@ func (c *Connection) handleSFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, 0, 0, maxWriteSize, true, fs) common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
t := newTransfer(baseTransfer, w, nil, errForRead) t := newTransfer(baseTransfer, w, nil, errForRead)
return t, nil return t, nil
@ -415,6 +415,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO
} }
initialSize := int64(0) initialSize := int64(0)
truncatedSize := int64(0) // bytes truncated and not included in quota
if isResume { if isResume {
c.Log(logger.LevelDebug, "resuming upload requested, file path %#v initial size: %v has append flag %v", c.Log(logger.LevelDebug, "resuming upload requested, file path %#v initial size: %v has append flag %v",
filePath, fileSize, pflags.Append) filePath, fileSize, pflags.Append)
@ -436,13 +437,14 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO
} }
} else { } else {
initialSize = fileSize initialSize = fileSize
truncatedSize = fileSize
} }
} }
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs) common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs)
t := newTransfer(baseTransfer, w, nil, errForRead) t := newTransfer(baseTransfer, w, nil, errForRead)
return t, nil return t, nil

View file

@ -162,7 +162,8 @@ func TestUploadResumeInvalidOffset(t *testing.T) {
} }
fs := vfs.NewOsFs("", os.TempDir(), "") fs := vfs.NewOsFs("", os.TempDir(), "")
conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user)
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferUpload, 10, 0, 0, false, fs) baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile,
common.TransferUpload, 10, 0, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil) transfer := newTransfer(baseTransfer, nil, nil, nil)
_, err = transfer.WriteAt([]byte("test"), 0) _, err = transfer.WriteAt([]byte("test"), 0)
assert.Error(t, err, "upload with invalid offset must fail") assert.Error(t, err, "upload with invalid offset must fail")
@ -193,7 +194,8 @@ func TestReadWriteErrors(t *testing.T) {
} }
fs := vfs.NewOsFs("", os.TempDir(), "") fs := vfs.NewOsFs("", os.TempDir(), "")
conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user)
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs) baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload,
0, 0, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil) transfer := newTransfer(baseTransfer, nil, nil, nil)
err = file.Close() err = file.Close()
assert.NoError(t, err) assert.NoError(t, err)
@ -207,7 +209,8 @@ func TestReadWriteErrors(t *testing.T) {
r, _, err := pipeat.Pipe() r, _, err := pipeat.Pipe()
assert.NoError(t, err) assert.NoError(t, err)
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs) baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload,
0, 0, 0, 0, false, fs)
transfer = newTransfer(baseTransfer, nil, r, nil) transfer = newTransfer(baseTransfer, nil, r, nil)
err = transfer.Close() err = transfer.Close()
assert.NoError(t, err) assert.NoError(t, err)
@ -217,7 +220,8 @@ func TestReadWriteErrors(t *testing.T) {
r, w, err := pipeat.Pipe() r, w, err := pipeat.Pipe()
assert.NoError(t, err) assert.NoError(t, err)
pipeWriter := vfs.NewPipeWriter(w) pipeWriter := vfs.NewPipeWriter(w)
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs) baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload,
0, 0, 0, 0, false, fs)
transfer = newTransfer(baseTransfer, pipeWriter, nil, nil) transfer = newTransfer(baseTransfer, pipeWriter, nil, nil)
err = r.Close() err = r.Close()
@ -264,7 +268,8 @@ func TestTransferCancelFn(t *testing.T) {
} }
fs := vfs.NewOsFs("", os.TempDir(), "") fs := vfs.NewOsFs("", os.TempDir(), "")
conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user)
baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs) baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), file.Name(), testfile, common.TransferDownload,
0, 0, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil) transfer := newTransfer(baseTransfer, nil, nil, nil)
errFake := errors.New("fake error, this will trigger cancelFn") errFake := errors.New("fake error, this will trigger cancelFn")
@ -971,8 +976,8 @@ func TestSystemCommandErrors(t *testing.T) {
WriteError: nil, WriteError: nil,
} }
sshCmd.connection.channel = &mockSSHChannel sshCmd.connection.channel = &mockSSHChannel
baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", "", common.TransferDownload, baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", "",
0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil) transfer := newTransfer(baseTransfer, nil, nil, nil)
destBuff := make([]byte, 65535) destBuff := make([]byte, 65535)
dst := bytes.NewBuffer(destBuff) dst := bytes.NewBuffer(destBuff)
@ -1639,7 +1644,7 @@ func TestSCPUploadFiledata(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
baseTransfer := common.NewBaseTransfer(file, scpCommand.connection.BaseConnection, nil, file.Name(), file.Name(), baseTransfer := common.NewBaseTransfer(file, scpCommand.connection.BaseConnection, nil, file.Name(), file.Name(),
"/"+testfile, common.TransferDownload, 0, 0, 0, true, fs) "/"+testfile, common.TransferDownload, 0, 0, 0, 0, true, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil) transfer := newTransfer(baseTransfer, nil, nil, nil)
err = scpCommand.getUploadFileData(2, transfer) err = scpCommand.getUploadFileData(2, transfer)
@ -1724,7 +1729,7 @@ func TestUploadError(t *testing.T) {
file, err := os.Create(fileTempName) file, err := os.Create(fileTempName)
assert.NoError(t, err) assert.NoError(t, err)
baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, testfile, file.Name(), baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, testfile, file.Name(),
testfile, common.TransferUpload, 0, 0, 0, true, fs) testfile, common.TransferUpload, 0, 0, 0, 0, true, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil) transfer := newTransfer(baseTransfer, nil, nil, nil)
errFake := errors.New("fake error") errFake := errors.New("fake error")
@ -1782,7 +1787,8 @@ func TestTransferFailingReader(t *testing.T) {
r, _, err := pipeat.Pipe() r, _, err := pipeat.Pipe()
assert.NoError(t, err) assert.NoError(t, err)
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, fsPath, filepath.Base(fsPath), common.TransferUpload, 0, 0, 0, false, fs) baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, fsPath, filepath.Base(fsPath),
common.TransferUpload, 0, 0, 0, 0, false, fs)
errRead := errors.New("read is not allowed") errRead := errors.New("read is not allowed")
tr := newTransfer(baseTransfer, nil, r, errRead) tr := newTransfer(baseTransfer, nil, r, errRead)
_, err = tr.ReadAt(buf, 0) _, err = tr.ReadAt(buf, 0)

View file

@ -238,6 +238,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
} }
initialSize := int64(0) initialSize := int64(0)
truncatedSize := int64(0) // bytes truncated and not included in quota
if !isNewFile { if !isNewFile {
if vfs.IsLocalOrSFTPFs(fs) { if vfs.IsLocalOrSFTPFs(fs) {
vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath)) vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
@ -251,6 +252,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
} }
} else { } else {
initialSize = fileSize initialSize = fileSize
truncatedSize = initialSize
} }
if maxWriteSize > 0 { if maxWriteSize > 0 {
maxWriteSize += fileSize maxWriteSize += fileSize
@ -260,7 +262,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
vfs.SetPathPermissions(fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID()) vfs.SetPathPermissions(fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID())
baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs) common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs)
t := newTransfer(baseTransfer, w, nil, nil) t := newTransfer(baseTransfer, w, nil, nil)
return c.getUploadFileData(sizeToRead, t) return c.getUploadFileData(sizeToRead, t)
@ -529,7 +531,7 @@ func (c *scpCommand) handleDownload(filePath string) error {
} }
baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, p, filePath, baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, p, filePath,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
t := newTransfer(baseTransfer, nil, r, nil) t := newTransfer(baseTransfer, nil, r, nil)
err = c.sendDownloadFileData(fs, p, stat, t) err = c.sendDownloadFileData(fs, p, stat, t)

View file

@ -356,7 +356,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
go func() { go func() {
defer stdin.Close() defer stdin.Close()
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath, baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
common.TransferUpload, 0, 0, remainingQuotaSize, false, command.fs) common.TransferUpload, 0, 0, remainingQuotaSize, 0, false, command.fs)
transfer := newTransfer(baseTransfer, nil, nil, nil) transfer := newTransfer(baseTransfer, nil, nil, nil)
w, e := transfer.copyFromReaderToWriter(stdin, c.connection.channel) w, e := transfer.copyFromReaderToWriter(stdin, c.connection.channel)
@ -369,7 +369,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
go func() { go func() {
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath, baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
common.TransferDownload, 0, 0, 0, false, command.fs) common.TransferDownload, 0, 0, 0, 0, false, command.fs)
transfer := newTransfer(baseTransfer, nil, nil, nil) transfer := newTransfer(baseTransfer, nil, nil, nil)
w, e := transfer.copyFromReaderToWriter(c.connection.channel, stdout) w, e := transfer.copyFromReaderToWriter(c.connection.channel, stdout)
@ -383,7 +383,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
go func() { go func() {
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath, baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
common.TransferDownload, 0, 0, 0, false, command.fs) common.TransferDownload, 0, 0, 0, 0, false, command.fs)
transfer := newTransfer(baseTransfer, nil, nil, nil) transfer := newTransfer(baseTransfer, nil, nil, nil)
w, e := transfer.copyFromReaderToWriter(c.connection.channel.(ssh.Channel).Stderr(), stderr) w, e := transfer.copyFromReaderToWriter(c.connection.channel.(ssh.Channel).Stderr(), stderr)

View file

@ -4,23 +4,23 @@ go 1.17
require ( require (
github.com/hashicorp/go-plugin v1.4.3 github.com/hashicorp/go-plugin v1.4.3
github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea
) )
require ( require (
github.com/fatih/color v1.13.0 // indirect github.com/fatih/color v1.13.0 // indirect
github.com/golang/protobuf v1.5.2 // indirect github.com/golang/protobuf v1.5.2 // indirect
github.com/google/go-cmp v0.5.6 // indirect github.com/google/go-cmp v0.5.6 // indirect
github.com/hashicorp/go-hclog v1.0.0 // indirect github.com/hashicorp/go-hclog v1.1.0 // indirect
github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87 // indirect github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect github.com/mattn/go-isatty v0.0.14 // indirect
github.com/mitchellh/go-testing-interface v1.14.1 // indirect github.com/mitchellh/go-testing-interface v1.14.1 // indirect
github.com/oklog/run v1.1.0 // indirect github.com/oklog/run v1.1.0 // indirect
golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98 // indirect golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d // indirect
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e // indirect golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect
golang.org/x/text v0.3.7 // indirect golang.org/x/text v0.3.7 // indirect
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb // indirect google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 // indirect
google.golang.org/grpc v1.43.0 // indirect google.golang.org/grpc v1.43.0 // indirect
google.golang.org/protobuf v1.27.1 // indirect google.golang.org/protobuf v1.27.1 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect

View file

@ -57,8 +57,9 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/hashicorp/go-hclog v0.14.1/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= github.com/hashicorp/go-hclog v0.14.1/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
github.com/hashicorp/go-hclog v1.0.0 h1:bkKf0BeBXcSYa7f5Fyi9gMuQ8gNsxeiNpZjR6VxNZeo=
github.com/hashicorp/go-hclog v1.0.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= github.com/hashicorp/go-hclog v1.0.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
github.com/hashicorp/go-hclog v1.1.0 h1:QsGcniKx5/LuX2eYoeL+Np3UKYPNaN7YKpTh29h8rbw=
github.com/hashicorp/go-hclog v1.1.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
github.com/hashicorp/go-plugin v1.4.3 h1:DXmvivbWD5qdiBts9TpBC7BYL1Aia5sxbRgQB+v6UZM= github.com/hashicorp/go-plugin v1.4.3 h1:DXmvivbWD5qdiBts9TpBC7BYL1Aia5sxbRgQB+v6UZM=
github.com/hashicorp/go-plugin v1.4.3/go.mod h1:5fGEH17QVwTTcR0zV7yhDPLLmFX9YSZ38b18Udy6vYQ= github.com/hashicorp/go-plugin v1.4.3/go.mod h1:5fGEH17QVwTTcR0zV7yhDPLLmFX9YSZ38b18Udy6vYQ=
github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM=
@ -85,8 +86,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a h1:JJc19rE0eW2knPa/KIFYvqyu25CwzKltJ5Cw1kK3o4A= github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea h1:ouwL3x9tXiAXIhdXtJGONd905f1dBLu3HhfFoaTq24k=
github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a/go.mod h1:Bhgac6kiwIziILXLzH4wepT8lQXyhF83poDXqZorN6Q= github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea/go.mod h1:Bhgac6kiwIziILXLzH4wepT8lQXyhF83poDXqZorN6Q=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
@ -110,8 +111,9 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98 h1:+6WJMRLHlD7X7frgp7TUZ36RnQzSf9wVVTNakEp+nqY=
golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d h1:1n1fc535VhN8SYtD4cDUyNlfpAF2ROMM9+11equK3hs=
golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -132,8 +134,9 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 h1:XfKQ4OlFl8okEOr5UvAqFRVj8pY/4yfcXrddB8qAbU0=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@ -156,8 +159,9 @@ google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoA
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb h1:ZrsicilzPCS/Xr8qtBZZLpy4P9TYXAfl49ctG1/5tgw=
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 h1:zzNejm+EgrbLfDZ6lu9Uud2IVvHySPl8vQzf04laR5Q=
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=

View file

@ -149,8 +149,8 @@ func (c *Connection) getFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File
} }
} }
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath, common.TransferDownload, baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath,
0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
return newWebDavFile(baseTransfer, nil, r), nil return newWebDavFile(baseTransfer, nil, r), nil
} }
@ -214,7 +214,7 @@ func (c *Connection) handleUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, re
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, 0, 0, maxWriteSize, true, fs) common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
return newWebDavFile(baseTransfer, w, nil), nil return newWebDavFile(baseTransfer, w, nil), nil
} }
@ -252,6 +252,7 @@ func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePat
return nil, c.GetFsError(fs, err) return nil, c.GetFsError(fs, err)
} }
initialSize := int64(0) initialSize := int64(0)
truncatedSize := int64(0) // bytes truncated and not included in quota
if vfs.IsLocalOrSFTPFs(fs) { if vfs.IsLocalOrSFTPFs(fs) {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil { if err == nil {
@ -264,12 +265,13 @@ func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePat
} }
} else { } else {
initialSize = fileSize initialSize = fileSize
truncatedSize = fileSize
} }
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
common.TransferUpload, 0, initialSize, maxWriteSize, false, fs) common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, false, fs)
return newWebDavFile(baseTransfer, w, nil), nil return newWebDavFile(baseTransfer, w, nil), nil
} }

View file

@ -695,7 +695,7 @@ func TestContentType(t *testing.T) {
testFilePath := filepath.Join(user.HomeDir, testFile) testFilePath := filepath.Join(user.HomeDir, testFile)
ctx := context.Background() ctx := context.Background()
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
fs = newMockOsFs(nil, false, fs.ConnectionID(), user.GetHomeDir(), nil) fs = newMockOsFs(nil, false, fs.ConnectionID(), user.GetHomeDir(), nil)
err := os.WriteFile(testFilePath, []byte(""), os.ModePerm) err := os.WriteFile(testFilePath, []byte(""), os.ModePerm)
assert.NoError(t, err) assert.NoError(t, err)
@ -745,7 +745,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
} }
testFilePath := filepath.Join(user.HomeDir, testFile) testFilePath := filepath.Join(user.HomeDir, testFile)
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferUpload, 0, 0, 0, false, fs) common.TransferUpload, 0, 0, 0, 0, false, fs)
davFile := newWebDavFile(baseTransfer, nil, nil) davFile := newWebDavFile(baseTransfer, nil, nil)
p := make([]byte, 1) p := make([]byte, 1)
_, err := davFile.Read(p) _, err := davFile.Read(p)
@ -763,7 +763,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Read(p) _, err = davFile.Read(p)
assert.True(t, os.IsNotExist(err)) assert.True(t, os.IsNotExist(err))
@ -771,7 +771,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
assert.True(t, os.IsNotExist(err)) assert.True(t, os.IsNotExist(err))
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
err = os.WriteFile(testFilePath, []byte(""), os.ModePerm) err = os.WriteFile(testFilePath, []byte(""), os.ModePerm)
assert.NoError(t, err) assert.NoError(t, err)
f, err := os.Open(testFilePath) f, err := os.Open(testFilePath)
@ -796,7 +796,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
mockFs := newMockOsFs(nil, false, fs.ConnectionID(), user.HomeDir, r) mockFs := newMockOsFs(nil, false, fs.ConnectionID(), user.HomeDir, r)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, mockFs) common.TransferDownload, 0, 0, 0, 0, false, mockFs)
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
writeContent := []byte("content\r\n") writeContent := []byte("content\r\n")
@ -816,7 +816,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.writer = f davFile.writer = f
err = davFile.Close() err = davFile.Close()
@ -841,7 +841,7 @@ func TestTransferSeek(t *testing.T) {
testFilePath := filepath.Join(user.HomeDir, testFile) testFilePath := filepath.Join(user.HomeDir, testFile)
testFileContents := []byte("content") testFileContents := []byte("content")
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferUpload, 0, 0, 0, false, fs) common.TransferUpload, 0, 0, 0, 0, false, fs)
davFile := newWebDavFile(baseTransfer, nil, nil) davFile := newWebDavFile(baseTransfer, nil, nil)
_, err := davFile.Seek(0, io.SeekStart) _, err := davFile.Seek(0, io.SeekStart)
assert.EqualError(t, err, common.ErrOpUnsupported.Error()) assert.EqualError(t, err, common.ErrOpUnsupported.Error())
@ -849,7 +849,7 @@ func TestTransferSeek(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekCurrent) _, err = davFile.Seek(0, io.SeekCurrent)
assert.True(t, os.IsNotExist(err)) assert.True(t, os.IsNotExist(err))
@ -863,14 +863,14 @@ func TestTransferSeek(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekStart) _, err = davFile.Seek(0, io.SeekStart)
assert.Error(t, err) assert.Error(t, err)
davFile.Connection.RemoveTransfer(davFile.BaseTransfer) davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
res, err := davFile.Seek(0, io.SeekStart) res, err := davFile.Seek(0, io.SeekStart)
assert.NoError(t, err) assert.NoError(t, err)
@ -885,14 +885,14 @@ func TestTransferSeek(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekEnd) _, err = davFile.Seek(0, io.SeekEnd)
assert.True(t, os.IsNotExist(err)) assert.True(t, os.IsNotExist(err))
davFile.Connection.RemoveTransfer(davFile.BaseTransfer) davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.reader = f davFile.reader = f
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil) davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil)
@ -907,7 +907,7 @@ func TestTransferSeek(t *testing.T) {
assert.Equal(t, int64(5), res) assert.Equal(t, int64(5), res)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile,
common.TransferDownload, 0, 0, 0, false, fs) common.TransferDownload, 0, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil) davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil)