mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-21 23:20:24 +00:00
check quota usage between ongoing transfers
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
parent
d73be7aee5
commit
d2a4178846
30 changed files with 1228 additions and 158 deletions
131
common/common.go
131
common/common.go
|
@ -53,9 +53,10 @@ const (
|
|||
operationMkdir = "mkdir"
|
||||
operationRmdir = "rmdir"
|
||||
// SSH command action name
|
||||
OperationSSHCmd = "ssh_cmd"
|
||||
chtimesFormat = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS
|
||||
idleTimeoutCheckInterval = 3 * time.Minute
|
||||
OperationSSHCmd = "ssh_cmd"
|
||||
chtimesFormat = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS
|
||||
idleTimeoutCheckInterval = 3 * time.Minute
|
||||
periodicTimeoutCheckInterval = 1 * time.Minute
|
||||
)
|
||||
|
||||
// Stat flags
|
||||
|
@ -110,6 +111,7 @@ var (
|
|||
ErrCrtRevoked = errors.New("your certificate has been revoked")
|
||||
ErrNoCredentials = errors.New("no credential provided")
|
||||
ErrInternalFailure = errors.New("internal failure")
|
||||
ErrTransferAborted = errors.New("transfer aborted")
|
||||
errNoTransfer = errors.New("requested transfer not found")
|
||||
errTransferMismatch = errors.New("transfer mismatch")
|
||||
)
|
||||
|
@ -120,10 +122,11 @@ var (
|
|||
// Connections is the list of active connections
|
||||
Connections ActiveConnections
|
||||
// QuotaScans is the list of active quota scans
|
||||
QuotaScans ActiveScans
|
||||
idleTimeoutTicker *time.Ticker
|
||||
idleTimeoutTickerDone chan bool
|
||||
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
|
||||
QuotaScans ActiveScans
|
||||
transfersChecker TransfersChecker
|
||||
periodicTimeoutTicker *time.Ticker
|
||||
periodicTimeoutTickerDone chan bool
|
||||
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
|
||||
ProtocolHTTP, ProtocolHTTPShare}
|
||||
disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
|
||||
// the map key is the protocol, for each protocol we can have multiple rate limiters
|
||||
|
@ -135,9 +138,7 @@ func Initialize(c Configuration) error {
|
|||
Config = c
|
||||
Config.idleLoginTimeout = 2 * time.Minute
|
||||
Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute
|
||||
if Config.IdleTimeout > 0 {
|
||||
startIdleTimeoutTicker(idleTimeoutCheckInterval)
|
||||
}
|
||||
startPeriodicTimeoutTicker(periodicTimeoutCheckInterval)
|
||||
Config.defender = nil
|
||||
rateLimiters = make(map[string][]*rateLimiter)
|
||||
for _, rlCfg := range c.RateLimitersConfig {
|
||||
|
@ -176,6 +177,7 @@ func Initialize(c Configuration) error {
|
|||
}
|
||||
vfs.SetTempPath(c.TempPath)
|
||||
dataprovider.SetTempPath(c.TempPath)
|
||||
transfersChecker = getTransfersChecker()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -267,41 +269,52 @@ func AddDefenderEvent(ip string, event HostEvent) {
|
|||
}
|
||||
|
||||
// the ticker cannot be started/stopped from multiple goroutines
|
||||
func startIdleTimeoutTicker(duration time.Duration) {
|
||||
stopIdleTimeoutTicker()
|
||||
idleTimeoutTicker = time.NewTicker(duration)
|
||||
idleTimeoutTickerDone = make(chan bool)
|
||||
func startPeriodicTimeoutTicker(duration time.Duration) {
|
||||
stopPeriodicTimeoutTicker()
|
||||
periodicTimeoutTicker = time.NewTicker(duration)
|
||||
periodicTimeoutTickerDone = make(chan bool)
|
||||
go func() {
|
||||
counter := int64(0)
|
||||
ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval
|
||||
for {
|
||||
select {
|
||||
case <-idleTimeoutTickerDone:
|
||||
case <-periodicTimeoutTickerDone:
|
||||
return
|
||||
case <-idleTimeoutTicker.C:
|
||||
Connections.checkIdles()
|
||||
case <-periodicTimeoutTicker.C:
|
||||
counter++
|
||||
if Config.IdleTimeout > 0 && counter >= int64(ratio) {
|
||||
counter = 0
|
||||
Connections.checkIdles()
|
||||
}
|
||||
go Connections.checkTransfers()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func stopIdleTimeoutTicker() {
|
||||
if idleTimeoutTicker != nil {
|
||||
idleTimeoutTicker.Stop()
|
||||
idleTimeoutTickerDone <- true
|
||||
idleTimeoutTicker = nil
|
||||
func stopPeriodicTimeoutTicker() {
|
||||
if periodicTimeoutTicker != nil {
|
||||
periodicTimeoutTicker.Stop()
|
||||
periodicTimeoutTickerDone <- true
|
||||
periodicTimeoutTicker = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ActiveTransfer defines the interface for the current active transfers
|
||||
type ActiveTransfer interface {
|
||||
GetID() uint64
|
||||
GetID() int64
|
||||
GetType() int
|
||||
GetSize() int64
|
||||
GetDownloadedSize() int64
|
||||
GetUploadedSize() int64
|
||||
GetVirtualPath() string
|
||||
GetStartTime() time.Time
|
||||
SignalClose()
|
||||
SignalClose(err error)
|
||||
Truncate(fsPath string, size int64) (int64, error)
|
||||
GetRealFsPath(fsPath string) string
|
||||
SetTimes(fsPath string, atime time.Time, mtime time.Time) bool
|
||||
GetTruncatedSize() int64
|
||||
GetMaxAllowedSize() int64
|
||||
}
|
||||
|
||||
// ActiveConnection defines the interface for the current active connections
|
||||
|
@ -319,6 +332,7 @@ type ActiveConnection interface {
|
|||
AddTransfer(t ActiveTransfer)
|
||||
RemoveTransfer(t ActiveTransfer)
|
||||
GetTransfers() []ConnectionTransfer
|
||||
SignalTransferClose(transferID int64, err error)
|
||||
CloseFS() error
|
||||
}
|
||||
|
||||
|
@ -335,11 +349,14 @@ type StatAttributes struct {
|
|||
|
||||
// ConnectionTransfer defines the trasfer details to expose
|
||||
type ConnectionTransfer struct {
|
||||
ID uint64 `json:"-"`
|
||||
OperationType string `json:"operation_type"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
Size int64 `json:"size"`
|
||||
VirtualPath string `json:"path"`
|
||||
ID int64 `json:"-"`
|
||||
OperationType string `json:"operation_type"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
Size int64 `json:"size"`
|
||||
VirtualPath string `json:"path"`
|
||||
MaxAllowedSize int64 `json:"-"`
|
||||
ULSize int64 `json:"-"`
|
||||
DLSize int64 `json:"-"`
|
||||
}
|
||||
|
||||
func (t *ConnectionTransfer) getConnectionTransferAsString() string {
|
||||
|
@ -653,7 +670,8 @@ func (c *SSHConnection) Close() error {
|
|||
type ActiveConnections struct {
|
||||
// clients contains both authenticated and estabilished connections and the ones waiting
|
||||
// for authentication
|
||||
clients clientsMap
|
||||
clients clientsMap
|
||||
transfersCheckStatus int32
|
||||
sync.RWMutex
|
||||
connections []ActiveConnection
|
||||
sshConnections []*SSHConnection
|
||||
|
@ -825,6 +843,59 @@ func (conns *ActiveConnections) checkIdles() {
|
|||
conns.RUnlock()
|
||||
}
|
||||
|
||||
func (conns *ActiveConnections) checkTransfers() {
|
||||
if atomic.LoadInt32(&conns.transfersCheckStatus) == 1 {
|
||||
logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution")
|
||||
return
|
||||
}
|
||||
atomic.StoreInt32(&conns.transfersCheckStatus, 1)
|
||||
defer atomic.StoreInt32(&conns.transfersCheckStatus, 0)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
logger.Debug(logSender, "", "start concurrent transfers check")
|
||||
conns.RLock()
|
||||
|
||||
// update the current size for transfers to monitors
|
||||
for _, c := range conns.connections {
|
||||
for _, t := range c.GetTransfers() {
|
||||
if t.MaxAllowedSize > 0 {
|
||||
wg.Add(1)
|
||||
|
||||
go func(transfer ConnectionTransfer, connID string) {
|
||||
defer wg.Done()
|
||||
transfersChecker.UpdateTransferCurrentSize(transfer.ULSize, transfer.DLSize, transfer.ID, connID)
|
||||
}(t, c.GetID())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
conns.RUnlock()
|
||||
logger.Debug(logSender, "", "waiting for the update of the transfers current size")
|
||||
wg.Wait()
|
||||
|
||||
logger.Debug(logSender, "", "getting overquota transfers")
|
||||
overquotaTransfers := transfersChecker.GetOverquotaTransfers()
|
||||
logger.Debug(logSender, "", "number of overquota transfers: %v", len(overquotaTransfers))
|
||||
if len(overquotaTransfers) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
conns.RLock()
|
||||
defer conns.RUnlock()
|
||||
|
||||
for _, c := range conns.connections {
|
||||
for _, overquotaTransfer := range overquotaTransfers {
|
||||
if c.GetID() == overquotaTransfer.ConnID {
|
||||
logger.Info(logSender, c.GetID(), "user %#v is overquota, try to close transfer id %v ",
|
||||
c.GetUsername(), overquotaTransfer.TransferID)
|
||||
c.SignalTransferClose(overquotaTransfer.TransferID, getQuotaExceededError(c.GetProtocol()))
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.Debug(logSender, "", "transfers check completed")
|
||||
}
|
||||
|
||||
// AddClientConnection stores a new client connection
|
||||
func (conns *ActiveConnections) AddClientConnection(ipAddr string) {
|
||||
conns.clients.add(ipAddr)
|
||||
|
|
|
@ -408,19 +408,19 @@ func TestIdleConnections(t *testing.T) {
|
|||
assert.Len(t, Connections.sshConnections, 2)
|
||||
Connections.RUnlock()
|
||||
|
||||
startIdleTimeoutTicker(100 * time.Millisecond)
|
||||
startPeriodicTimeoutTicker(100 * time.Millisecond)
|
||||
assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond)
|
||||
assert.Eventually(t, func() bool {
|
||||
Connections.RLock()
|
||||
defer Connections.RUnlock()
|
||||
return len(Connections.sshConnections) == 1
|
||||
}, 1*time.Second, 200*time.Millisecond)
|
||||
stopIdleTimeoutTicker()
|
||||
stopPeriodicTimeoutTicker()
|
||||
assert.Len(t, Connections.GetStats(), 2)
|
||||
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
|
||||
cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
|
||||
sshConn2.lastActivity = c.lastActivity
|
||||
startIdleTimeoutTicker(100 * time.Millisecond)
|
||||
startPeriodicTimeoutTicker(100 * time.Millisecond)
|
||||
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond)
|
||||
assert.Eventually(t, func() bool {
|
||||
Connections.RLock()
|
||||
|
@ -428,7 +428,7 @@ func TestIdleConnections(t *testing.T) {
|
|||
return len(Connections.sshConnections) == 0
|
||||
}, 1*time.Second, 200*time.Millisecond)
|
||||
assert.Equal(t, int32(0), Connections.GetClientConnections())
|
||||
stopIdleTimeoutTicker()
|
||||
stopPeriodicTimeoutTicker()
|
||||
assert.True(t, customConn1.isClosed)
|
||||
assert.True(t, customConn2.isClosed)
|
||||
|
||||
|
@ -505,9 +505,9 @@ func TestConnectionStatus(t *testing.T) {
|
|||
fakeConn1 := &fakeConnection{
|
||||
BaseConnection: c1,
|
||||
}
|
||||
t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, true, fs)
|
||||
t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs)
|
||||
t1.BytesReceived = 123
|
||||
t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs)
|
||||
t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs)
|
||||
t2.BytesSent = 456
|
||||
c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user)
|
||||
fakeConn2 := &fakeConnection{
|
||||
|
@ -519,7 +519,7 @@ func TestConnectionStatus(t *testing.T) {
|
|||
BaseConnection: c3,
|
||||
command: "PROPFIND",
|
||||
}
|
||||
t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs)
|
||||
t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs)
|
||||
Connections.Add(fakeConn1)
|
||||
Connections.Add(fakeConn2)
|
||||
Connections.Add(fakeConn3)
|
||||
|
|
|
@ -27,7 +27,7 @@ type BaseConnection struct {
|
|||
lastActivity int64
|
||||
// unique ID for a transfer.
|
||||
// This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment
|
||||
transferID uint64
|
||||
transferID int64
|
||||
// Unique identifier for the connection
|
||||
ID string
|
||||
// user associated with this connection if any
|
||||
|
@ -66,8 +66,8 @@ func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...interfac
|
|||
}
|
||||
|
||||
// GetTransferID returns an unique transfer ID for this connection
|
||||
func (c *BaseConnection) GetTransferID() uint64 {
|
||||
return atomic.AddUint64(&c.transferID, 1)
|
||||
func (c *BaseConnection) GetTransferID() int64 {
|
||||
return atomic.AddInt64(&c.transferID, 1)
|
||||
}
|
||||
|
||||
// GetID returns the connection ID
|
||||
|
@ -125,6 +125,27 @@ func (c *BaseConnection) AddTransfer(t ActiveTransfer) {
|
|||
|
||||
c.activeTransfers = append(c.activeTransfers, t)
|
||||
c.Log(logger.LevelDebug, "transfer added, id: %v, active transfers: %v", t.GetID(), len(c.activeTransfers))
|
||||
if t.GetMaxAllowedSize() > 0 {
|
||||
folderName := ""
|
||||
if t.GetType() == TransferUpload {
|
||||
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(t.GetVirtualPath()))
|
||||
if err == nil {
|
||||
if !vfolder.IsIncludedInUserQuota() {
|
||||
folderName = vfolder.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
go transfersChecker.AddTransfer(dataprovider.ActiveTransfer{
|
||||
ID: t.GetID(),
|
||||
Type: t.GetType(),
|
||||
ConnID: c.ID,
|
||||
Username: c.GetUsername(),
|
||||
FolderName: folderName,
|
||||
TruncatedSize: t.GetTruncatedSize(),
|
||||
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveTransfer removes the specified transfer from the active ones
|
||||
|
@ -132,6 +153,10 @@ func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) {
|
|||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if t.GetMaxAllowedSize() > 0 {
|
||||
go transfersChecker.RemoveTransfer(t.GetID(), c.ID)
|
||||
}
|
||||
|
||||
for idx, transfer := range c.activeTransfers {
|
||||
if transfer.GetID() == t.GetID() {
|
||||
lastIdx := len(c.activeTransfers) - 1
|
||||
|
@ -145,6 +170,20 @@ func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) {
|
|||
c.Log(logger.LevelWarn, "transfer to remove with id %v not found!", t.GetID())
|
||||
}
|
||||
|
||||
// SignalTransferClose makes the transfer fail on the next read/write with the
|
||||
// specified error
|
||||
func (c *BaseConnection) SignalTransferClose(transferID int64, err error) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
for _, t := range c.activeTransfers {
|
||||
if t.GetID() == transferID {
|
||||
c.Log(logger.LevelInfo, "signal transfer close for transfer id %v", transferID)
|
||||
t.SignalClose(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetTransfers returns the active transfers
|
||||
func (c *BaseConnection) GetTransfers() []ConnectionTransfer {
|
||||
c.RLock()
|
||||
|
@ -160,11 +199,14 @@ func (c *BaseConnection) GetTransfers() []ConnectionTransfer {
|
|||
operationType = operationUpload
|
||||
}
|
||||
transfers = append(transfers, ConnectionTransfer{
|
||||
ID: t.GetID(),
|
||||
OperationType: operationType,
|
||||
StartTime: util.GetTimeAsMsSinceEpoch(t.GetStartTime()),
|
||||
Size: t.GetSize(),
|
||||
VirtualPath: t.GetVirtualPath(),
|
||||
ID: t.GetID(),
|
||||
OperationType: operationType,
|
||||
StartTime: util.GetTimeAsMsSinceEpoch(t.GetStartTime()),
|
||||
Size: t.GetSize(),
|
||||
VirtualPath: t.GetVirtualPath(),
|
||||
MaxAllowedSize: t.GetMaxAllowedSize(),
|
||||
ULSize: t.GetUploadedSize(),
|
||||
DLSize: t.GetDownloadedSize(),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -181,7 +223,7 @@ func (c *BaseConnection) SignalTransfersAbort() error {
|
|||
}
|
||||
|
||||
for _, t := range c.activeTransfers {
|
||||
t.SignalClose()
|
||||
t.SignalClose(ErrTransferAborted)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -1208,9 +1250,8 @@ func (c *BaseConnection) GetOpUnsupportedError() error {
|
|||
}
|
||||
}
|
||||
|
||||
// GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol
|
||||
func (c *BaseConnection) GetQuotaExceededError() error {
|
||||
switch c.protocol {
|
||||
func getQuotaExceededError(protocol string) error {
|
||||
switch protocol {
|
||||
case ProtocolSFTP:
|
||||
return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, ErrQuotaExceeded.Error())
|
||||
case ProtocolFTP:
|
||||
|
@ -1220,6 +1261,11 @@ func (c *BaseConnection) GetQuotaExceededError() error {
|
|||
}
|
||||
}
|
||||
|
||||
// GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol
|
||||
func (c *BaseConnection) GetQuotaExceededError() error {
|
||||
return getQuotaExceededError(c.protocol)
|
||||
}
|
||||
|
||||
// IsQuotaExceededError returns true if the given error is a quota exceeded error
|
||||
func (c *BaseConnection) IsQuotaExceededError(err error) bool {
|
||||
switch c.protocol {
|
||||
|
|
|
@ -20,7 +20,7 @@ var (
|
|||
|
||||
// BaseTransfer contains protocols common transfer details for an upload or a download.
|
||||
type BaseTransfer struct { //nolint:maligned
|
||||
ID uint64
|
||||
ID int64
|
||||
BytesSent int64
|
||||
BytesReceived int64
|
||||
Fs vfs.Fs
|
||||
|
@ -35,18 +35,21 @@ type BaseTransfer struct { //nolint:maligned
|
|||
MaxWriteSize int64
|
||||
MinWriteOffset int64
|
||||
InitialSize int64
|
||||
truncatedSize int64
|
||||
isNewFile bool
|
||||
transferType int
|
||||
AbortTransfer int32
|
||||
aTime time.Time
|
||||
mTime time.Time
|
||||
sync.Mutex
|
||||
errAbort error
|
||||
ErrTransfer error
|
||||
}
|
||||
|
||||
// NewBaseTransfer returns a new BaseTransfer and adds it to the given connection
|
||||
func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPath, effectiveFsPath, requestPath string,
|
||||
transferType int, minWriteOffset, initialSize, maxWriteSize int64, isNewFile bool, fs vfs.Fs) *BaseTransfer {
|
||||
transferType int, minWriteOffset, initialSize, maxWriteSize, truncatedSize int64, isNewFile bool, fs vfs.Fs,
|
||||
) *BaseTransfer {
|
||||
t := &BaseTransfer{
|
||||
ID: conn.GetTransferID(),
|
||||
File: file,
|
||||
|
@ -64,6 +67,7 @@ func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPat
|
|||
BytesReceived: 0,
|
||||
MaxWriteSize: maxWriteSize,
|
||||
AbortTransfer: 0,
|
||||
truncatedSize: truncatedSize,
|
||||
Fs: fs,
|
||||
}
|
||||
|
||||
|
@ -77,7 +81,7 @@ func (t *BaseTransfer) SetFtpMode(mode string) {
|
|||
}
|
||||
|
||||
// GetID returns the transfer ID
|
||||
func (t *BaseTransfer) GetID() uint64 {
|
||||
func (t *BaseTransfer) GetID() int64 {
|
||||
return t.ID
|
||||
}
|
||||
|
||||
|
@ -94,19 +98,53 @@ func (t *BaseTransfer) GetSize() int64 {
|
|||
return atomic.LoadInt64(&t.BytesReceived)
|
||||
}
|
||||
|
||||
// GetDownloadedSize returns the transferred size
|
||||
func (t *BaseTransfer) GetDownloadedSize() int64 {
|
||||
return atomic.LoadInt64(&t.BytesSent)
|
||||
}
|
||||
|
||||
// GetUploadedSize returns the transferred size
|
||||
func (t *BaseTransfer) GetUploadedSize() int64 {
|
||||
return atomic.LoadInt64(&t.BytesReceived)
|
||||
}
|
||||
|
||||
// GetStartTime returns the start time
|
||||
func (t *BaseTransfer) GetStartTime() time.Time {
|
||||
return t.start
|
||||
}
|
||||
|
||||
// SignalClose signals that the transfer should be closed.
|
||||
// For same protocols, for example WebDAV, we have no
|
||||
// access to the network connection, so we use this method
|
||||
// to make the next read or write to fail
|
||||
func (t *BaseTransfer) SignalClose() {
|
||||
// GetAbortError returns the error to send to the client if the transfer was aborted
|
||||
func (t *BaseTransfer) GetAbortError() error {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
if t.errAbort != nil {
|
||||
return t.errAbort
|
||||
}
|
||||
return getQuotaExceededError(t.Connection.protocol)
|
||||
}
|
||||
|
||||
// SignalClose signals that the transfer should be closed after the next read/write.
|
||||
// The optional error argument allow to send a specific error, otherwise a generic
|
||||
// transfer aborted error is sent
|
||||
func (t *BaseTransfer) SignalClose(err error) {
|
||||
t.Lock()
|
||||
t.errAbort = err
|
||||
t.Unlock()
|
||||
atomic.StoreInt32(&(t.AbortTransfer), 1)
|
||||
}
|
||||
|
||||
// GetTruncatedSize returns the truncated sized if this is an upload overwriting
|
||||
// an existing file
|
||||
func (t *BaseTransfer) GetTruncatedSize() int64 {
|
||||
return t.truncatedSize
|
||||
}
|
||||
|
||||
// GetMaxAllowedSize returns the max allowed size
|
||||
func (t *BaseTransfer) GetMaxAllowedSize() int64 {
|
||||
return t.MaxWriteSize
|
||||
}
|
||||
|
||||
// GetVirtualPath returns the transfer virtual path
|
||||
func (t *BaseTransfer) GetVirtualPath() string {
|
||||
return t.requestPath
|
||||
|
|
|
@ -65,7 +65,7 @@ func TestTransferThrottling(t *testing.T) {
|
|||
wantedUploadElapsed -= wantedDownloadElapsed / 10
|
||||
wantedDownloadElapsed -= wantedDownloadElapsed / 10
|
||||
conn := NewBaseConnection("id", ProtocolSCP, "", "", u)
|
||||
transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, true, fs)
|
||||
transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs)
|
||||
transfer.BytesReceived = testFileSize
|
||||
transfer.Connection.UpdateLastActivity()
|
||||
startTime := transfer.Connection.GetLastActivity()
|
||||
|
@ -75,7 +75,7 @@ func TestTransferThrottling(t *testing.T) {
|
|||
err := transfer.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, true, fs)
|
||||
transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs)
|
||||
transfer.BytesSent = testFileSize
|
||||
transfer.Connection.UpdateLastActivity()
|
||||
startTime = transfer.Connection.GetLastActivity()
|
||||
|
@ -101,7 +101,8 @@ func TestRealPath(t *testing.T) {
|
|||
file, err := os.Create(testFile)
|
||||
require.NoError(t, err)
|
||||
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
|
||||
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs)
|
||||
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file",
|
||||
TransferUpload, 0, 0, 0, 0, true, fs)
|
||||
rPath := transfer.GetRealFsPath(testFile)
|
||||
assert.Equal(t, testFile, rPath)
|
||||
rPath = conn.getRealFsPath(testFile)
|
||||
|
@ -138,7 +139,8 @@ func TestTruncate(t *testing.T) {
|
|||
_, err = file.Write([]byte("hello"))
|
||||
assert.NoError(t, err)
|
||||
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
|
||||
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 5, 100, false, fs)
|
||||
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 5,
|
||||
100, 0, false, fs)
|
||||
|
||||
err = conn.SetStat("/transfer_test_file", &StatAttributes{
|
||||
Size: 2,
|
||||
|
@ -155,7 +157,8 @@ func TestTruncate(t *testing.T) {
|
|||
assert.Equal(t, int64(2), fi.Size())
|
||||
}
|
||||
|
||||
transfer = NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 100, true, fs)
|
||||
transfer = NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0,
|
||||
100, 0, true, fs)
|
||||
// file.Stat will fail on a closed file
|
||||
err = conn.SetStat("/transfer_test_file", &StatAttributes{
|
||||
Size: 2,
|
||||
|
@ -165,7 +168,7 @@ func TestTruncate(t *testing.T) {
|
|||
err = transfer.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, true, fs)
|
||||
transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, 0, true, fs)
|
||||
_, err = transfer.Truncate("mismatch", 0)
|
||||
assert.EqualError(t, err, errTransferMismatch.Error())
|
||||
_, err = transfer.Truncate(testFile, 0)
|
||||
|
@ -202,7 +205,8 @@ func TestTransferErrors(t *testing.T) {
|
|||
assert.FailNow(t, "unable to open test file")
|
||||
}
|
||||
conn := NewBaseConnection("id", ProtocolSFTP, "", "", u)
|
||||
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs)
|
||||
transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload,
|
||||
0, 0, 0, 0, true, fs)
|
||||
assert.Nil(t, transfer.cancelFn)
|
||||
assert.Equal(t, testFile, transfer.GetFsPath())
|
||||
transfer.SetCancelFn(cancelFn)
|
||||
|
@ -228,7 +232,7 @@ func TestTransferErrors(t *testing.T) {
|
|||
assert.FailNow(t, "unable to open test file")
|
||||
}
|
||||
fsPath := filepath.Join(os.TempDir(), "test_file")
|
||||
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, true, fs)
|
||||
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs)
|
||||
transfer.BytesReceived = 9
|
||||
transfer.TransferError(errFake)
|
||||
assert.Error(t, transfer.ErrTransfer, errFake.Error())
|
||||
|
@ -247,7 +251,7 @@ func TestTransferErrors(t *testing.T) {
|
|||
if !assert.NoError(t, err) {
|
||||
assert.FailNow(t, "unable to open test file")
|
||||
}
|
||||
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, true, fs)
|
||||
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs)
|
||||
transfer.BytesReceived = 9
|
||||
// the file is closed from the embedding struct before to call close
|
||||
err = file.Close()
|
||||
|
@ -273,7 +277,8 @@ func TestRemovePartialCryptoFile(t *testing.T) {
|
|||
},
|
||||
}
|
||||
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u)
|
||||
transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs)
|
||||
transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload,
|
||||
0, 0, 0, 0, true, fs)
|
||||
transfer.ErrTransfer = errors.New("test error")
|
||||
_, err = transfer.getUploadFileSize()
|
||||
assert.Error(t, err)
|
||||
|
|
167
common/transferschecker.go
Normal file
167
common/transferschecker.go
Normal file
|
@ -0,0 +1,167 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/dataprovider"
|
||||
"github.com/drakkan/sftpgo/v2/logger"
|
||||
"github.com/drakkan/sftpgo/v2/util"
|
||||
)
|
||||
|
||||
type overquotaTransfer struct {
|
||||
ConnID string
|
||||
TransferID int64
|
||||
}
|
||||
|
||||
// TransfersChecker defines the interface that transfer checkers must implement.
|
||||
// A transfer checker ensure that multiple concurrent transfers does not exceeded
|
||||
// the remaining user quota
|
||||
type TransfersChecker interface {
|
||||
AddTransfer(transfer dataprovider.ActiveTransfer)
|
||||
RemoveTransfer(ID int64, connectionID string)
|
||||
UpdateTransferCurrentSize(ulSize int64, dlSize int64, ID int64, connectionID string)
|
||||
GetOverquotaTransfers() []overquotaTransfer
|
||||
}
|
||||
|
||||
func getTransfersChecker() TransfersChecker {
|
||||
return &transfersCheckerMem{}
|
||||
}
|
||||
|
||||
type transfersCheckerMem struct {
|
||||
sync.RWMutex
|
||||
transfers []dataprovider.ActiveTransfer
|
||||
}
|
||||
|
||||
func (t *transfersCheckerMem) AddTransfer(transfer dataprovider.ActiveTransfer) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
t.transfers = append(t.transfers, transfer)
|
||||
}
|
||||
|
||||
func (t *transfersCheckerMem) RemoveTransfer(ID int64, connectionID string) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
for idx, transfer := range t.transfers {
|
||||
if transfer.ID == ID && transfer.ConnID == connectionID {
|
||||
lastIdx := len(t.transfers) - 1
|
||||
t.transfers[idx] = t.transfers[lastIdx]
|
||||
t.transfers = t.transfers[:lastIdx]
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *transfersCheckerMem) UpdateTransferCurrentSize(ulSize int64, dlSize int64, ID int64, connectionID string) {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
|
||||
for idx := range t.transfers {
|
||||
if t.transfers[idx].ID == ID && t.transfers[idx].ConnID == connectionID {
|
||||
t.transfers[idx].CurrentDLSize = dlSize
|
||||
t.transfers[idx].CurrentULSize = ulSize
|
||||
t.transfers[idx].UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *transfersCheckerMem) getRemainingDiskQuota(user dataprovider.User, folderName string) (int64, error) {
|
||||
var result int64
|
||||
|
||||
if folderName != "" {
|
||||
for _, folder := range user.VirtualFolders {
|
||||
if folder.Name == folderName {
|
||||
if folder.QuotaSize > 0 {
|
||||
return folder.QuotaSize - folder.UsedQuotaSize, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if user.QuotaSize > 0 {
|
||||
return user.QuotaSize - user.UsedQuotaSize, nil
|
||||
}
|
||||
}
|
||||
|
||||
return result, errors.New("no quota limit defined")
|
||||
}
|
||||
|
||||
func (t *transfersCheckerMem) aggregateTransfers() (map[string]bool, map[string][]dataprovider.ActiveTransfer) {
|
||||
t.RLock()
|
||||
defer t.RUnlock()
|
||||
|
||||
usersToFetch := make(map[string]bool)
|
||||
aggregations := make(map[string][]dataprovider.ActiveTransfer)
|
||||
for _, transfer := range t.transfers {
|
||||
key := transfer.GetKey()
|
||||
aggregations[key] = append(aggregations[key], transfer)
|
||||
if len(aggregations[key]) > 1 {
|
||||
if transfer.FolderName != "" {
|
||||
usersToFetch[transfer.Username] = true
|
||||
} else {
|
||||
if _, ok := usersToFetch[transfer.Username]; !ok {
|
||||
usersToFetch[transfer.Username] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return usersToFetch, aggregations
|
||||
}
|
||||
|
||||
func (t *transfersCheckerMem) GetOverquotaTransfers() []overquotaTransfer {
|
||||
usersToFetch, aggregations := t.aggregateTransfers()
|
||||
|
||||
if len(usersToFetch) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch)
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "", "unable to check transfers, error getting users quota: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
usersMap := make(map[string]dataprovider.User)
|
||||
|
||||
for _, user := range users {
|
||||
usersMap[user.Username] = user
|
||||
}
|
||||
|
||||
var overquotaTransfers []overquotaTransfer
|
||||
|
||||
for _, transfers := range aggregations {
|
||||
if len(transfers) > 1 {
|
||||
username := transfers[0].Username
|
||||
folderName := transfers[0].FolderName
|
||||
// transfer type is always upload for now
|
||||
remaningDiskQuota, err := t.getRemainingDiskQuota(usersMap[username], folderName)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var usedDiskQuota int64
|
||||
for _, tr := range transfers {
|
||||
// We optimistically assume that a cloud transfer that replaces an existing
|
||||
// file will be successful
|
||||
usedDiskQuota += tr.CurrentULSize - tr.TruncatedSize
|
||||
}
|
||||
logger.Debug(logSender, "", "username %#v, folder %#v, concurrent transfers: %v, remaining disk quota: %v, disk quota used in ongoing transfers: %v",
|
||||
username, folderName, len(transfers), remaningDiskQuota, usedDiskQuota)
|
||||
if usedDiskQuota > remaningDiskQuota {
|
||||
for _, tr := range transfers {
|
||||
if tr.CurrentULSize > tr.TruncatedSize {
|
||||
overquotaTransfers = append(overquotaTransfers, overquotaTransfer{
|
||||
ConnID: tr.ConnID,
|
||||
TransferID: tr.ID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return overquotaTransfers
|
||||
}
|
449
common/transferschecker_test.go
Normal file
449
common/transferschecker_test.go
Normal file
|
@ -0,0 +1,449 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/sftpgo/sdk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/dataprovider"
|
||||
"github.com/drakkan/sftpgo/v2/util"
|
||||
"github.com/drakkan/sftpgo/v2/vfs"
|
||||
)
|
||||
|
||||
func TestTransfersCheckerDiskQuota(t *testing.T) {
|
||||
username := "transfers_check_username"
|
||||
folderName := "test_transfers_folder"
|
||||
vdirPath := "/vdir"
|
||||
user := dataprovider.User{
|
||||
BaseUser: sdk.BaseUser{
|
||||
Username: username,
|
||||
Password: "testpwd",
|
||||
HomeDir: filepath.Join(os.TempDir(), username),
|
||||
Status: 1,
|
||||
QuotaSize: 120,
|
||||
Permissions: map[string][]string{
|
||||
"/": {dataprovider.PermAny},
|
||||
},
|
||||
},
|
||||
VirtualFolders: []vfs.VirtualFolder{
|
||||
{
|
||||
BaseVirtualFolder: vfs.BaseVirtualFolder{
|
||||
Name: folderName,
|
||||
MappedPath: filepath.Join(os.TempDir(), folderName),
|
||||
},
|
||||
VirtualPath: vdirPath,
|
||||
QuotaSize: 100,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := dataprovider.AddUser(&user, "", "")
|
||||
assert.NoError(t, err)
|
||||
user, err = dataprovider.UserExists(username)
|
||||
assert.NoError(t, err)
|
||||
|
||||
connID1 := xid.New().String()
|
||||
fsUser, err := user.GetFilesystemForPath("/file1", connID1)
|
||||
assert.NoError(t, err)
|
||||
conn1 := NewBaseConnection(connID1, ProtocolSFTP, "", "", user)
|
||||
fakeConn1 := &fakeConnection{
|
||||
BaseConnection: conn1,
|
||||
}
|
||||
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
|
||||
"/file1", TransferUpload, 0, 0, 120, 0, true, fsUser)
|
||||
transfer1.BytesReceived = 150
|
||||
Connections.Add(fakeConn1)
|
||||
// the transferschecker will do nothing if there is only one ongoing transfer
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
|
||||
connID2 := xid.New().String()
|
||||
conn2 := NewBaseConnection(connID2, ProtocolSFTP, "", "", user)
|
||||
fakeConn2 := &fakeConnection{
|
||||
BaseConnection: conn2,
|
||||
}
|
||||
transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
|
||||
"/file2", TransferUpload, 0, 0, 120, 40, true, fsUser)
|
||||
transfer1.BytesReceived = 50
|
||||
transfer2.BytesReceived = 60
|
||||
Connections.Add(fakeConn2)
|
||||
|
||||
connID3 := xid.New().String()
|
||||
conn3 := NewBaseConnection(connID3, ProtocolSFTP, "", "", user)
|
||||
fakeConn3 := &fakeConnection{
|
||||
BaseConnection: conn3,
|
||||
}
|
||||
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"),
|
||||
"/file3", TransferDownload, 0, 0, 120, 0, true, fsUser)
|
||||
transfer3.BytesReceived = 60 // this value will be ignored, this is a download
|
||||
Connections.Add(fakeConn3)
|
||||
|
||||
// the transfers are not overquota
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
assert.Nil(t, transfer2.errAbort)
|
||||
assert.Nil(t, transfer3.errAbort)
|
||||
|
||||
transfer1.BytesReceived = 80 // truncated size will be subtracted, we are not overquota
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
assert.Nil(t, transfer2.errAbort)
|
||||
assert.Nil(t, transfer3.errAbort)
|
||||
transfer1.BytesReceived = 120
|
||||
// we are now overquota
|
||||
// if another check is in progress nothing is done
|
||||
atomic.StoreInt32(&Connections.transfersCheckStatus, 1)
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
assert.Nil(t, transfer2.errAbort)
|
||||
assert.Nil(t, transfer3.errAbort)
|
||||
atomic.StoreInt32(&Connections.transfersCheckStatus, 0)
|
||||
|
||||
Connections.checkTransfers()
|
||||
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort))
|
||||
assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort))
|
||||
assert.True(t, conn1.IsQuotaExceededError(transfer1.GetAbortError()))
|
||||
assert.Nil(t, transfer3.errAbort)
|
||||
assert.True(t, conn3.IsQuotaExceededError(transfer3.GetAbortError()))
|
||||
// update the user quota size
|
||||
user.QuotaSize = 1000
|
||||
err = dataprovider.UpdateUser(&user, "", "")
|
||||
assert.NoError(t, err)
|
||||
transfer1.errAbort = nil
|
||||
transfer2.errAbort = nil
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
assert.Nil(t, transfer2.errAbort)
|
||||
assert.Nil(t, transfer3.errAbort)
|
||||
|
||||
user.QuotaSize = 0
|
||||
err = dataprovider.UpdateUser(&user, "", "")
|
||||
assert.NoError(t, err)
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
assert.Nil(t, transfer2.errAbort)
|
||||
assert.Nil(t, transfer3.errAbort)
|
||||
// now check a public folder
|
||||
transfer1.BytesReceived = 0
|
||||
transfer2.BytesReceived = 0
|
||||
connID4 := xid.New().String()
|
||||
fsFolder, err := user.GetFilesystemForPath(path.Join(vdirPath, "/file1"), connID4)
|
||||
assert.NoError(t, err)
|
||||
conn4 := NewBaseConnection(connID4, ProtocolSFTP, "", "", user)
|
||||
fakeConn4 := &fakeConnection{
|
||||
BaseConnection: conn4,
|
||||
}
|
||||
transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(os.TempDir(), folderName, "file1"),
|
||||
filepath.Join(os.TempDir(), folderName, "file1"), path.Join(vdirPath, "/file1"), TransferUpload, 0, 0,
|
||||
100, 0, true, fsFolder)
|
||||
Connections.Add(fakeConn4)
|
||||
connID5 := xid.New().String()
|
||||
conn5 := NewBaseConnection(connID5, ProtocolSFTP, "", "", user)
|
||||
fakeConn5 := &fakeConnection{
|
||||
BaseConnection: conn5,
|
||||
}
|
||||
transfer5 := NewBaseTransfer(nil, conn5, nil, filepath.Join(os.TempDir(), folderName, "file2"),
|
||||
filepath.Join(os.TempDir(), folderName, "file2"), path.Join(vdirPath, "/file2"), TransferUpload, 0, 0,
|
||||
100, 0, true, fsFolder)
|
||||
|
||||
Connections.Add(fakeConn5)
|
||||
transfer4.BytesReceived = 50
|
||||
transfer5.BytesReceived = 40
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer4.errAbort)
|
||||
assert.Nil(t, transfer5.errAbort)
|
||||
transfer5.BytesReceived = 60
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
assert.Nil(t, transfer2.errAbort)
|
||||
assert.Nil(t, transfer3.errAbort)
|
||||
assert.True(t, conn1.IsQuotaExceededError(transfer4.errAbort))
|
||||
assert.True(t, conn2.IsQuotaExceededError(transfer5.errAbort))
|
||||
|
||||
if dataprovider.GetProviderStatus().Driver != dataprovider.MemoryDataProviderName {
|
||||
providerConf := dataprovider.GetProviderConfig()
|
||||
err = dataprovider.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
transfer4.errAbort = nil
|
||||
transfer5.errAbort = nil
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
assert.Nil(t, transfer2.errAbort)
|
||||
assert.Nil(t, transfer3.errAbort)
|
||||
assert.Nil(t, transfer4.errAbort)
|
||||
assert.Nil(t, transfer5.errAbort)
|
||||
|
||||
err = dataprovider.Initialize(providerConf, configDir, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
Connections.Remove(fakeConn1.GetID())
|
||||
Connections.Remove(fakeConn2.GetID())
|
||||
Connections.Remove(fakeConn3.GetID())
|
||||
Connections.Remove(fakeConn4.GetID())
|
||||
Connections.Remove(fakeConn5.GetID())
|
||||
stats := Connections.GetStats()
|
||||
assert.Len(t, stats, 0)
|
||||
|
||||
err = dataprovider.DeleteUser(user.Username, "", "")
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = dataprovider.DeleteFolder(folderName, "", "")
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(filepath.Join(os.TempDir(), folderName))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAggregateTransfers(t *testing.T) {
|
||||
checker := transfersCheckerMem{}
|
||||
checker.AddTransfer(dataprovider.ActiveTransfer{
|
||||
ID: 1,
|
||||
Type: TransferUpload,
|
||||
ConnID: "1",
|
||||
Username: "user",
|
||||
FolderName: "",
|
||||
TruncatedSize: 0,
|
||||
CurrentULSize: 100,
|
||||
CurrentDLSize: 0,
|
||||
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
})
|
||||
usersToFetch, aggregations := checker.aggregateTransfers()
|
||||
assert.Len(t, usersToFetch, 0)
|
||||
assert.Len(t, aggregations, 1)
|
||||
|
||||
checker.AddTransfer(dataprovider.ActiveTransfer{
|
||||
ID: 1,
|
||||
Type: TransferDownload,
|
||||
ConnID: "2",
|
||||
Username: "user",
|
||||
FolderName: "",
|
||||
TruncatedSize: 0,
|
||||
CurrentULSize: 0,
|
||||
CurrentDLSize: 100,
|
||||
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
})
|
||||
|
||||
usersToFetch, aggregations = checker.aggregateTransfers()
|
||||
assert.Len(t, usersToFetch, 0)
|
||||
assert.Len(t, aggregations, 2)
|
||||
|
||||
checker.AddTransfer(dataprovider.ActiveTransfer{
|
||||
ID: 1,
|
||||
Type: TransferUpload,
|
||||
ConnID: "3",
|
||||
Username: "user",
|
||||
FolderName: "folder",
|
||||
TruncatedSize: 0,
|
||||
CurrentULSize: 10,
|
||||
CurrentDLSize: 0,
|
||||
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
})
|
||||
|
||||
usersToFetch, aggregations = checker.aggregateTransfers()
|
||||
assert.Len(t, usersToFetch, 0)
|
||||
assert.Len(t, aggregations, 3)
|
||||
|
||||
checker.AddTransfer(dataprovider.ActiveTransfer{
|
||||
ID: 1,
|
||||
Type: TransferUpload,
|
||||
ConnID: "4",
|
||||
Username: "user1",
|
||||
FolderName: "",
|
||||
TruncatedSize: 0,
|
||||
CurrentULSize: 100,
|
||||
CurrentDLSize: 0,
|
||||
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
})
|
||||
|
||||
usersToFetch, aggregations = checker.aggregateTransfers()
|
||||
assert.Len(t, usersToFetch, 0)
|
||||
assert.Len(t, aggregations, 4)
|
||||
|
||||
checker.AddTransfer(dataprovider.ActiveTransfer{
|
||||
ID: 1,
|
||||
Type: TransferUpload,
|
||||
ConnID: "5",
|
||||
Username: "user",
|
||||
FolderName: "",
|
||||
TruncatedSize: 0,
|
||||
CurrentULSize: 100,
|
||||
CurrentDLSize: 0,
|
||||
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
})
|
||||
|
||||
usersToFetch, aggregations = checker.aggregateTransfers()
|
||||
assert.Len(t, usersToFetch, 1)
|
||||
val, ok := usersToFetch["user"]
|
||||
assert.True(t, ok)
|
||||
assert.False(t, val)
|
||||
assert.Len(t, aggregations, 4)
|
||||
aggregate, ok := aggregations["user0"]
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, aggregate, 2)
|
||||
|
||||
checker.AddTransfer(dataprovider.ActiveTransfer{
|
||||
ID: 1,
|
||||
Type: TransferUpload,
|
||||
ConnID: "6",
|
||||
Username: "user",
|
||||
FolderName: "",
|
||||
TruncatedSize: 0,
|
||||
CurrentULSize: 100,
|
||||
CurrentDLSize: 0,
|
||||
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
})
|
||||
|
||||
usersToFetch, aggregations = checker.aggregateTransfers()
|
||||
assert.Len(t, usersToFetch, 1)
|
||||
val, ok = usersToFetch["user"]
|
||||
assert.True(t, ok)
|
||||
assert.False(t, val)
|
||||
assert.Len(t, aggregations, 4)
|
||||
aggregate, ok = aggregations["user0"]
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, aggregate, 3)
|
||||
|
||||
checker.AddTransfer(dataprovider.ActiveTransfer{
|
||||
ID: 1,
|
||||
Type: TransferUpload,
|
||||
ConnID: "7",
|
||||
Username: "user",
|
||||
FolderName: "folder",
|
||||
TruncatedSize: 0,
|
||||
CurrentULSize: 10,
|
||||
CurrentDLSize: 0,
|
||||
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
})
|
||||
|
||||
usersToFetch, aggregations = checker.aggregateTransfers()
|
||||
assert.Len(t, usersToFetch, 1)
|
||||
val, ok = usersToFetch["user"]
|
||||
assert.True(t, ok)
|
||||
assert.True(t, val)
|
||||
assert.Len(t, aggregations, 4)
|
||||
aggregate, ok = aggregations["user0"]
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, aggregate, 3)
|
||||
aggregate, ok = aggregations["userfolder0"]
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, aggregate, 2)
|
||||
|
||||
checker.AddTransfer(dataprovider.ActiveTransfer{
|
||||
ID: 1,
|
||||
Type: TransferUpload,
|
||||
ConnID: "8",
|
||||
Username: "user",
|
||||
FolderName: "",
|
||||
TruncatedSize: 0,
|
||||
CurrentULSize: 100,
|
||||
CurrentDLSize: 0,
|
||||
CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
})
|
||||
|
||||
usersToFetch, aggregations = checker.aggregateTransfers()
|
||||
assert.Len(t, usersToFetch, 1)
|
||||
val, ok = usersToFetch["user"]
|
||||
assert.True(t, ok)
|
||||
assert.True(t, val)
|
||||
assert.Len(t, aggregations, 4)
|
||||
aggregate, ok = aggregations["user0"]
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, aggregate, 4)
|
||||
aggregate, ok = aggregations["userfolder0"]
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, aggregate, 2)
|
||||
}
|
||||
|
||||
func TestGetUsersForQuotaCheck(t *testing.T) {
|
||||
usersToFetch := make(map[string]bool)
|
||||
for i := 0; i < 50; i++ {
|
||||
usersToFetch[fmt.Sprintf("user%v", i)] = i%2 == 0
|
||||
}
|
||||
|
||||
users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, users, 0)
|
||||
|
||||
for i := 0; i < 40; i++ {
|
||||
user := dataprovider.User{
|
||||
BaseUser: sdk.BaseUser{
|
||||
Username: fmt.Sprintf("user%v", i),
|
||||
Password: "pwd",
|
||||
HomeDir: filepath.Join(os.TempDir(), fmt.Sprintf("user%v", i)),
|
||||
Status: 1,
|
||||
QuotaSize: 120,
|
||||
Permissions: map[string][]string{
|
||||
"/": {dataprovider.PermAny},
|
||||
},
|
||||
},
|
||||
VirtualFolders: []vfs.VirtualFolder{
|
||||
{
|
||||
BaseVirtualFolder: vfs.BaseVirtualFolder{
|
||||
Name: fmt.Sprintf("f%v", i),
|
||||
MappedPath: filepath.Join(os.TempDir(), fmt.Sprintf("f%v", i)),
|
||||
},
|
||||
VirtualPath: "/vfolder",
|
||||
QuotaSize: 100,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = dataprovider.AddUser(&user, "", "")
|
||||
assert.NoError(t, err)
|
||||
err = dataprovider.UpdateVirtualFolderQuota(&vfs.BaseVirtualFolder{Name: fmt.Sprintf("f%v", i)}, 1, 50, false)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, users, 40)
|
||||
|
||||
for _, user := range users {
|
||||
userIdxStr := strings.Replace(user.Username, "user", "", 1)
|
||||
userIdx, err := strconv.Atoi(userIdxStr)
|
||||
assert.NoError(t, err)
|
||||
if userIdx%2 == 0 {
|
||||
if assert.Len(t, user.VirtualFolders, 1, user.Username) {
|
||||
assert.Equal(t, int64(100), user.VirtualFolders[0].QuotaSize)
|
||||
assert.Equal(t, int64(50), user.VirtualFolders[0].UsedQuotaSize)
|
||||
}
|
||||
} else {
|
||||
switch dataprovider.GetProviderStatus().Driver {
|
||||
case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName,
|
||||
dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName:
|
||||
assert.Len(t, user.VirtualFolders, 0, user.Username)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 40; i++ {
|
||||
err = dataprovider.DeleteUser(fmt.Sprintf("user%v", i), "", "")
|
||||
assert.NoError(t, err)
|
||||
err = dataprovider.DeleteFolder(fmt.Sprintf("f%v", i), "", "")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, users, 0)
|
||||
}
|
|
@ -647,6 +647,53 @@ func (p *BoltProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) {
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *BoltProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
|
||||
users := make([]User, 0, 30)
|
||||
|
||||
err := p.dbHandle.View(func(tx *bolt.Tx) error {
|
||||
bucket, err := getUsersBucket(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
foldersBucket, err := getFoldersBucket(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cursor := bucket.Cursor()
|
||||
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
|
||||
var user User
|
||||
err := json.Unmarshal(v, &user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
needFolders, ok := toFetch[user.Username]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if needFolders && len(user.VirtualFolders) > 0 {
|
||||
var folders []vfs.VirtualFolder
|
||||
for idx := range user.VirtualFolders {
|
||||
folder := &user.VirtualFolders[idx]
|
||||
baseFolder, err := folderExistsInternal(folder.Name, foldersBucket)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
folder.BaseVirtualFolder = baseFolder
|
||||
folders = append(folders, *folder)
|
||||
}
|
||||
user.VirtualFolders = folders
|
||||
}
|
||||
|
||||
user.SetEmptySecretsIfNil()
|
||||
user.PrepareForRendering()
|
||||
users = append(users, user)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (p *BoltProvider) getUsers(limit int, offset int, order string) ([]User, error) {
|
||||
users := make([]User, 0, limit)
|
||||
var err error
|
||||
|
|
|
@ -381,6 +381,26 @@ func (c *Config) IsDefenderSupported() bool {
|
|||
}
|
||||
}
|
||||
|
||||
// ActiveTransfer defines an active protocol transfer
|
||||
type ActiveTransfer struct {
|
||||
ID int64
|
||||
Type int
|
||||
ConnID string
|
||||
Username string
|
||||
FolderName string
|
||||
TruncatedSize int64
|
||||
CurrentULSize int64
|
||||
CurrentDLSize int64
|
||||
CreatedAt int64
|
||||
UpdatedAt int64
|
||||
}
|
||||
|
||||
// GetKey returns an aggregation key.
|
||||
// The same key will be returned for similar transfers
|
||||
func (t *ActiveTransfer) GetKey() string {
|
||||
return fmt.Sprintf("%v%v%v", t.Username, t.FolderName, t.Type)
|
||||
}
|
||||
|
||||
// DefenderEntry defines a defender entry
|
||||
type DefenderEntry struct {
|
||||
ID int64 `json:"-"`
|
||||
|
@ -476,6 +496,7 @@ type Provider interface {
|
|||
getUsers(limit int, offset int, order string) ([]User, error)
|
||||
dumpUsers() ([]User, error)
|
||||
getRecentlyUpdatedUsers(after int64) ([]User, error)
|
||||
getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error)
|
||||
updateLastLogin(username string) error
|
||||
updateAdminLastLogin(username string) error
|
||||
setUpdatedAt(username string)
|
||||
|
@ -1268,6 +1289,11 @@ func GetUsers(limit, offset int, order string) ([]User, error) {
|
|||
return provider.getUsers(limit, offset, order)
|
||||
}
|
||||
|
||||
// GetUsersForQuotaCheck returns the users with the fields required for a quota check
|
||||
func GetUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
|
||||
return provider.getUsersForQuotaCheck(toFetch)
|
||||
}
|
||||
|
||||
// AddFolder adds a new virtual folder.
|
||||
func AddFolder(folder *vfs.BaseVirtualFolder) error {
|
||||
return provider.addFolder(folder)
|
||||
|
|
|
@ -349,6 +349,7 @@ func (p *MemoryProvider) dumpUsers() ([]User, error) {
|
|||
for _, username := range p.dbHandle.usernames {
|
||||
u := p.dbHandle.users[username]
|
||||
user := u.getACopy()
|
||||
p.addVirtualFoldersToUser(&user)
|
||||
err = addCredentialsToUser(&user)
|
||||
if err != nil {
|
||||
return users, err
|
||||
|
@ -376,6 +377,28 @@ func (p *MemoryProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) {
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
|
||||
users := make([]User, 0, 30)
|
||||
p.dbHandle.Lock()
|
||||
defer p.dbHandle.Unlock()
|
||||
if p.dbHandle.isClosed {
|
||||
return users, errMemoryProviderClosed
|
||||
}
|
||||
for _, username := range p.dbHandle.usernames {
|
||||
if val, ok := toFetch[username]; ok {
|
||||
u := p.dbHandle.users[username]
|
||||
user := u.getACopy()
|
||||
if val {
|
||||
p.addVirtualFoldersToUser(&user)
|
||||
}
|
||||
user.PrepareForRendering()
|
||||
users = append(users, user)
|
||||
}
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User, error) {
|
||||
users := make([]User, 0, limit)
|
||||
var err error
|
||||
|
@ -396,6 +419,7 @@ func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User,
|
|||
}
|
||||
u := p.dbHandle.users[username]
|
||||
user := u.getACopy()
|
||||
p.addVirtualFoldersToUser(&user)
|
||||
user.PrepareForRendering()
|
||||
users = append(users, user)
|
||||
if len(users) >= limit {
|
||||
|
@ -411,6 +435,7 @@ func (p *MemoryProvider) getUsers(limit int, offset int, order string) ([]User,
|
|||
username := p.dbHandle.usernames[i]
|
||||
u := p.dbHandle.users[username]
|
||||
user := u.getACopy()
|
||||
p.addVirtualFoldersToUser(&user)
|
||||
user.PrepareForRendering()
|
||||
users = append(users, user)
|
||||
if len(users) >= limit {
|
||||
|
@ -427,7 +452,12 @@ func (p *MemoryProvider) userExists(username string) (User, error) {
|
|||
if p.dbHandle.isClosed {
|
||||
return User{}, errMemoryProviderClosed
|
||||
}
|
||||
return p.userExistsInternal(username)
|
||||
user, err := p.userExistsInternal(username)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
p.addVirtualFoldersToUser(&user)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) userExistsInternal(username string) (User, error) {
|
||||
|
@ -632,6 +662,22 @@ func (p *MemoryProvider) joinVirtualFoldersFields(user *User) []vfs.VirtualFolde
|
|||
return folders
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) addVirtualFoldersToUser(user *User) {
|
||||
if len(user.VirtualFolders) > 0 {
|
||||
var folders []vfs.VirtualFolder
|
||||
for idx := range user.VirtualFolders {
|
||||
folder := &user.VirtualFolders[idx]
|
||||
baseFolder, err := p.folderExistsInternal(folder.Name)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
folder.BaseVirtualFolder = baseFolder.GetACopy()
|
||||
folders = append(folders, *folder)
|
||||
}
|
||||
user.VirtualFolders = folders
|
||||
}
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) removeUserFromFolderMapping(folderName, username string) {
|
||||
folder, err := p.folderExistsInternal(folderName)
|
||||
if err == nil {
|
||||
|
@ -655,7 +701,8 @@ func (p *MemoryProvider) updateFoldersMappingInternal(folder vfs.BaseVirtualFold
|
|||
}
|
||||
|
||||
func (p *MemoryProvider) addOrUpdateFolderInternal(baseFolder *vfs.BaseVirtualFolder, username string, usedQuotaSize int64,
|
||||
usedQuotaFiles int, lastQuotaUpdate int64) (vfs.BaseVirtualFolder, error) {
|
||||
usedQuotaFiles int, lastQuotaUpdate int64) (vfs.BaseVirtualFolder, error,
|
||||
) {
|
||||
folder, err := p.folderExistsInternal(baseFolder.Name)
|
||||
if err == nil {
|
||||
// exists
|
||||
|
|
|
@ -186,6 +186,10 @@ func (p *MySQLProvider) getUsers(limit int, offset int, order string) ([]User, e
|
|||
return sqlCommonGetUsers(limit, offset, order, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
|
||||
return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) {
|
||||
return sqlCommonDumpFolders(p.dbHandle)
|
||||
}
|
||||
|
|
|
@ -198,6 +198,10 @@ func (p *PGSQLProvider) getUsers(limit int, offset int, order string) ([]User, e
|
|||
return sqlCommonGetUsers(limit, offset, order, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
|
||||
return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) {
|
||||
return sqlCommonDumpFolders(p.dbHandle)
|
||||
}
|
||||
|
|
|
@ -939,6 +939,90 @@ func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User,
|
|||
return getUsersWithVirtualFolders(ctx, users, dbHandle)
|
||||
}
|
||||
|
||||
func sqlCommonGetUsersForQuotaCheck(toFetch map[string]bool, dbHandle sqlQuerier) ([]User, error) {
|
||||
users := make([]User, 0, 30)
|
||||
|
||||
usernames := make([]string, 0, len(toFetch))
|
||||
for k := range toFetch {
|
||||
usernames = append(usernames, k)
|
||||
}
|
||||
|
||||
maxUsers := 30
|
||||
for len(usernames) > 0 {
|
||||
if maxUsers > len(usernames) {
|
||||
maxUsers = len(usernames)
|
||||
}
|
||||
usersRange, err := sqlCommonGetUsersRangeForQuotaCheck(usernames[:maxUsers], dbHandle)
|
||||
if err != nil {
|
||||
return users, err
|
||||
}
|
||||
users = append(users, usersRange...)
|
||||
usernames = usernames[maxUsers:]
|
||||
}
|
||||
|
||||
var usersWithFolders []User
|
||||
|
||||
validIdx := 0
|
||||
for _, user := range users {
|
||||
if toFetch[user.Username] {
|
||||
usersWithFolders = append(usersWithFolders, user)
|
||||
} else {
|
||||
users[validIdx] = user
|
||||
validIdx++
|
||||
}
|
||||
}
|
||||
users = users[:validIdx]
|
||||
if len(usersWithFolders) == 0 {
|
||||
return users, nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
usersWithFolders, err := getUsersWithVirtualFolders(ctx, usersWithFolders, dbHandle)
|
||||
if err != nil {
|
||||
return users, err
|
||||
}
|
||||
users = append(users, usersWithFolders...)
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func sqlCommonGetUsersRangeForQuotaCheck(usernames []string, dbHandle sqlQuerier) ([]User, error) {
|
||||
users := make([]User, 0, len(usernames))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getUsersForQuotaCheckQuery(len(usernames))
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return users, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
queryArgs := make([]interface{}, 0, len(usernames))
|
||||
for idx := range usernames {
|
||||
queryArgs = append(queryArgs, usernames[idx])
|
||||
}
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, queryArgs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var user User
|
||||
err = rows.Scan(&user.ID, &user.Username, &user.QuotaSize, &user.UsedQuotaSize)
|
||||
if err != nil {
|
||||
return users, err
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
return users, rows.Err()
|
||||
}
|
||||
|
||||
func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier) ([]User, error) {
|
||||
users := make([]User, 0, limit)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
|
|
|
@ -183,6 +183,10 @@ func (p *SQLiteProvider) getUsers(limit int, offset int, order string) ([]User,
|
|||
return sqlCommonGetUsers(limit, offset, order, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) {
|
||||
return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) {
|
||||
return sqlCommonDumpFolders(p.dbHandle)
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ const (
|
|||
|
||||
func getSQLPlaceholders() []string {
|
||||
var placeholders []string
|
||||
for i := 1; i <= 30; i++ {
|
||||
for i := 1; i <= 50; i++ {
|
||||
if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName {
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%v", i))
|
||||
} else {
|
||||
|
@ -263,6 +263,23 @@ func getUsersQuery(order string) string {
|
|||
order, sqlPlaceholders[0], sqlPlaceholders[1])
|
||||
}
|
||||
|
||||
func getUsersForQuotaCheckQuery(numArgs int) string {
|
||||
var sb strings.Builder
|
||||
for idx := 0; idx < numArgs; idx++ {
|
||||
if sb.Len() == 0 {
|
||||
sb.WriteString("(")
|
||||
} else {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(sqlPlaceholders[idx])
|
||||
}
|
||||
if sb.Len() > 0 {
|
||||
sb.WriteString(")")
|
||||
}
|
||||
return fmt.Sprintf(`SELECT id,username,quota_size,used_quota_size FROM %v WHERE username IN %v`,
|
||||
sqlTableUsers, sb.String())
|
||||
}
|
||||
|
||||
func getRecentlyUpdatedUsersQuery() string {
|
||||
return fmt.Sprintf(`SELECT %v FROM %v WHERE updated_at >= %v`, selectUserFields, sqlTableUsers, sqlPlaceholders[0])
|
||||
}
|
||||
|
|
|
@ -335,8 +335,8 @@ func (c *Connection) downloadFile(fs vfs.Fs, fsPath, ftpPath string, offset int6
|
|||
return nil, c.GetFsError(fs, err)
|
||||
}
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath, common.TransferDownload,
|
||||
0, 0, 0, false, fs)
|
||||
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath,
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
baseTransfer.SetFtpMode(c.getFTPMode())
|
||||
t := newTransfer(baseTransfer, nil, r, offset)
|
||||
|
||||
|
@ -402,7 +402,7 @@ func (c *Connection) handleFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath,
|
|||
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
|
||||
common.TransferUpload, 0, 0, maxWriteSize, true, fs)
|
||||
common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
|
||||
baseTransfer.SetFtpMode(c.getFTPMode())
|
||||
t := newTransfer(baseTransfer, w, nil, 0)
|
||||
|
||||
|
@ -452,6 +452,7 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve
|
|||
}
|
||||
|
||||
initialSize := int64(0)
|
||||
truncatedSize := int64(0) // bytes truncated and not included in quota
|
||||
if isResume {
|
||||
c.Log(logger.LevelDebug, "resuming upload requested, file path: %#v initial size: %v", filePath, fileSize)
|
||||
minWriteOffset = fileSize
|
||||
|
@ -473,13 +474,14 @@ func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolve
|
|||
}
|
||||
} else {
|
||||
initialSize = fileSize
|
||||
truncatedSize = fileSize
|
||||
}
|
||||
}
|
||||
|
||||
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
|
||||
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs)
|
||||
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs)
|
||||
baseTransfer.SetFtpMode(c.getFTPMode())
|
||||
t := newTransfer(baseTransfer, w, nil, 0)
|
||||
|
||||
|
|
|
@ -808,7 +808,7 @@ func TestTransferErrors(t *testing.T) {
|
|||
clientContext: mockCC,
|
||||
}
|
||||
baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, file.Name(), file.Name(), testfile,
|
||||
common.TransferDownload, 0, 0, 0, false, fs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
tr := newTransfer(baseTransfer, nil, nil, 0)
|
||||
err = tr.Close()
|
||||
assert.NoError(t, err)
|
||||
|
@ -826,7 +826,7 @@ func TestTransferErrors(t *testing.T) {
|
|||
r, _, err := pipeat.Pipe()
|
||||
assert.NoError(t, err)
|
||||
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile,
|
||||
common.TransferUpload, 0, 0, 0, false, fs)
|
||||
common.TransferUpload, 0, 0, 0, 0, false, fs)
|
||||
tr = newTransfer(baseTransfer, nil, r, 10)
|
||||
pos, err := tr.Seek(10, 0)
|
||||
assert.NoError(t, err)
|
||||
|
@ -838,7 +838,7 @@ func TestTransferErrors(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
pipeWriter := vfs.NewPipeWriter(w)
|
||||
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile,
|
||||
common.TransferUpload, 0, 0, 0, false, fs)
|
||||
common.TransferUpload, 0, 0, 0, 0, false, fs)
|
||||
tr = newTransfer(baseTransfer, pipeWriter, nil, 0)
|
||||
|
||||
err = r.Close()
|
||||
|
|
16
go.mod
16
go.mod
|
@ -7,8 +7,8 @@ require (
|
|||
github.com/Azure/azure-storage-blob-go v0.14.0
|
||||
github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962
|
||||
github.com/alexedwards/argon2id v0.0.0-20211130144151-3585854a6387
|
||||
github.com/aws/aws-sdk-go v1.42.35
|
||||
github.com/cockroachdb/cockroach-go/v2 v2.2.5
|
||||
github.com/aws/aws-sdk-go v1.42.37
|
||||
github.com/cockroachdb/cockroach-go/v2 v2.2.6
|
||||
github.com/eikenb/pipeat v0.0.0-20210603033007-44fc3ffce52b
|
||||
github.com/fclairamb/ftpserverlib v0.17.0
|
||||
github.com/fclairamb/go-log v0.2.0
|
||||
|
@ -35,7 +35,7 @@ require (
|
|||
github.com/pires/go-proxyproto v0.6.1
|
||||
github.com/pkg/sftp v1.13.5-0.20211217081921-1849af66afae
|
||||
github.com/pquerna/otp v1.3.0
|
||||
github.com/prometheus/client_golang v1.11.0
|
||||
github.com/prometheus/client_golang v1.12.0
|
||||
github.com/rs/cors v1.8.2
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/rs/zerolog v1.26.2-0.20211219225053-665519c4da50
|
||||
|
@ -62,8 +62,8 @@ require (
|
|||
|
||||
require (
|
||||
cloud.google.com/go v0.100.2 // indirect
|
||||
cloud.google.com/go/compute v1.0.0 // indirect
|
||||
cloud.google.com/go/iam v0.1.0 // indirect
|
||||
cloud.google.com/go/compute v1.1.0 // indirect
|
||||
cloud.google.com/go/iam v0.1.1 // indirect
|
||||
github.com/Azure/azure-pipeline-go v0.2.3 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/boombuler/barcode v1.0.1 // indirect
|
||||
|
@ -79,7 +79,7 @@ require (
|
|||
github.com/goccy/go-json v0.9.3 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/golang/protobuf v1.5.2 // indirect
|
||||
github.com/google/go-cmp v0.5.6 // indirect
|
||||
github.com/google/go-cmp v0.5.7 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.1.1 // indirect
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
|
@ -126,10 +126,10 @@ require (
|
|||
golang.org/x/tools v0.1.8 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 // indirect
|
||||
google.golang.org/grpc v1.43.0 // indirect
|
||||
google.golang.org/protobuf v1.27.1 // indirect
|
||||
gopkg.in/ini.v1 v1.66.2 // indirect
|
||||
gopkg.in/ini.v1 v1.66.3 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
|
||||
)
|
||||
|
|
31
go.sum
31
go.sum
|
@ -46,14 +46,14 @@ cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUM
|
|||
cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc=
|
||||
cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ=
|
||||
cloud.google.com/go/compute v0.1.0/go.mod h1:GAesmwr110a34z04OlxYkATPBEfVhkymfTBXtfbBFow=
|
||||
cloud.google.com/go/compute v1.0.0 h1:SJYBzih8Jj9EUm6IDirxKG0I0AGWduhtb6BmdqWarw4=
|
||||
cloud.google.com/go/compute v1.0.0/go.mod h1:GAesmwr110a34z04OlxYkATPBEfVhkymfTBXtfbBFow=
|
||||
cloud.google.com/go/compute v1.1.0 h1:pyPhehLfZ6pVzRgJmXGYvCY4K7WSWRhVw0AwhgVvS84=
|
||||
cloud.google.com/go/compute v1.1.0/go.mod h1:2NIffxgWfORSI7EOYMFatGTfjMLnqrOKBEyYb6NoRgA=
|
||||
cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
|
||||
cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
|
||||
cloud.google.com/go/firestore v1.5.0/go.mod h1:c4nNYR1qdq7eaZ+jSc5fonrQN2k3M7sWATcYTiakjEo=
|
||||
cloud.google.com/go/firestore v1.6.1/go.mod h1:asNXNOzBdyVQmEU+ggO8UPodTkEVFW5Qx+rwHnAz+EY=
|
||||
cloud.google.com/go/iam v0.1.0 h1:W2vbGCrE3Z7J/x3WXLxxGl9LMSB2uhsAA7Ss/6u/qRY=
|
||||
cloud.google.com/go/iam v0.1.0/go.mod h1:vcUNEa0pEm0qRVpmWepWaFMIAI8/hjB9mO8rNCJtF6c=
|
||||
cloud.google.com/go/iam v0.1.1 h1:4CapQyNFjiksks1/x7jsvsygFPhihslYk5GptIrlX68=
|
||||
cloud.google.com/go/iam v0.1.1/go.mod h1:CKqrcnI/suGpybEHxZ7BMehL0oA4LpdyJdUlTl9jVMw=
|
||||
cloud.google.com/go/kms v0.1.0 h1:VXAb5OzejDcyhFzIDeZ5n5AUdlsFnCyexuascIwWMj0=
|
||||
cloud.google.com/go/kms v0.1.0/go.mod h1:8Qp8PCAypHg4FdmlyW1QRAv09BGQ9Uzh7JnmIZxPk+c=
|
||||
cloud.google.com/go/monitoring v0.1.0/go.mod h1:Hpm3XfzJv+UTiXzCG5Ffp0wijzHTC7Cv4eR7o3x/fEE=
|
||||
|
@ -141,8 +141,8 @@ github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgI
|
|||
github.com/aws/aws-sdk-go v1.15.27/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0=
|
||||
github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro=
|
||||
github.com/aws/aws-sdk-go v1.40.34/go.mod h1:585smgzpB/KqRA+K3y/NL/oYRqQvpNJYvLm+LY1U59Q=
|
||||
github.com/aws/aws-sdk-go v1.42.35 h1:N4N9buNs4YlosI9N0+WYrq8cIZwdgv34yRbxzZlTvFs=
|
||||
github.com/aws/aws-sdk-go v1.42.35/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc=
|
||||
github.com/aws/aws-sdk-go v1.42.37 h1:EIziSq3REaoi1LgUBgxoQr29DQS7GYHnBbZPajtJmXM=
|
||||
github.com/aws/aws-sdk-go v1.42.37/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc=
|
||||
github.com/aws/aws-sdk-go-v2 v1.9.0/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.7.0/go.mod h1:w9+nMZ7soXCe5nT46Ri354SNhXDQ6v+V5wqDjnZE+GY=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.4.0/go.mod h1:dgGR+Qq7Wjcd4AOAW5Rf5Tnv3+x7ed6kETXyS9WCuAY=
|
||||
|
@ -190,8 +190,8 @@ github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWH
|
|||
github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||
github.com/cncf/xds/go v0.0.0-20211130200136-a8f946100490/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
|
||||
github.com/cockroachdb/cockroach-go/v2 v2.2.5 h1:tfPdGHO5YpmrpN2ikJZYpaSGgU8WALwwjH3s+msiTQ0=
|
||||
github.com/cockroachdb/cockroach-go/v2 v2.2.5/go.mod h1:q4ZRgO6CQpwNyEvEwSxwNrOSVchsmzrBnAv3HuZ3Abc=
|
||||
github.com/cockroachdb/cockroach-go/v2 v2.2.6 h1:LTh++UIVvmDBihDo1oYbM8+OruXheusw+ILCONlAm/w=
|
||||
github.com/cockroachdb/cockroach-go/v2 v2.2.6/go.mod h1:q4ZRgO6CQpwNyEvEwSxwNrOSVchsmzrBnAv3HuZ3Abc=
|
||||
github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
|
||||
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||
github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f h1:JOrtw2xFKzlg+cbHpyrpLDmnN1HqhBfnX7WDiW7eG2c=
|
||||
|
@ -343,8 +343,9 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
|
||||
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
|
||||
github.com/google/go-replayers/grpcreplay v1.1.0/go.mod h1:qzAvJ8/wi57zq7gWqaE6AwLM6miiXUQwP1S+I9icmhk=
|
||||
github.com/google/go-replayers/httpreplay v1.0.0/go.mod h1:LJhKoTwS5Wy5Ld/peq8dFFG5OfJyHEz7ft+DsTUv25M=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
|
@ -649,8 +650,9 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
|
|||
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
||||
github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU=
|
||||
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
|
||||
github.com/prometheus/client_golang v1.11.0 h1:HNkLOAEQMIDv/K+04rukrLx6ch7msSRwf3/SASFAGtQ=
|
||||
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
|
||||
github.com/prometheus/client_golang v1.12.0 h1:C+UIj/QWtmqY13Arb8kwMt5j34/0Z2iKamrJ+ryC0Gg=
|
||||
github.com/prometheus/client_golang v1.12.0/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY=
|
||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
|
@ -1075,6 +1077,7 @@ google.golang.org/api v0.59.0/go.mod h1:sT2boj7M9YJxZzgeZqXogmhfmRWDtPzT31xkieUb
|
|||
google.golang.org/api v0.61.0/go.mod h1:xQRti5UdCmoCEqFxcz93fTl338AVqDgyaDRuOZ3hg9I=
|
||||
google.golang.org/api v0.62.0/go.mod h1:dKmwPCydfsad4qCH08MSdgWjfHOyfpd4VtDGgRFdavw=
|
||||
google.golang.org/api v0.63.0/go.mod h1:gs4ij2ffTRXwuzzgJl/56BdwJaA194ijkfn++9tDuPo=
|
||||
google.golang.org/api v0.64.0/go.mod h1:931CdxA8Rm4t6zqTFGSsgwbAEZ2+GMYurbndwSimebM=
|
||||
google.golang.org/api v0.65.0 h1:MTW9c+LIBAbwoS1Gb+YV7NjFBt2f7GtAS5hIzh2NjgQ=
|
||||
google.golang.org/api v0.65.0/go.mod h1:ArYhxgGadlWmqO1IqVujw6Cs8IdD33bTmzKo2Sh+cbg=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
|
@ -1162,8 +1165,9 @@ google.golang.org/genproto v0.0.0-20211208223120-3a66f561d7aa/go.mod h1:5CzLGKJ6
|
|||
google.golang.org/genproto v0.0.0-20211221195035-429b39de9b1c/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
|
||||
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
|
||||
google.golang.org/genproto v0.0.0-20220107163113-42d7afdf6368/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
|
||||
google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0 h1:aCsSLXylHWFno0r4S3joLpiaWayvqd2Mn4iSvx4WZZc=
|
||||
google.golang.org/genproto v0.0.0-20220114231437-d2e6a121cae0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
|
||||
google.golang.org/genproto v0.0.0-20220111164026-67b88f271998/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
|
||||
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 h1:zzNejm+EgrbLfDZ6lu9Uud2IVvHySPl8vQzf04laR5Q=
|
||||
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
|
||||
google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
||||
|
@ -1217,8 +1221,9 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
|
|||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
|
||||
gopkg.in/ini.v1 v1.66.2 h1:XfR1dOYubytKy4Shzc2LHrrGhU0lDCfDGG1yLPmpgsI=
|
||||
gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/ini.v1 v1.66.3 h1:jRskFVxYaMGAMUbN0UZ7niA9gzL9B49DOqE78vg0k3w=
|
||||
gopkg.in/ini.v1 v1.66.3/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package httpd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
|
@ -11,8 +10,6 @@ import (
|
|||
"github.com/drakkan/sftpgo/v2/vfs"
|
||||
)
|
||||
|
||||
var errTransferAborted = errors.New("transfer aborted")
|
||||
|
||||
type httpdFile struct {
|
||||
*common.BaseTransfer
|
||||
writer io.WriteCloser
|
||||
|
@ -42,7 +39,9 @@ func newHTTPDFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter,
|
|||
// Read reads the contents to downloads.
|
||||
func (f *httpdFile) Read(p []byte) (n int, err error) {
|
||||
if atomic.LoadInt32(&f.AbortTransfer) == 1 {
|
||||
return 0, errTransferAborted
|
||||
err := f.GetAbortError()
|
||||
f.TransferError(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
f.Connection.UpdateLastActivity()
|
||||
|
@ -61,7 +60,9 @@ func (f *httpdFile) Read(p []byte) (n int, err error) {
|
|||
// Write writes the contents to upload
|
||||
func (f *httpdFile) Write(p []byte) (n int, err error) {
|
||||
if atomic.LoadInt32(&f.AbortTransfer) == 1 {
|
||||
return 0, errTransferAborted
|
||||
err := f.GetAbortError()
|
||||
f.TransferError(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
f.Connection.UpdateLastActivity()
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
|
@ -113,7 +114,7 @@ func (c *Connection) getFileReader(name string, offset int64, method string) (io
|
|||
}
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, name, common.TransferDownload,
|
||||
0, 0, 0, false, fs)
|
||||
0, 0, 0, 0, false, fs)
|
||||
return newHTTPDFile(baseTransfer, nil, r), nil
|
||||
}
|
||||
|
||||
|
@ -190,6 +191,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
|
|||
}
|
||||
|
||||
initialSize := int64(0)
|
||||
truncatedSize := int64(0) // bytes truncated and not included in quota
|
||||
if !isNewFile {
|
||||
if vfs.IsLocalOrSFTPFs(fs) {
|
||||
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
||||
|
@ -203,6 +205,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
|
|||
}
|
||||
} else {
|
||||
initialSize = fileSize
|
||||
truncatedSize = fileSize
|
||||
}
|
||||
if maxWriteSize > 0 {
|
||||
maxWriteSize += fileSize
|
||||
|
@ -212,7 +215,7 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
|
|||
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
|
||||
common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs)
|
||||
common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs)
|
||||
return newHTTPDFile(baseTransfer, w, nil), nil
|
||||
}
|
||||
|
||||
|
@ -232,15 +235,17 @@ func newThrottledReader(r io.ReadCloser, limit int64, conn *Connection) *throttl
|
|||
|
||||
type throttledReader struct {
|
||||
bytesRead int64
|
||||
id uint64
|
||||
id int64
|
||||
limit int64
|
||||
r io.ReadCloser
|
||||
abortTransfer int32
|
||||
start time.Time
|
||||
conn *Connection
|
||||
mu sync.Mutex
|
||||
errAbort error
|
||||
}
|
||||
|
||||
func (t *throttledReader) GetID() uint64 {
|
||||
func (t *throttledReader) GetID() int64 {
|
||||
return t.id
|
||||
}
|
||||
|
||||
|
@ -252,6 +257,14 @@ func (t *throttledReader) GetSize() int64 {
|
|||
return atomic.LoadInt64(&t.bytesRead)
|
||||
}
|
||||
|
||||
func (t *throttledReader) GetDownloadedSize() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (t *throttledReader) GetUploadedSize() int64 {
|
||||
return atomic.LoadInt64(&t.bytesRead)
|
||||
}
|
||||
|
||||
func (t *throttledReader) GetVirtualPath() string {
|
||||
return "**reading request body**"
|
||||
}
|
||||
|
@ -260,10 +273,31 @@ func (t *throttledReader) GetStartTime() time.Time {
|
|||
return t.start
|
||||
}
|
||||
|
||||
func (t *throttledReader) SignalClose() {
|
||||
func (t *throttledReader) GetAbortError() error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.errAbort != nil {
|
||||
return t.errAbort
|
||||
}
|
||||
return common.ErrTransferAborted
|
||||
}
|
||||
|
||||
func (t *throttledReader) SignalClose(err error) {
|
||||
t.mu.Lock()
|
||||
t.errAbort = err
|
||||
t.mu.Unlock()
|
||||
atomic.StoreInt32(&(t.abortTransfer), 1)
|
||||
}
|
||||
|
||||
func (t *throttledReader) GetTruncatedSize() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (t *throttledReader) GetMaxAllowedSize() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (t *throttledReader) Truncate(fsPath string, size int64) (int64, error) {
|
||||
return 0, vfs.ErrVfsUnsupported
|
||||
}
|
||||
|
@ -278,7 +312,7 @@ func (t *throttledReader) SetTimes(fsPath string, atime time.Time, mtime time.Ti
|
|||
|
||||
func (t *throttledReader) Read(p []byte) (n int, err error) {
|
||||
if atomic.LoadInt32(&t.abortTransfer) == 1 {
|
||||
return 0, errTransferAborted
|
||||
return 0, t.GetAbortError()
|
||||
}
|
||||
|
||||
t.conn.UpdateLastActivity()
|
||||
|
|
|
@ -1844,12 +1844,15 @@ func TestThrottledHandler(t *testing.T) {
|
|||
tr := &throttledReader{
|
||||
r: io.NopCloser(bytes.NewBuffer(nil)),
|
||||
}
|
||||
assert.Equal(t, int64(0), tr.GetTruncatedSize())
|
||||
err := tr.Close()
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, tr.GetRealFsPath("real path"))
|
||||
assert.False(t, tr.SetTimes("p", time.Now(), time.Now()))
|
||||
_, err = tr.Truncate("", 0)
|
||||
assert.ErrorIs(t, err, vfs.ErrVfsUnsupported)
|
||||
err = tr.GetAbortError()
|
||||
assert.ErrorIs(t, err, common.ErrTransferAborted)
|
||||
}
|
||||
|
||||
func TestHTTPDFile(t *testing.T) {
|
||||
|
@ -1879,7 +1882,7 @@ func TestHTTPDFile(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, p, p, name, common.TransferDownload,
|
||||
0, 0, 0, false, fs)
|
||||
0, 0, 0, 0, false, fs)
|
||||
httpdFile := newHTTPDFile(baseTransfer, nil, nil)
|
||||
// the file is closed, read should fail
|
||||
buf := make([]byte, 100)
|
||||
|
@ -1899,9 +1902,9 @@ func TestHTTPDFile(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
assert.Error(t, httpdFile.ErrTransfer)
|
||||
assert.Equal(t, err, httpdFile.ErrTransfer)
|
||||
httpdFile.SignalClose()
|
||||
httpdFile.SignalClose(nil)
|
||||
_, err = httpdFile.Write(nil)
|
||||
assert.ErrorIs(t, err, errTransferAborted)
|
||||
assert.ErrorIs(t, err, common.ErrQuotaExceeded)
|
||||
}
|
||||
|
||||
func TestChangeUserPwd(t *testing.T) {
|
||||
|
|
|
@ -85,7 +85,7 @@ func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
|
|||
}
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, request.Filepath, common.TransferDownload,
|
||||
0, 0, 0, false, fs)
|
||||
0, 0, 0, 0, false, fs)
|
||||
t := newTransfer(baseTransfer, nil, r, nil)
|
||||
|
||||
return t, nil
|
||||
|
@ -364,7 +364,7 @@ func (c *Connection) handleSFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath
|
|||
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
|
||||
common.TransferUpload, 0, 0, maxWriteSize, true, fs)
|
||||
common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
|
||||
t := newTransfer(baseTransfer, w, nil, errForRead)
|
||||
|
||||
return t, nil
|
||||
|
@ -415,6 +415,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO
|
|||
}
|
||||
|
||||
initialSize := int64(0)
|
||||
truncatedSize := int64(0) // bytes truncated and not included in quota
|
||||
if isResume {
|
||||
c.Log(logger.LevelDebug, "resuming upload requested, file path %#v initial size: %v has append flag %v",
|
||||
filePath, fileSize, pflags.Append)
|
||||
|
@ -436,13 +437,14 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO
|
|||
}
|
||||
} else {
|
||||
initialSize = fileSize
|
||||
truncatedSize = fileSize
|
||||
}
|
||||
}
|
||||
|
||||
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
|
||||
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs)
|
||||
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs)
|
||||
t := newTransfer(baseTransfer, w, nil, errForRead)
|
||||
|
||||
return t, nil
|
||||
|
|
|
@ -162,7 +162,8 @@ func TestUploadResumeInvalidOffset(t *testing.T) {
|
|||
}
|
||||
fs := vfs.NewOsFs("", os.TempDir(), "")
|
||||
conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user)
|
||||
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferUpload, 10, 0, 0, false, fs)
|
||||
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile,
|
||||
common.TransferUpload, 10, 0, 0, 0, false, fs)
|
||||
transfer := newTransfer(baseTransfer, nil, nil, nil)
|
||||
_, err = transfer.WriteAt([]byte("test"), 0)
|
||||
assert.Error(t, err, "upload with invalid offset must fail")
|
||||
|
@ -193,7 +194,8 @@ func TestReadWriteErrors(t *testing.T) {
|
|||
}
|
||||
fs := vfs.NewOsFs("", os.TempDir(), "")
|
||||
conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user)
|
||||
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
|
||||
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload,
|
||||
0, 0, 0, 0, false, fs)
|
||||
transfer := newTransfer(baseTransfer, nil, nil, nil)
|
||||
err = file.Close()
|
||||
assert.NoError(t, err)
|
||||
|
@ -207,7 +209,8 @@ func TestReadWriteErrors(t *testing.T) {
|
|||
|
||||
r, _, err := pipeat.Pipe()
|
||||
assert.NoError(t, err)
|
||||
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
|
||||
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload,
|
||||
0, 0, 0, 0, false, fs)
|
||||
transfer = newTransfer(baseTransfer, nil, r, nil)
|
||||
err = transfer.Close()
|
||||
assert.NoError(t, err)
|
||||
|
@ -217,7 +220,8 @@ func TestReadWriteErrors(t *testing.T) {
|
|||
r, w, err := pipeat.Pipe()
|
||||
assert.NoError(t, err)
|
||||
pipeWriter := vfs.NewPipeWriter(w)
|
||||
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
|
||||
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload,
|
||||
0, 0, 0, 0, false, fs)
|
||||
transfer = newTransfer(baseTransfer, pipeWriter, nil, nil)
|
||||
|
||||
err = r.Close()
|
||||
|
@ -264,7 +268,8 @@ func TestTransferCancelFn(t *testing.T) {
|
|||
}
|
||||
fs := vfs.NewOsFs("", os.TempDir(), "")
|
||||
conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user)
|
||||
baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
|
||||
baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), file.Name(), testfile, common.TransferDownload,
|
||||
0, 0, 0, 0, false, fs)
|
||||
transfer := newTransfer(baseTransfer, nil, nil, nil)
|
||||
|
||||
errFake := errors.New("fake error, this will trigger cancelFn")
|
||||
|
@ -971,8 +976,8 @@ func TestSystemCommandErrors(t *testing.T) {
|
|||
WriteError: nil,
|
||||
}
|
||||
sshCmd.connection.channel = &mockSSHChannel
|
||||
baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", "", common.TransferDownload,
|
||||
0, 0, 0, false, fs)
|
||||
baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", "",
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
transfer := newTransfer(baseTransfer, nil, nil, nil)
|
||||
destBuff := make([]byte, 65535)
|
||||
dst := bytes.NewBuffer(destBuff)
|
||||
|
@ -1639,7 +1644,7 @@ func TestSCPUploadFiledata(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, scpCommand.connection.BaseConnection, nil, file.Name(), file.Name(),
|
||||
"/"+testfile, common.TransferDownload, 0, 0, 0, true, fs)
|
||||
"/"+testfile, common.TransferDownload, 0, 0, 0, 0, true, fs)
|
||||
transfer := newTransfer(baseTransfer, nil, nil, nil)
|
||||
|
||||
err = scpCommand.getUploadFileData(2, transfer)
|
||||
|
@ -1724,7 +1729,7 @@ func TestUploadError(t *testing.T) {
|
|||
file, err := os.Create(fileTempName)
|
||||
assert.NoError(t, err)
|
||||
baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, testfile, file.Name(),
|
||||
testfile, common.TransferUpload, 0, 0, 0, true, fs)
|
||||
testfile, common.TransferUpload, 0, 0, 0, 0, true, fs)
|
||||
transfer := newTransfer(baseTransfer, nil, nil, nil)
|
||||
|
||||
errFake := errors.New("fake error")
|
||||
|
@ -1782,7 +1787,8 @@ func TestTransferFailingReader(t *testing.T) {
|
|||
|
||||
r, _, err := pipeat.Pipe()
|
||||
assert.NoError(t, err)
|
||||
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, fsPath, filepath.Base(fsPath), common.TransferUpload, 0, 0, 0, false, fs)
|
||||
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, fsPath, filepath.Base(fsPath),
|
||||
common.TransferUpload, 0, 0, 0, 0, false, fs)
|
||||
errRead := errors.New("read is not allowed")
|
||||
tr := newTransfer(baseTransfer, nil, r, errRead)
|
||||
_, err = tr.ReadAt(buf, 0)
|
||||
|
|
|
@ -238,6 +238,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
|
|||
}
|
||||
|
||||
initialSize := int64(0)
|
||||
truncatedSize := int64(0) // bytes truncated and not included in quota
|
||||
if !isNewFile {
|
||||
if vfs.IsLocalOrSFTPFs(fs) {
|
||||
vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
||||
|
@ -251,6 +252,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
|
|||
}
|
||||
} else {
|
||||
initialSize = fileSize
|
||||
truncatedSize = initialSize
|
||||
}
|
||||
if maxWriteSize > 0 {
|
||||
maxWriteSize += fileSize
|
||||
|
@ -260,7 +262,7 @@ func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string,
|
|||
vfs.SetPathPermissions(fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID())
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
|
||||
common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs)
|
||||
common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs)
|
||||
t := newTransfer(baseTransfer, w, nil, nil)
|
||||
|
||||
return c.getUploadFileData(sizeToRead, t)
|
||||
|
@ -529,7 +531,7 @@ func (c *scpCommand) handleDownload(filePath string) error {
|
|||
}
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, p, filePath,
|
||||
common.TransferDownload, 0, 0, 0, false, fs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
t := newTransfer(baseTransfer, nil, r, nil)
|
||||
|
||||
err = c.sendDownloadFileData(fs, p, stat, t)
|
||||
|
|
|
@ -356,7 +356,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
|
|||
go func() {
|
||||
defer stdin.Close()
|
||||
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
|
||||
common.TransferUpload, 0, 0, remainingQuotaSize, false, command.fs)
|
||||
common.TransferUpload, 0, 0, remainingQuotaSize, 0, false, command.fs)
|
||||
transfer := newTransfer(baseTransfer, nil, nil, nil)
|
||||
|
||||
w, e := transfer.copyFromReaderToWriter(stdin, c.connection.channel)
|
||||
|
@ -369,7 +369,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
|
|||
|
||||
go func() {
|
||||
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
|
||||
common.TransferDownload, 0, 0, 0, false, command.fs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, command.fs)
|
||||
transfer := newTransfer(baseTransfer, nil, nil, nil)
|
||||
|
||||
w, e := transfer.copyFromReaderToWriter(c.connection.channel, stdout)
|
||||
|
@ -383,7 +383,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
|
|||
|
||||
go func() {
|
||||
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
|
||||
common.TransferDownload, 0, 0, 0, false, command.fs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, command.fs)
|
||||
transfer := newTransfer(baseTransfer, nil, nil, nil)
|
||||
|
||||
w, e := transfer.copyFromReaderToWriter(c.connection.channel.(ssh.Channel).Stderr(), stderr)
|
||||
|
|
|
@ -4,23 +4,23 @@ go 1.17
|
|||
|
||||
require (
|
||||
github.com/hashicorp/go-plugin v1.4.3
|
||||
github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a
|
||||
github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/fatih/color v1.13.0 // indirect
|
||||
github.com/golang/protobuf v1.5.2 // indirect
|
||||
github.com/google/go-cmp v0.5.6 // indirect
|
||||
github.com/hashicorp/go-hclog v1.0.0 // indirect
|
||||
github.com/hashicorp/go-hclog v1.1.0 // indirect
|
||||
github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87 // indirect
|
||||
github.com/mattn/go-colorable v0.1.12 // indirect
|
||||
github.com/mattn/go-isatty v0.0.14 // indirect
|
||||
github.com/mitchellh/go-testing-interface v1.14.1 // indirect
|
||||
github.com/oklog/run v1.1.0 // indirect
|
||||
golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98 // indirect
|
||||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e // indirect
|
||||
golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d // indirect
|
||||
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect
|
||||
golang.org/x/text v0.3.7 // indirect
|
||||
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb // indirect
|
||||
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 // indirect
|
||||
google.golang.org/grpc v1.43.0 // indirect
|
||||
google.golang.org/protobuf v1.27.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
|
||||
|
|
|
@ -57,8 +57,9 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||
github.com/hashicorp/go-hclog v0.14.1/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
|
||||
github.com/hashicorp/go-hclog v1.0.0 h1:bkKf0BeBXcSYa7f5Fyi9gMuQ8gNsxeiNpZjR6VxNZeo=
|
||||
github.com/hashicorp/go-hclog v1.0.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
|
||||
github.com/hashicorp/go-hclog v1.1.0 h1:QsGcniKx5/LuX2eYoeL+Np3UKYPNaN7YKpTh29h8rbw=
|
||||
github.com/hashicorp/go-hclog v1.1.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
|
||||
github.com/hashicorp/go-plugin v1.4.3 h1:DXmvivbWD5qdiBts9TpBC7BYL1Aia5sxbRgQB+v6UZM=
|
||||
github.com/hashicorp/go-plugin v1.4.3/go.mod h1:5fGEH17QVwTTcR0zV7yhDPLLmFX9YSZ38b18Udy6vYQ=
|
||||
github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM=
|
||||
|
@ -85,8 +86,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a h1:JJc19rE0eW2knPa/KIFYvqyu25CwzKltJ5Cw1kK3o4A=
|
||||
github.com/sftpgo/sdk v0.0.0-20220106101837-50e87c59705a/go.mod h1:Bhgac6kiwIziILXLzH4wepT8lQXyhF83poDXqZorN6Q=
|
||||
github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea h1:ouwL3x9tXiAXIhdXtJGONd905f1dBLu3HhfFoaTq24k=
|
||||
github.com/sftpgo/sdk v0.0.0-20220115154521-b31d253a0bea/go.mod h1:Bhgac6kiwIziILXLzH4wepT8lQXyhF83poDXqZorN6Q=
|
||||
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
|
@ -110,8 +111,9 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
|
|||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98 h1:+6WJMRLHlD7X7frgp7TUZ36RnQzSf9wVVTNakEp+nqY=
|
||||
golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d h1:1n1fc535VhN8SYtD4cDUyNlfpAF2ROMM9+11equK3hs=
|
||||
golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
|
@ -132,8 +134,9 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM=
|
||||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 h1:XfKQ4OlFl8okEOr5UvAqFRVj8pY/4yfcXrddB8qAbU0=
|
||||
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
|
@ -156,8 +159,9 @@ google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoA
|
|||
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
|
||||
google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
|
||||
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
|
||||
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb h1:ZrsicilzPCS/Xr8qtBZZLpy4P9TYXAfl49ctG1/5tgw=
|
||||
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
|
||||
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 h1:zzNejm+EgrbLfDZ6lu9Uud2IVvHySPl8vQzf04laR5Q=
|
||||
google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
|
||||
google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
|
||||
|
|
|
@ -149,8 +149,8 @@ func (c *Connection) getFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File
|
|||
}
|
||||
}
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath, common.TransferDownload,
|
||||
0, 0, 0, false, fs)
|
||||
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath,
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
|
||||
return newWebDavFile(baseTransfer, nil, r), nil
|
||||
}
|
||||
|
@ -214,7 +214,7 @@ func (c *Connection) handleUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, re
|
|||
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
|
||||
common.TransferUpload, 0, 0, maxWriteSize, true, fs)
|
||||
common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
|
||||
|
||||
return newWebDavFile(baseTransfer, w, nil), nil
|
||||
}
|
||||
|
@ -252,6 +252,7 @@ func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePat
|
|||
return nil, c.GetFsError(fs, err)
|
||||
}
|
||||
initialSize := int64(0)
|
||||
truncatedSize := int64(0) // bytes truncated and not included in quota
|
||||
if vfs.IsLocalOrSFTPFs(fs) {
|
||||
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
||||
if err == nil {
|
||||
|
@ -264,12 +265,13 @@ func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePat
|
|||
}
|
||||
} else {
|
||||
initialSize = fileSize
|
||||
truncatedSize = fileSize
|
||||
}
|
||||
|
||||
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
|
||||
|
||||
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
|
||||
common.TransferUpload, 0, initialSize, maxWriteSize, false, fs)
|
||||
common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, false, fs)
|
||||
|
||||
return newWebDavFile(baseTransfer, w, nil), nil
|
||||
}
|
||||
|
|
|
@ -695,7 +695,7 @@ func TestContentType(t *testing.T) {
|
|||
testFilePath := filepath.Join(user.HomeDir, testFile)
|
||||
ctx := context.Background()
|
||||
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
|
||||
common.TransferDownload, 0, 0, 0, false, fs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
fs = newMockOsFs(nil, false, fs.ConnectionID(), user.GetHomeDir(), nil)
|
||||
err := os.WriteFile(testFilePath, []byte(""), os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
|
@ -745,7 +745,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
|
|||
}
|
||||
testFilePath := filepath.Join(user.HomeDir, testFile)
|
||||
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
|
||||
common.TransferUpload, 0, 0, 0, false, fs)
|
||||
common.TransferUpload, 0, 0, 0, 0, false, fs)
|
||||
davFile := newWebDavFile(baseTransfer, nil, nil)
|
||||
p := make([]byte, 1)
|
||||
_, err := davFile.Read(p)
|
||||
|
@ -763,7 +763,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
|
||||
common.TransferDownload, 0, 0, 0, false, fs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
davFile = newWebDavFile(baseTransfer, nil, nil)
|
||||
_, err = davFile.Read(p)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
@ -771,7 +771,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
|
|||
assert.True(t, os.IsNotExist(err))
|
||||
|
||||
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
|
||||
common.TransferDownload, 0, 0, 0, false, fs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
err = os.WriteFile(testFilePath, []byte(""), os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
f, err := os.Open(testFilePath)
|
||||
|
@ -796,7 +796,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
mockFs := newMockOsFs(nil, false, fs.ConnectionID(), user.HomeDir, r)
|
||||
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
|
||||
common.TransferDownload, 0, 0, 0, false, mockFs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, mockFs)
|
||||
davFile = newWebDavFile(baseTransfer, nil, nil)
|
||||
|
||||
writeContent := []byte("content\r\n")
|
||||
|
@ -816,7 +816,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
|
||||
common.TransferDownload, 0, 0, 0, false, fs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
davFile = newWebDavFile(baseTransfer, nil, nil)
|
||||
davFile.writer = f
|
||||
err = davFile.Close()
|
||||
|
@ -841,7 +841,7 @@ func TestTransferSeek(t *testing.T) {
|
|||
testFilePath := filepath.Join(user.HomeDir, testFile)
|
||||
testFileContents := []byte("content")
|
||||
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
|
||||
common.TransferUpload, 0, 0, 0, false, fs)
|
||||
common.TransferUpload, 0, 0, 0, 0, false, fs)
|
||||
davFile := newWebDavFile(baseTransfer, nil, nil)
|
||||
_, err := davFile.Seek(0, io.SeekStart)
|
||||
assert.EqualError(t, err, common.ErrOpUnsupported.Error())
|
||||
|
@ -849,7 +849,7 @@ func TestTransferSeek(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
|
||||
common.TransferDownload, 0, 0, 0, false, fs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
davFile = newWebDavFile(baseTransfer, nil, nil)
|
||||
_, err = davFile.Seek(0, io.SeekCurrent)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
@ -863,14 +863,14 @@ func TestTransferSeek(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
|
||||
common.TransferDownload, 0, 0, 0, false, fs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
davFile = newWebDavFile(baseTransfer, nil, nil)
|
||||
_, err = davFile.Seek(0, io.SeekStart)
|
||||
assert.Error(t, err)
|
||||
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
|
||||
|
||||
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
|
||||
common.TransferDownload, 0, 0, 0, false, fs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
davFile = newWebDavFile(baseTransfer, nil, nil)
|
||||
res, err := davFile.Seek(0, io.SeekStart)
|
||||
assert.NoError(t, err)
|
||||
|
@ -885,14 +885,14 @@ func TestTransferSeek(t *testing.T) {
|
|||
assert.Nil(t, err)
|
||||
|
||||
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile,
|
||||
common.TransferDownload, 0, 0, 0, false, fs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
davFile = newWebDavFile(baseTransfer, nil, nil)
|
||||
_, err = davFile.Seek(0, io.SeekEnd)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
|
||||
|
||||
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
|
||||
common.TransferDownload, 0, 0, 0, false, fs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
davFile = newWebDavFile(baseTransfer, nil, nil)
|
||||
davFile.reader = f
|
||||
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil)
|
||||
|
@ -907,7 +907,7 @@ func TestTransferSeek(t *testing.T) {
|
|||
assert.Equal(t, int64(5), res)
|
||||
|
||||
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile,
|
||||
common.TransferDownload, 0, 0, 0, false, fs)
|
||||
common.TransferDownload, 0, 0, 0, 0, false, fs)
|
||||
|
||||
davFile = newWebDavFile(baseTransfer, nil, nil)
|
||||
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil)
|
||||
|
|
Loading…
Reference in a new issue