mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-21 23:20:24 +00:00
use the new atomic types introduced in Go 1.19
we depend on Go 1.19 anyway Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
parent
da03f6c4e3
commit
95e9106902
22 changed files with 231 additions and 231 deletions
|
@ -23,13 +23,13 @@ import (
|
|||
|
||||
// clienstMap is a struct containing the map of the connected clients
|
||||
type clientsMap struct {
|
||||
totalConnections int32
|
||||
totalConnections atomic.Int32
|
||||
mu sync.RWMutex
|
||||
clients map[string]int
|
||||
}
|
||||
|
||||
func (c *clientsMap) add(source string) {
|
||||
atomic.AddInt32(&c.totalConnections, 1)
|
||||
c.totalConnections.Add(1)
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
@ -42,7 +42,7 @@ func (c *clientsMap) remove(source string) {
|
|||
defer c.mu.Unlock()
|
||||
|
||||
if val, ok := c.clients[source]; ok {
|
||||
atomic.AddInt32(&c.totalConnections, -1)
|
||||
c.totalConnections.Add(-1)
|
||||
c.clients[source]--
|
||||
if val > 1 {
|
||||
return
|
||||
|
@ -54,7 +54,7 @@ func (c *clientsMap) remove(source string) {
|
|||
}
|
||||
|
||||
func (c *clientsMap) getTotal() int32 {
|
||||
return atomic.LoadInt32(&c.totalConnections)
|
||||
return c.totalConnections.Load()
|
||||
}
|
||||
|
||||
func (c *clientsMap) getTotalFrom(source string) int {
|
||||
|
|
|
@ -704,16 +704,17 @@ func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error {
|
|||
type SSHConnection struct {
|
||||
id string
|
||||
conn net.Conn
|
||||
lastActivity int64
|
||||
lastActivity atomic.Int64
|
||||
}
|
||||
|
||||
// NewSSHConnection returns a new SSHConnection
|
||||
func NewSSHConnection(id string, conn net.Conn) *SSHConnection {
|
||||
return &SSHConnection{
|
||||
id: id,
|
||||
conn: conn,
|
||||
lastActivity: time.Now().UnixNano(),
|
||||
c := &SSHConnection{
|
||||
id: id,
|
||||
conn: conn,
|
||||
}
|
||||
c.lastActivity.Store(time.Now().UnixNano())
|
||||
return c
|
||||
}
|
||||
|
||||
// GetID returns the ID for this SSHConnection
|
||||
|
@ -723,12 +724,12 @@ func (c *SSHConnection) GetID() string {
|
|||
|
||||
// UpdateLastActivity updates last activity for this connection
|
||||
func (c *SSHConnection) UpdateLastActivity() {
|
||||
atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano())
|
||||
c.lastActivity.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// GetLastActivity returns the last connection activity
|
||||
func (c *SSHConnection) GetLastActivity() time.Time {
|
||||
return time.Unix(0, atomic.LoadInt64(&c.lastActivity))
|
||||
return time.Unix(0, c.lastActivity.Load())
|
||||
}
|
||||
|
||||
// Close closes the underlying network connection
|
||||
|
@ -741,7 +742,7 @@ type ActiveConnections struct {
|
|||
// clients contains both authenticated and estabilished connections and the ones waiting
|
||||
// for authentication
|
||||
clients clientsMap
|
||||
transfersCheckStatus int32
|
||||
transfersCheckStatus atomic.Bool
|
||||
sync.RWMutex
|
||||
connections []ActiveConnection
|
||||
sshConnections []*SSHConnection
|
||||
|
@ -953,12 +954,12 @@ func (conns *ActiveConnections) checkIdles() {
|
|||
}
|
||||
|
||||
func (conns *ActiveConnections) checkTransfers() {
|
||||
if atomic.LoadInt32(&conns.transfersCheckStatus) == 1 {
|
||||
if conns.transfersCheckStatus.Load() {
|
||||
logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution")
|
||||
return
|
||||
}
|
||||
atomic.StoreInt32(&conns.transfersCheckStatus, 1)
|
||||
defer atomic.StoreInt32(&conns.transfersCheckStatus, 0)
|
||||
conns.transfersCheckStatus.Store(true)
|
||||
defer conns.transfersCheckStatus.Store(false)
|
||||
|
||||
conns.RLock()
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@ import (
|
|||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -498,14 +497,14 @@ func TestIdleConnections(t *testing.T) {
|
|||
},
|
||||
}
|
||||
c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, "", "", user)
|
||||
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
|
||||
c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano())
|
||||
fakeConn := &fakeConnection{
|
||||
BaseConnection: c,
|
||||
}
|
||||
// both ssh connections are expired but they should get removed only
|
||||
// if there is no associated connection
|
||||
sshConn1.lastActivity = c.lastActivity
|
||||
sshConn2.lastActivity = c.lastActivity
|
||||
sshConn1.lastActivity.Store(c.lastActivity.Load())
|
||||
sshConn2.lastActivity.Store(c.lastActivity.Load())
|
||||
Connections.AddSSHConnection(sshConn1)
|
||||
err = Connections.Add(fakeConn)
|
||||
assert.NoError(t, err)
|
||||
|
@ -520,7 +519,7 @@ func TestIdleConnections(t *testing.T) {
|
|||
assert.Equal(t, Connections.GetActiveSessions(username), 2)
|
||||
|
||||
cFTP := NewBaseConnection("id2", ProtocolFTP, "", "", dataprovider.User{})
|
||||
cFTP.lastActivity = time.Now().UnixNano()
|
||||
cFTP.lastActivity.Store(time.Now().UnixNano())
|
||||
fakeConn = &fakeConnection{
|
||||
BaseConnection: cFTP,
|
||||
}
|
||||
|
@ -541,9 +540,9 @@ func TestIdleConnections(t *testing.T) {
|
|||
}, 1*time.Second, 200*time.Millisecond)
|
||||
stopEventScheduler()
|
||||
assert.Len(t, Connections.GetStats(), 2)
|
||||
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
|
||||
cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
|
||||
sshConn2.lastActivity = c.lastActivity
|
||||
c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano())
|
||||
cFTP.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano())
|
||||
sshConn2.lastActivity.Store(c.lastActivity.Load())
|
||||
startPeriodicChecks(100 * time.Millisecond)
|
||||
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 2*time.Second, 200*time.Millisecond)
|
||||
assert.Eventually(t, func() bool {
|
||||
|
@ -646,9 +645,9 @@ func TestConnectionStatus(t *testing.T) {
|
|||
BaseConnection: c1,
|
||||
}
|
||||
t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
|
||||
t1.BytesReceived = 123
|
||||
t1.BytesReceived.Store(123)
|
||||
t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
|
||||
t2.BytesSent = 456
|
||||
t2.BytesSent.Store(456)
|
||||
c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user)
|
||||
fakeConn2 := &fakeConnection{
|
||||
BaseConnection: c2,
|
||||
|
@ -698,7 +697,7 @@ func TestConnectionStatus(t *testing.T) {
|
|||
|
||||
err = fakeConn3.SignalTransfersAbort()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&t3.AbortTransfer))
|
||||
assert.True(t, t3.AbortTransfer.Load())
|
||||
err = t3.Close()
|
||||
assert.NoError(t, err)
|
||||
err = fakeConn3.SignalTransfersAbort()
|
||||
|
|
|
@ -38,12 +38,12 @@ import (
|
|||
type BaseConnection struct {
|
||||
// last activity for this connection.
|
||||
// Since this field is accessed atomically we put it as first element of the struct to achieve 64 bit alignment
|
||||
lastActivity int64
|
||||
lastActivity atomic.Int64
|
||||
uploadDone atomic.Bool
|
||||
downloadDone atomic.Bool
|
||||
// unique ID for a transfer.
|
||||
// This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment
|
||||
transferID int64
|
||||
transferID atomic.Int64
|
||||
// Unique identifier for the connection
|
||||
ID string
|
||||
// user associated with this connection if any
|
||||
|
@ -64,16 +64,18 @@ func NewBaseConnection(id, protocol, localAddr, remoteAddr string, user dataprov
|
|||
connID = fmt.Sprintf("%s_%s", protocol, id)
|
||||
}
|
||||
user.UploadBandwidth, user.DownloadBandwidth = user.GetBandwidthForIP(util.GetIPFromRemoteAddress(remoteAddr), connID)
|
||||
return &BaseConnection{
|
||||
ID: connID,
|
||||
User: user,
|
||||
startTime: time.Now(),
|
||||
protocol: protocol,
|
||||
localAddr: localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
lastActivity: time.Now().UnixNano(),
|
||||
transferID: 0,
|
||||
c := &BaseConnection{
|
||||
ID: connID,
|
||||
User: user,
|
||||
startTime: time.Now(),
|
||||
protocol: protocol,
|
||||
localAddr: localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
c.transferID.Store(0)
|
||||
c.lastActivity.Store(time.Now().UnixNano())
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Log outputs a log entry to the configured logger
|
||||
|
@ -83,7 +85,7 @@ func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...any) {
|
|||
|
||||
// GetTransferID returns an unique transfer ID for this connection
|
||||
func (c *BaseConnection) GetTransferID() int64 {
|
||||
return atomic.AddInt64(&c.transferID, 1)
|
||||
return c.transferID.Add(1)
|
||||
}
|
||||
|
||||
// GetID returns the connection ID
|
||||
|
@ -126,12 +128,12 @@ func (c *BaseConnection) GetConnectionTime() time.Time {
|
|||
|
||||
// UpdateLastActivity updates last activity for this connection
|
||||
func (c *BaseConnection) UpdateLastActivity() {
|
||||
atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano())
|
||||
c.lastActivity.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// GetLastActivity returns the last connection activity
|
||||
func (c *BaseConnection) GetLastActivity() time.Time {
|
||||
return time.Unix(0, atomic.LoadInt64(&c.lastActivity))
|
||||
return time.Unix(0, c.lastActivity.Load())
|
||||
}
|
||||
|
||||
// CloseFS closes the underlying fs
|
||||
|
|
|
@ -82,7 +82,7 @@ func HandleCertificateEvent(params EventParams) {
|
|||
// eventRulesContainer stores event rules by trigger
|
||||
type eventRulesContainer struct {
|
||||
sync.RWMutex
|
||||
lastLoad int64
|
||||
lastLoad atomic.Int64
|
||||
FsEvents []dataprovider.EventRule
|
||||
ProviderEvents []dataprovider.EventRule
|
||||
Schedules []dataprovider.EventRule
|
||||
|
@ -101,11 +101,11 @@ func (r *eventRulesContainer) removeAsyncTask() {
|
|||
}
|
||||
|
||||
func (r *eventRulesContainer) getLastLoadTime() int64 {
|
||||
return atomic.LoadInt64(&r.lastLoad)
|
||||
return r.lastLoad.Load()
|
||||
}
|
||||
|
||||
func (r *eventRulesContainer) setLastLoadTime(modTime int64) {
|
||||
atomic.StoreInt64(&r.lastLoad, modTime)
|
||||
r.lastLoad.Store(modTime)
|
||||
}
|
||||
|
||||
// RemoveRule deletes the rule with the specified name
|
||||
|
|
|
@ -186,16 +186,16 @@ func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
|
|||
}
|
||||
|
||||
type sourceRateLimiter struct {
|
||||
lastActivity int64
|
||||
lastActivity *atomic.Int64
|
||||
bucket *rate.Limiter
|
||||
}
|
||||
|
||||
func (s *sourceRateLimiter) updateLastActivity() {
|
||||
atomic.StoreInt64(&s.lastActivity, time.Now().UnixNano())
|
||||
s.lastActivity.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func (s *sourceRateLimiter) getLastActivity() int64 {
|
||||
return atomic.LoadInt64(&s.lastActivity)
|
||||
return s.lastActivity.Load()
|
||||
}
|
||||
|
||||
type sourceBuckets struct {
|
||||
|
@ -224,7 +224,8 @@ func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Rese
|
|||
b.cleanup()
|
||||
|
||||
src := sourceRateLimiter{
|
||||
bucket: r,
|
||||
lastActivity: new(atomic.Int64),
|
||||
bucket: r,
|
||||
}
|
||||
src.updateLastActivity()
|
||||
b.buckets[source] = src
|
||||
|
|
|
@ -35,8 +35,8 @@ var (
|
|||
// BaseTransfer contains protocols common transfer details for an upload or a download.
|
||||
type BaseTransfer struct { //nolint:maligned
|
||||
ID int64
|
||||
BytesSent int64
|
||||
BytesReceived int64
|
||||
BytesSent atomic.Int64
|
||||
BytesReceived atomic.Int64
|
||||
Fs vfs.Fs
|
||||
File vfs.File
|
||||
Connection *BaseConnection
|
||||
|
@ -52,7 +52,7 @@ type BaseTransfer struct { //nolint:maligned
|
|||
truncatedSize int64
|
||||
isNewFile bool
|
||||
transferType int
|
||||
AbortTransfer int32
|
||||
AbortTransfer atomic.Bool
|
||||
aTime time.Time
|
||||
mTime time.Time
|
||||
transferQuota dataprovider.TransferQuota
|
||||
|
@ -79,14 +79,14 @@ func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPat
|
|||
InitialSize: initialSize,
|
||||
isNewFile: isNewFile,
|
||||
requestPath: requestPath,
|
||||
BytesSent: 0,
|
||||
BytesReceived: 0,
|
||||
MaxWriteSize: maxWriteSize,
|
||||
AbortTransfer: 0,
|
||||
truncatedSize: truncatedSize,
|
||||
transferQuota: transferQuota,
|
||||
Fs: fs,
|
||||
}
|
||||
t.AbortTransfer.Store(false)
|
||||
t.BytesSent.Store(0)
|
||||
t.BytesReceived.Store(0)
|
||||
|
||||
conn.AddTransfer(t)
|
||||
return t
|
||||
|
@ -115,19 +115,19 @@ func (t *BaseTransfer) GetType() int {
|
|||
// GetSize returns the transferred size
|
||||
func (t *BaseTransfer) GetSize() int64 {
|
||||
if t.transferType == TransferDownload {
|
||||
return atomic.LoadInt64(&t.BytesSent)
|
||||
return t.BytesSent.Load()
|
||||
}
|
||||
return atomic.LoadInt64(&t.BytesReceived)
|
||||
return t.BytesReceived.Load()
|
||||
}
|
||||
|
||||
// GetDownloadedSize returns the transferred size
|
||||
func (t *BaseTransfer) GetDownloadedSize() int64 {
|
||||
return atomic.LoadInt64(&t.BytesSent)
|
||||
return t.BytesSent.Load()
|
||||
}
|
||||
|
||||
// GetUploadedSize returns the transferred size
|
||||
func (t *BaseTransfer) GetUploadedSize() int64 {
|
||||
return atomic.LoadInt64(&t.BytesReceived)
|
||||
return t.BytesReceived.Load()
|
||||
}
|
||||
|
||||
// GetStartTime returns the start time
|
||||
|
@ -153,7 +153,7 @@ func (t *BaseTransfer) SignalClose(err error) {
|
|||
t.Lock()
|
||||
t.errAbort = err
|
||||
t.Unlock()
|
||||
atomic.StoreInt32(&(t.AbortTransfer), 1)
|
||||
t.AbortTransfer.Store(true)
|
||||
}
|
||||
|
||||
// GetTruncatedSize returns the truncated sized if this is an upload overwriting
|
||||
|
@ -217,11 +217,11 @@ func (t *BaseTransfer) CheckRead() error {
|
|||
return nil
|
||||
}
|
||||
if t.transferQuota.AllowedTotalSize > 0 {
|
||||
if atomic.LoadInt64(&t.BytesSent)+atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedTotalSize {
|
||||
if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize {
|
||||
return t.Connection.GetReadQuotaExceededError()
|
||||
}
|
||||
} else if t.transferQuota.AllowedDLSize > 0 {
|
||||
if atomic.LoadInt64(&t.BytesSent) > t.transferQuota.AllowedDLSize {
|
||||
if t.BytesSent.Load() > t.transferQuota.AllowedDLSize {
|
||||
return t.Connection.GetReadQuotaExceededError()
|
||||
}
|
||||
}
|
||||
|
@ -230,18 +230,18 @@ func (t *BaseTransfer) CheckRead() error {
|
|||
|
||||
// CheckWrite returns an error if write if not allowed
|
||||
func (t *BaseTransfer) CheckWrite() error {
|
||||
if t.MaxWriteSize > 0 && atomic.LoadInt64(&t.BytesReceived) > t.MaxWriteSize {
|
||||
if t.MaxWriteSize > 0 && t.BytesReceived.Load() > t.MaxWriteSize {
|
||||
return t.Connection.GetQuotaExceededError()
|
||||
}
|
||||
if t.transferQuota.AllowedULSize == 0 && t.transferQuota.AllowedTotalSize == 0 {
|
||||
return nil
|
||||
}
|
||||
if t.transferQuota.AllowedTotalSize > 0 {
|
||||
if atomic.LoadInt64(&t.BytesSent)+atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedTotalSize {
|
||||
if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize {
|
||||
return t.Connection.GetQuotaExceededError()
|
||||
}
|
||||
} else if t.transferQuota.AllowedULSize > 0 {
|
||||
if atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedULSize {
|
||||
if t.BytesReceived.Load() > t.transferQuota.AllowedULSize {
|
||||
return t.Connection.GetQuotaExceededError()
|
||||
}
|
||||
}
|
||||
|
@ -261,14 +261,14 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) {
|
|||
if t.MaxWriteSize > 0 {
|
||||
sizeDiff := initialSize - size
|
||||
t.MaxWriteSize += sizeDiff
|
||||
metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived),
|
||||
metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(),
|
||||
t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs))
|
||||
if t.transferQuota.HasSizeLimits() {
|
||||
go func(ulSize, dlSize int64, user dataprovider.User) {
|
||||
dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
|
||||
}(atomic.LoadInt64(&t.BytesReceived), atomic.LoadInt64(&t.BytesSent), t.Connection.User)
|
||||
}(t.BytesReceived.Load(), t.BytesSent.Load(), t.Connection.User)
|
||||
}
|
||||
atomic.StoreInt64(&t.BytesReceived, 0)
|
||||
t.BytesReceived.Store(0)
|
||||
}
|
||||
t.Unlock()
|
||||
}
|
||||
|
@ -276,7 +276,7 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) {
|
|||
fsPath, size, t.MaxWriteSize, t.InitialSize, err)
|
||||
return initialSize, err
|
||||
}
|
||||
if size == 0 && atomic.LoadInt64(&t.BytesSent) == 0 {
|
||||
if size == 0 && t.BytesSent.Load() == 0 {
|
||||
// for cloud providers the file is always truncated to zero, we don't support append/resume for uploads
|
||||
// for buffered SFTP we can have buffered bytes so we returns an error
|
||||
if !vfs.IsBufferedSFTPFs(t.Fs) {
|
||||
|
@ -302,8 +302,8 @@ func (t *BaseTransfer) TransferError(err error) {
|
|||
}
|
||||
elapsed := time.Since(t.start).Nanoseconds() / 1000000
|
||||
t.Connection.Log(logger.LevelError, "Unexpected error for transfer, path: %#v, error: \"%v\" bytes sent: %v, "+
|
||||
"bytes received: %v transfer running since %v ms", t.fsPath, t.ErrTransfer, atomic.LoadInt64(&t.BytesSent),
|
||||
atomic.LoadInt64(&t.BytesReceived), elapsed)
|
||||
"bytes received: %v transfer running since %v ms", t.fsPath, t.ErrTransfer, t.BytesSent.Load(),
|
||||
t.BytesReceived.Load(), elapsed)
|
||||
}
|
||||
|
||||
func (t *BaseTransfer) getUploadFileSize() (int64, error) {
|
||||
|
@ -333,7 +333,7 @@ func (t *BaseTransfer) checkUploadOutsideHomeDir(err error) int {
|
|||
t.Connection.Log(logger.LevelWarn, "upload in temp path cannot be renamed, delete temporary file: %#v, deletion error: %v",
|
||||
t.effectiveFsPath, err)
|
||||
// the file is outside the home dir so don't update the quota
|
||||
atomic.StoreInt64(&t.BytesReceived, 0)
|
||||
t.BytesReceived.Store(0)
|
||||
t.MinWriteOffset = 0
|
||||
return 1
|
||||
}
|
||||
|
@ -351,18 +351,18 @@ func (t *BaseTransfer) Close() error {
|
|||
if t.isNewFile {
|
||||
numFiles = 1
|
||||
}
|
||||
metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived),
|
||||
metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(),
|
||||
t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs))
|
||||
if t.transferQuota.HasSizeLimits() {
|
||||
dataprovider.UpdateUserTransferQuota(&t.Connection.User, atomic.LoadInt64(&t.BytesReceived), //nolint:errcheck
|
||||
atomic.LoadInt64(&t.BytesSent), false)
|
||||
dataprovider.UpdateUserTransferQuota(&t.Connection.User, t.BytesReceived.Load(), //nolint:errcheck
|
||||
t.BytesSent.Load(), false)
|
||||
}
|
||||
if t.File != nil && t.Connection.IsQuotaExceededError(t.ErrTransfer) {
|
||||
// if quota is exceeded we try to remove the partial file for uploads to local filesystem
|
||||
err = t.Fs.Remove(t.File.Name(), false)
|
||||
if err == nil {
|
||||
numFiles--
|
||||
atomic.StoreInt64(&t.BytesReceived, 0)
|
||||
t.BytesReceived.Store(0)
|
||||
t.MinWriteOffset = 0
|
||||
}
|
||||
t.Connection.Log(logger.LevelWarn, "upload denied due to space limit, delete temporary file: %#v, deletion error: %v",
|
||||
|
@ -380,7 +380,7 @@ func (t *BaseTransfer) Close() error {
|
|||
t.ErrTransfer, t.effectiveFsPath, err)
|
||||
if err == nil {
|
||||
numFiles--
|
||||
atomic.StoreInt64(&t.BytesReceived, 0)
|
||||
t.BytesReceived.Store(0)
|
||||
t.MinWriteOffset = 0
|
||||
}
|
||||
}
|
||||
|
@ -388,12 +388,12 @@ func (t *BaseTransfer) Close() error {
|
|||
elapsed := time.Since(t.start).Nanoseconds() / 1000000
|
||||
var uploadFileSize int64
|
||||
if t.transferType == TransferDownload {
|
||||
logger.TransferLog(downloadLogSender, t.fsPath, elapsed, atomic.LoadInt64(&t.BytesSent), t.Connection.User.Username,
|
||||
logger.TransferLog(downloadLogSender, t.fsPath, elapsed, t.BytesSent.Load(), t.Connection.User.Username,
|
||||
t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode)
|
||||
ExecuteActionNotification(t.Connection, operationDownload, t.fsPath, t.requestPath, "", "", "", //nolint:errcheck
|
||||
atomic.LoadInt64(&t.BytesSent), t.ErrTransfer)
|
||||
t.BytesSent.Load(), t.ErrTransfer)
|
||||
} else {
|
||||
uploadFileSize = atomic.LoadInt64(&t.BytesReceived) + t.MinWriteOffset
|
||||
uploadFileSize = t.BytesReceived.Load() + t.MinWriteOffset
|
||||
if statSize, errStat := t.getUploadFileSize(); errStat == nil {
|
||||
uploadFileSize = statSize
|
||||
}
|
||||
|
@ -401,7 +401,7 @@ func (t *BaseTransfer) Close() error {
|
|||
numFiles, uploadFileSize = t.executeUploadHook(numFiles, uploadFileSize)
|
||||
t.updateQuota(numFiles, uploadFileSize)
|
||||
t.updateTimes()
|
||||
logger.TransferLog(uploadLogSender, t.fsPath, elapsed, atomic.LoadInt64(&t.BytesReceived), t.Connection.User.Username,
|
||||
logger.TransferLog(uploadLogSender, t.fsPath, elapsed, t.BytesReceived.Load(), t.Connection.User.Username,
|
||||
t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode)
|
||||
}
|
||||
if t.ErrTransfer != nil {
|
||||
|
@ -428,11 +428,11 @@ func (t *BaseTransfer) updateTransferTimestamps(uploadFileSize int64) {
|
|||
}
|
||||
return
|
||||
}
|
||||
if t.Connection.User.FirstDownload == 0 && !t.Connection.downloadDone.Load() && atomic.LoadInt64(&t.BytesSent) > 0 {
|
||||
if t.Connection.User.FirstDownload == 0 && !t.Connection.downloadDone.Load() && t.BytesSent.Load() > 0 {
|
||||
if err := dataprovider.UpdateUserTransferTimestamps(t.Connection.User.Username, false); err == nil {
|
||||
t.Connection.downloadDone.Store(true)
|
||||
ExecuteActionNotification(t.Connection, operationFirstDownload, t.fsPath, t.requestPath, "", //nolint:errcheck
|
||||
"", "", atomic.LoadInt64(&t.BytesSent), t.ErrTransfer)
|
||||
"", "", t.BytesSent.Load(), t.ErrTransfer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -449,7 +449,7 @@ func (t *BaseTransfer) executeUploadHook(numFiles int, fileSize int64) (int, int
|
|||
if err == nil {
|
||||
numFiles--
|
||||
fileSize = 0
|
||||
atomic.StoreInt64(&t.BytesReceived, 0)
|
||||
t.BytesReceived.Store(0)
|
||||
t.MinWriteOffset = 0
|
||||
} else {
|
||||
t.Connection.Log(logger.LevelWarn, "unable to remove path %q after upload hook failure: %v", t.fsPath, err)
|
||||
|
@ -494,10 +494,10 @@ func (t *BaseTransfer) HandleThrottle() {
|
|||
var trasferredBytes int64
|
||||
if t.transferType == TransferDownload {
|
||||
wantedBandwidth = t.Connection.User.DownloadBandwidth
|
||||
trasferredBytes = atomic.LoadInt64(&t.BytesSent)
|
||||
trasferredBytes = t.BytesSent.Load()
|
||||
} else {
|
||||
wantedBandwidth = t.Connection.User.UploadBandwidth
|
||||
trasferredBytes = atomic.LoadInt64(&t.BytesReceived)
|
||||
trasferredBytes = t.BytesReceived.Load()
|
||||
}
|
||||
if wantedBandwidth > 0 {
|
||||
// real and wanted elapsed as milliseconds, bytes as kilobytes
|
||||
|
|
|
@ -33,11 +33,11 @@ import (
|
|||
func TestTransferUpdateQuota(t *testing.T) {
|
||||
conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{})
|
||||
transfer := BaseTransfer{
|
||||
Connection: conn,
|
||||
transferType: TransferUpload,
|
||||
BytesReceived: 123,
|
||||
Fs: vfs.NewOsFs("", os.TempDir(), ""),
|
||||
Connection: conn,
|
||||
transferType: TransferUpload,
|
||||
Fs: vfs.NewOsFs("", os.TempDir(), ""),
|
||||
}
|
||||
transfer.BytesReceived.Store(123)
|
||||
errFake := errors.New("fake error")
|
||||
transfer.TransferError(errFake)
|
||||
assert.False(t, transfer.updateQuota(1, 0))
|
||||
|
@ -56,7 +56,7 @@ func TestTransferUpdateQuota(t *testing.T) {
|
|||
QuotaSize: -1,
|
||||
})
|
||||
transfer.ErrTransfer = nil
|
||||
transfer.BytesReceived = 1
|
||||
transfer.BytesReceived.Store(1)
|
||||
transfer.requestPath = "/vdir/file"
|
||||
assert.True(t, transfer.updateQuota(1, 0))
|
||||
err = transfer.Close()
|
||||
|
@ -80,7 +80,7 @@ func TestTransferThrottling(t *testing.T) {
|
|||
wantedDownloadElapsed -= wantedDownloadElapsed / 10
|
||||
conn := NewBaseConnection("id", ProtocolSCP, "", "", u)
|
||||
transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
|
||||
transfer.BytesReceived = testFileSize
|
||||
transfer.BytesReceived.Store(testFileSize)
|
||||
transfer.Connection.UpdateLastActivity()
|
||||
startTime := transfer.Connection.GetLastActivity()
|
||||
transfer.HandleThrottle()
|
||||
|
@ -90,7 +90,7 @@ func TestTransferThrottling(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
|
||||
transfer.BytesSent = testFileSize
|
||||
transfer.BytesSent.Store(testFileSize)
|
||||
transfer.Connection.UpdateLastActivity()
|
||||
startTime = transfer.Connection.GetLastActivity()
|
||||
|
||||
|
@ -226,7 +226,7 @@ func TestTransferErrors(t *testing.T) {
|
|||
assert.Equal(t, testFile, transfer.GetFsPath())
|
||||
transfer.SetCancelFn(cancelFn)
|
||||
errFake := errors.New("err fake")
|
||||
transfer.BytesReceived = 9
|
||||
transfer.BytesReceived.Store(9)
|
||||
transfer.TransferError(ErrQuotaExceeded)
|
||||
assert.True(t, isCancelled)
|
||||
transfer.TransferError(errFake)
|
||||
|
@ -249,7 +249,7 @@ func TestTransferErrors(t *testing.T) {
|
|||
fsPath := filepath.Join(os.TempDir(), "test_file")
|
||||
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true,
|
||||
fs, dataprovider.TransferQuota{})
|
||||
transfer.BytesReceived = 9
|
||||
transfer.BytesReceived.Store(9)
|
||||
transfer.TransferError(errFake)
|
||||
assert.Error(t, transfer.ErrTransfer, errFake.Error())
|
||||
// the file is closed from the embedding struct before to call close
|
||||
|
@ -269,7 +269,7 @@ func TestTransferErrors(t *testing.T) {
|
|||
}
|
||||
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true,
|
||||
fs, dataprovider.TransferQuota{})
|
||||
transfer.BytesReceived = 9
|
||||
transfer.BytesReceived.Store(9)
|
||||
// the file is closed from the embedding struct before to call close
|
||||
err = file.Close()
|
||||
assert.NoError(t, err)
|
||||
|
@ -310,11 +310,11 @@ func TestRemovePartialCryptoFile(t *testing.T) {
|
|||
func TestFTPMode(t *testing.T) {
|
||||
conn := NewBaseConnection("", ProtocolFTP, "", "", dataprovider.User{})
|
||||
transfer := BaseTransfer{
|
||||
Connection: conn,
|
||||
transferType: TransferUpload,
|
||||
BytesReceived: 123,
|
||||
Fs: vfs.NewOsFs("", os.TempDir(), ""),
|
||||
Connection: conn,
|
||||
transferType: TransferUpload,
|
||||
Fs: vfs.NewOsFs("", os.TempDir(), ""),
|
||||
}
|
||||
transfer.BytesReceived.Store(123)
|
||||
assert.Empty(t, transfer.ftpMode)
|
||||
transfer.SetFtpMode("active")
|
||||
assert.Equal(t, "active", transfer.ftpMode)
|
||||
|
@ -399,14 +399,14 @@ func TestTransferQuota(t *testing.T) {
|
|||
transfer.transferQuota = dataprovider.TransferQuota{
|
||||
AllowedTotalSize: 10,
|
||||
}
|
||||
transfer.BytesReceived = 5
|
||||
transfer.BytesSent = 4
|
||||
transfer.BytesReceived.Store(5)
|
||||
transfer.BytesSent.Store(4)
|
||||
err = transfer.CheckRead()
|
||||
assert.NoError(t, err)
|
||||
err = transfer.CheckWrite()
|
||||
assert.NoError(t, err)
|
||||
|
||||
transfer.BytesSent = 6
|
||||
transfer.BytesSent.Store(6)
|
||||
err = transfer.CheckRead()
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error())
|
||||
|
@ -428,7 +428,7 @@ func TestTransferQuota(t *testing.T) {
|
|||
err = transfer.CheckWrite()
|
||||
assert.NoError(t, err)
|
||||
|
||||
transfer.BytesReceived = 11
|
||||
transfer.BytesReceived.Store(11)
|
||||
err = transfer.CheckRead()
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error())
|
||||
|
@ -442,11 +442,11 @@ func TestUploadOutsideHomeRenameError(t *testing.T) {
|
|||
|
||||
conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{})
|
||||
transfer := BaseTransfer{
|
||||
Connection: conn,
|
||||
transferType: TransferUpload,
|
||||
BytesReceived: 123,
|
||||
Fs: vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), ""),
|
||||
Connection: conn,
|
||||
transferType: TransferUpload,
|
||||
Fs: vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), ""),
|
||||
}
|
||||
transfer.BytesReceived.Store(123)
|
||||
|
||||
fileName := filepath.Join(os.TempDir(), "_temp")
|
||||
err := os.WriteFile(fileName, []byte(`data`), 0644)
|
||||
|
@ -459,10 +459,10 @@ func TestUploadOutsideHomeRenameError(t *testing.T) {
|
|||
Config.TempPath = filepath.Clean(os.TempDir())
|
||||
res = transfer.checkUploadOutsideHomeDir(nil)
|
||||
assert.Equal(t, 0, res)
|
||||
assert.Greater(t, transfer.BytesReceived, int64(0))
|
||||
assert.Greater(t, transfer.BytesReceived.Load(), int64(0))
|
||||
res = transfer.checkUploadOutsideHomeDir(os.ErrPermission)
|
||||
assert.Equal(t, 1, res)
|
||||
assert.Equal(t, int64(0), transfer.BytesReceived)
|
||||
assert.Equal(t, int64(0), transfer.BytesReceived.Load())
|
||||
assert.NoFileExists(t, fileName)
|
||||
|
||||
Config.TempPath = oldTempPath
|
||||
|
|
|
@ -21,7 +21,6 @@ import (
|
|||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -96,7 +95,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
|
|||
}
|
||||
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
|
||||
"/file1", TransferUpload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
|
||||
transfer1.BytesReceived = 150
|
||||
transfer1.BytesReceived.Store(150)
|
||||
err = Connections.Add(fakeConn1)
|
||||
assert.NoError(t, err)
|
||||
// the transferschecker will do nothing if there is only one ongoing transfer
|
||||
|
@ -110,8 +109,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
|
|||
}
|
||||
transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
|
||||
"/file2", TransferUpload, 0, 0, 120, 40, true, fsUser, dataprovider.TransferQuota{})
|
||||
transfer1.BytesReceived = 50
|
||||
transfer2.BytesReceived = 60
|
||||
transfer1.BytesReceived.Store(50)
|
||||
transfer2.BytesReceived.Store(60)
|
||||
err = Connections.Add(fakeConn2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
@ -122,7 +121,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
|
|||
}
|
||||
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"),
|
||||
"/file3", TransferDownload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
|
||||
transfer3.BytesReceived = 60 // this value will be ignored, this is a download
|
||||
transfer3.BytesReceived.Store(60) // this value will be ignored, this is a download
|
||||
err = Connections.Add(fakeConn3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
@ -132,20 +131,20 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
|
|||
assert.Nil(t, transfer2.errAbort)
|
||||
assert.Nil(t, transfer3.errAbort)
|
||||
|
||||
transfer1.BytesReceived = 80 // truncated size will be subtracted, we are not overquota
|
||||
transfer1.BytesReceived.Store(80) // truncated size will be subtracted, we are not overquota
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
assert.Nil(t, transfer2.errAbort)
|
||||
assert.Nil(t, transfer3.errAbort)
|
||||
transfer1.BytesReceived = 120
|
||||
transfer1.BytesReceived.Store(120)
|
||||
// we are now overquota
|
||||
// if another check is in progress nothing is done
|
||||
atomic.StoreInt32(&Connections.transfersCheckStatus, 1)
|
||||
Connections.transfersCheckStatus.Store(true)
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
assert.Nil(t, transfer2.errAbort)
|
||||
assert.Nil(t, transfer3.errAbort)
|
||||
atomic.StoreInt32(&Connections.transfersCheckStatus, 0)
|
||||
Connections.transfersCheckStatus.Store(false)
|
||||
|
||||
Connections.checkTransfers()
|
||||
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort), transfer1.errAbort)
|
||||
|
@ -172,8 +171,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
|
|||
assert.Nil(t, transfer2.errAbort)
|
||||
assert.Nil(t, transfer3.errAbort)
|
||||
// now check a public folder
|
||||
transfer1.BytesReceived = 0
|
||||
transfer2.BytesReceived = 0
|
||||
transfer1.BytesReceived.Store(0)
|
||||
transfer2.BytesReceived.Store(0)
|
||||
connID4 := xid.New().String()
|
||||
fsFolder, err := user.GetFilesystemForPath(path.Join(vdirPath, "/file1"), connID4)
|
||||
assert.NoError(t, err)
|
||||
|
@ -197,12 +196,12 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
|
|||
|
||||
err = Connections.Add(fakeConn5)
|
||||
assert.NoError(t, err)
|
||||
transfer4.BytesReceived = 50
|
||||
transfer5.BytesReceived = 40
|
||||
transfer4.BytesReceived.Store(50)
|
||||
transfer5.BytesReceived.Store(40)
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer4.errAbort)
|
||||
assert.Nil(t, transfer5.errAbort)
|
||||
transfer5.BytesReceived = 60
|
||||
transfer5.BytesReceived.Store(60)
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
assert.Nil(t, transfer2.errAbort)
|
||||
|
@ -286,7 +285,7 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
|
|||
}
|
||||
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
|
||||
"/file1", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
|
||||
transfer1.BytesReceived = 150
|
||||
transfer1.BytesReceived.Store(150)
|
||||
err = Connections.Add(fakeConn1)
|
||||
assert.NoError(t, err)
|
||||
// the transferschecker will do nothing if there is only one ongoing transfer
|
||||
|
@ -300,26 +299,26 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
|
|||
}
|
||||
transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
|
||||
"/file2", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
|
||||
transfer2.BytesReceived = 150
|
||||
transfer2.BytesReceived.Store(150)
|
||||
err = Connections.Add(fakeConn2)
|
||||
assert.NoError(t, err)
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
assert.Nil(t, transfer2.errAbort)
|
||||
// now test overquota
|
||||
transfer1.BytesReceived = 1024*1024 + 1
|
||||
transfer2.BytesReceived = 0
|
||||
transfer1.BytesReceived.Store(1024*1024 + 1)
|
||||
transfer2.BytesReceived.Store(0)
|
||||
Connections.checkTransfers()
|
||||
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort))
|
||||
assert.Nil(t, transfer2.errAbort)
|
||||
transfer1.errAbort = nil
|
||||
transfer1.BytesReceived = 1024*1024 + 1
|
||||
transfer2.BytesReceived = 1024
|
||||
transfer1.BytesReceived.Store(1024*1024 + 1)
|
||||
transfer2.BytesReceived.Store(1024)
|
||||
Connections.checkTransfers()
|
||||
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort))
|
||||
assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort))
|
||||
transfer1.BytesReceived = 0
|
||||
transfer2.BytesReceived = 0
|
||||
transfer1.BytesReceived.Store(0)
|
||||
transfer2.BytesReceived.Store(0)
|
||||
transfer1.errAbort = nil
|
||||
transfer2.errAbort = nil
|
||||
|
||||
|
@ -337,7 +336,7 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
|
|||
}
|
||||
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
|
||||
"/file1", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
|
||||
transfer3.BytesSent = 150
|
||||
transfer3.BytesSent.Store(150)
|
||||
err = Connections.Add(fakeConn3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
@ -348,15 +347,15 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
|
|||
}
|
||||
transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
|
||||
"/file2", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
|
||||
transfer4.BytesSent = 150
|
||||
transfer4.BytesSent.Store(150)
|
||||
err = Connections.Add(fakeConn4)
|
||||
assert.NoError(t, err)
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer3.errAbort)
|
||||
assert.Nil(t, transfer4.errAbort)
|
||||
|
||||
transfer3.BytesSent = 512 * 1024
|
||||
transfer4.BytesSent = 512*1024 + 1
|
||||
transfer3.BytesSent.Store(512 * 1024)
|
||||
transfer4.BytesSent.Store(512*1024 + 1)
|
||||
Connections.checkTransfers()
|
||||
if assert.Error(t, transfer3.errAbort) {
|
||||
assert.Contains(t, transfer3.errAbort.Error(), ErrReadQuotaExceeded.Error())
|
||||
|
|
|
@ -155,7 +155,7 @@ var (
|
|||
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
// ErrLoginNotAllowedFromIP defines the error to return if login is denied from the current IP
|
||||
ErrLoginNotAllowedFromIP = errors.New("login is not allowed from this IP")
|
||||
isAdminCreated = int32(0)
|
||||
isAdminCreated atomic.Bool
|
||||
validTLSUsernames = []string{string(sdk.TLSUsernameNone), string(sdk.TLSUsernameCN)}
|
||||
config Config
|
||||
provider Provider
|
||||
|
@ -844,7 +844,7 @@ func Initialize(cnf Config, basePath string, checkAdmins bool) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
atomic.StoreInt32(&isAdminCreated, int32(len(admins)))
|
||||
isAdminCreated.Store(len(admins) > 0)
|
||||
delayedQuotaUpdater.start()
|
||||
return startScheduler()
|
||||
}
|
||||
|
@ -1722,7 +1722,7 @@ func UpdateTaskTimestamp(name string) error {
|
|||
// HasAdmin returns true if the first admin has been created
|
||||
// and so SFTPGo is ready to be used
|
||||
func HasAdmin() bool {
|
||||
return atomic.LoadInt32(&isAdminCreated) > 0
|
||||
return isAdminCreated.Load()
|
||||
}
|
||||
|
||||
// AddAdmin adds a new SFTPGo admin
|
||||
|
@ -1734,7 +1734,7 @@ func AddAdmin(admin *Admin, executor, ipAddress string) error {
|
|||
admin.Username = config.convertName(admin.Username)
|
||||
err := provider.addAdmin(admin)
|
||||
if err == nil {
|
||||
atomic.StoreInt32(&isAdminCreated, 1)
|
||||
isAdminCreated.Store(true)
|
||||
executeAction(operationAdd, executor, ipAddress, actionObjectAdmin, admin.Username, admin)
|
||||
}
|
||||
return err
|
||||
|
|
|
@ -28,11 +28,11 @@ import (
|
|||
|
||||
var (
|
||||
scheduler *cron.Cron
|
||||
lastUserCacheUpdate int64
|
||||
lastUserCacheUpdate atomic.Int64
|
||||
// used for bolt and memory providers, so we avoid iterating all users/rules
|
||||
// to find recently modified ones
|
||||
lastUserUpdate int64
|
||||
lastRuleUpdate int64
|
||||
lastUserUpdate atomic.Int64
|
||||
lastRuleUpdate atomic.Int64
|
||||
)
|
||||
|
||||
func stopScheduler() {
|
||||
|
@ -62,7 +62,7 @@ func startScheduler() error {
|
|||
}
|
||||
|
||||
func addScheduledCacheUpdates() error {
|
||||
lastUserCacheUpdate = util.GetTimeAsMsSinceEpoch(time.Now())
|
||||
lastUserCacheUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
_, err := scheduler.AddFunc("@every 10m", checkCacheUpdates)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to schedule cache updates: %w", err)
|
||||
|
@ -79,9 +79,9 @@ func checkDataprovider() {
|
|||
}
|
||||
|
||||
func checkCacheUpdates() {
|
||||
providerLog(logger.LevelDebug, "start user cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate))
|
||||
providerLog(logger.LevelDebug, "start user cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate.Load()))
|
||||
checkTime := util.GetTimeAsMsSinceEpoch(time.Now())
|
||||
users, err := provider.getRecentlyUpdatedUsers(lastUserCacheUpdate)
|
||||
users, err := provider.getRecentlyUpdatedUsers(lastUserCacheUpdate.Load())
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to get recently updated users: %v", err)
|
||||
return
|
||||
|
@ -102,22 +102,22 @@ func checkCacheUpdates() {
|
|||
cachedPasswords.Remove(user.Username)
|
||||
}
|
||||
|
||||
lastUserCacheUpdate = checkTime
|
||||
providerLog(logger.LevelDebug, "end user cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate))
|
||||
lastUserCacheUpdate.Store(checkTime)
|
||||
providerLog(logger.LevelDebug, "end user cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate.Load()))
|
||||
}
|
||||
|
||||
func setLastUserUpdate() {
|
||||
atomic.StoreInt64(&lastUserUpdate, util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
lastUserUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
}
|
||||
|
||||
func getLastUserUpdate() int64 {
|
||||
return atomic.LoadInt64(&lastUserUpdate)
|
||||
return lastUserUpdate.Load()
|
||||
}
|
||||
|
||||
func setLastRuleUpdate() {
|
||||
atomic.StoreInt64(&lastRuleUpdate, util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
lastRuleUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
}
|
||||
|
||||
func getLastRuleUpdate() int64 {
|
||||
return atomic.LoadInt64(&lastRuleUpdate)
|
||||
return lastRuleUpdate.Load()
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@ package ftpd
|
|||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/eikenb/pipeat"
|
||||
|
||||
|
@ -61,7 +60,7 @@ func (t *transfer) Read(p []byte) (n int, err error) {
|
|||
t.Connection.UpdateLastActivity()
|
||||
|
||||
n, err = t.reader.Read(p)
|
||||
atomic.AddInt64(&t.BytesSent, int64(n))
|
||||
t.BytesSent.Add(int64(n))
|
||||
|
||||
if err == nil {
|
||||
err = t.CheckRead()
|
||||
|
@ -79,7 +78,7 @@ func (t *transfer) Write(p []byte) (n int, err error) {
|
|||
t.Connection.UpdateLastActivity()
|
||||
|
||||
n, err = t.writer.Write(p)
|
||||
atomic.AddInt64(&t.BytesReceived, int64(n))
|
||||
t.BytesReceived.Add(int64(n))
|
||||
|
||||
if err == nil {
|
||||
err = t.CheckWrite()
|
||||
|
|
|
@ -16,7 +16,6 @@ package httpd
|
|||
|
||||
import (
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/eikenb/pipeat"
|
||||
|
||||
|
@ -52,7 +51,7 @@ func newHTTPDFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter,
|
|||
|
||||
// Read reads the contents to downloads.
|
||||
func (f *httpdFile) Read(p []byte) (n int, err error) {
|
||||
if atomic.LoadInt32(&f.AbortTransfer) == 1 {
|
||||
if f.AbortTransfer.Load() {
|
||||
err := f.GetAbortError()
|
||||
f.TransferError(err)
|
||||
return 0, err
|
||||
|
@ -61,7 +60,7 @@ func (f *httpdFile) Read(p []byte) (n int, err error) {
|
|||
f.Connection.UpdateLastActivity()
|
||||
|
||||
n, err = f.reader.Read(p)
|
||||
atomic.AddInt64(&f.BytesSent, int64(n))
|
||||
f.BytesSent.Add(int64(n))
|
||||
|
||||
if err == nil {
|
||||
err = f.CheckRead()
|
||||
|
@ -76,7 +75,7 @@ func (f *httpdFile) Read(p []byte) (n int, err error) {
|
|||
|
||||
// Write writes the contents to upload
|
||||
func (f *httpdFile) Write(p []byte) (n int, err error) {
|
||||
if atomic.LoadInt32(&f.AbortTransfer) == 1 {
|
||||
if f.AbortTransfer.Load() {
|
||||
err := f.GetAbortError()
|
||||
f.TransferError(err)
|
||||
return 0, err
|
||||
|
@ -85,7 +84,7 @@ func (f *httpdFile) Write(p []byte) (n int, err error) {
|
|||
f.Connection.UpdateLastActivity()
|
||||
|
||||
n, err = f.writer.Write(p)
|
||||
atomic.AddInt64(&f.BytesReceived, int64(n))
|
||||
f.BytesReceived.Add(int64(n))
|
||||
|
||||
if err == nil {
|
||||
err = f.CheckWrite()
|
||||
|
|
|
@ -238,24 +238,24 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
|
|||
|
||||
func newThrottledReader(r io.ReadCloser, limit int64, conn *Connection) *throttledReader {
|
||||
t := &throttledReader{
|
||||
bytesRead: 0,
|
||||
id: conn.GetTransferID(),
|
||||
limit: limit,
|
||||
r: r,
|
||||
abortTransfer: 0,
|
||||
start: time.Now(),
|
||||
conn: conn,
|
||||
id: conn.GetTransferID(),
|
||||
limit: limit,
|
||||
r: r,
|
||||
start: time.Now(),
|
||||
conn: conn,
|
||||
}
|
||||
t.bytesRead.Store(0)
|
||||
t.abortTransfer.Store(false)
|
||||
conn.AddTransfer(t)
|
||||
return t
|
||||
}
|
||||
|
||||
type throttledReader struct {
|
||||
bytesRead int64
|
||||
bytesRead atomic.Int64
|
||||
id int64
|
||||
limit int64
|
||||
r io.ReadCloser
|
||||
abortTransfer int32
|
||||
abortTransfer atomic.Bool
|
||||
start time.Time
|
||||
conn *Connection
|
||||
mu sync.Mutex
|
||||
|
@ -271,7 +271,7 @@ func (t *throttledReader) GetType() int {
|
|||
}
|
||||
|
||||
func (t *throttledReader) GetSize() int64 {
|
||||
return atomic.LoadInt64(&t.bytesRead)
|
||||
return t.bytesRead.Load()
|
||||
}
|
||||
|
||||
func (t *throttledReader) GetDownloadedSize() int64 {
|
||||
|
@ -279,7 +279,7 @@ func (t *throttledReader) GetDownloadedSize() int64 {
|
|||
}
|
||||
|
||||
func (t *throttledReader) GetUploadedSize() int64 {
|
||||
return atomic.LoadInt64(&t.bytesRead)
|
||||
return t.bytesRead.Load()
|
||||
}
|
||||
|
||||
func (t *throttledReader) GetVirtualPath() string {
|
||||
|
@ -304,7 +304,7 @@ func (t *throttledReader) SignalClose(err error) {
|
|||
t.mu.Lock()
|
||||
t.errAbort = err
|
||||
t.mu.Unlock()
|
||||
atomic.StoreInt32(&(t.abortTransfer), 1)
|
||||
t.abortTransfer.Store(true)
|
||||
}
|
||||
|
||||
func (t *throttledReader) GetTruncatedSize() int64 {
|
||||
|
@ -328,15 +328,15 @@ func (t *throttledReader) SetTimes(fsPath string, atime time.Time, mtime time.Ti
|
|||
}
|
||||
|
||||
func (t *throttledReader) Read(p []byte) (n int, err error) {
|
||||
if atomic.LoadInt32(&t.abortTransfer) == 1 {
|
||||
if t.abortTransfer.Load() {
|
||||
return 0, t.GetAbortError()
|
||||
}
|
||||
|
||||
t.conn.UpdateLastActivity()
|
||||
n, err = t.r.Read(p)
|
||||
if t.limit > 0 {
|
||||
atomic.AddInt64(&t.bytesRead, int64(n))
|
||||
trasferredBytes := atomic.LoadInt64(&t.bytesRead)
|
||||
t.bytesRead.Add(int64(n))
|
||||
trasferredBytes := t.bytesRead.Load()
|
||||
elapsed := time.Since(t.start).Nanoseconds() / 1000000
|
||||
wantedElapsed := 1000 * (trasferredBytes / 1024) / t.limit
|
||||
if wantedElapsed > elapsed {
|
||||
|
|
|
@ -93,7 +93,7 @@ func (c *Config) newKMSPluginSecretProvider(base kms.BaseSecret, url, masterKey
|
|||
|
||||
// Manager handles enabled plugins
|
||||
type Manager struct {
|
||||
closed int32
|
||||
closed atomic.Bool
|
||||
done chan bool
|
||||
// List of configured plugins
|
||||
Configs []Config `json:"plugins" mapstructure:"plugins"`
|
||||
|
@ -124,10 +124,10 @@ func Initialize(configs []Config, logLevel string) error {
|
|||
Handler = Manager{
|
||||
Configs: configs,
|
||||
done: make(chan bool),
|
||||
closed: 0,
|
||||
authScopes: -1,
|
||||
concurrencyGuard: make(chan struct{}, 250),
|
||||
}
|
||||
Handler.closed.Store(false)
|
||||
setLogLevel(logLevel)
|
||||
if len(configs) == 0 {
|
||||
return nil
|
||||
|
@ -604,7 +604,7 @@ func (m *Manager) checkCrashedPlugins() {
|
|||
}
|
||||
|
||||
func (m *Manager) restartNotifierPlugin(config Config, idx int) {
|
||||
if atomic.LoadInt32(&m.closed) == 1 {
|
||||
if m.closed.Load() {
|
||||
return
|
||||
}
|
||||
logger.Info(logSender, "", "try to restart crashed notifier plugin %#v, idx: %v", config.Cmd, idx)
|
||||
|
@ -622,7 +622,7 @@ func (m *Manager) restartNotifierPlugin(config Config, idx int) {
|
|||
}
|
||||
|
||||
func (m *Manager) restartKMSPlugin(config Config, idx int) {
|
||||
if atomic.LoadInt32(&m.closed) == 1 {
|
||||
if m.closed.Load() {
|
||||
return
|
||||
}
|
||||
logger.Info(logSender, "", "try to restart crashed kms plugin %#v, idx: %v", config.Cmd, idx)
|
||||
|
@ -638,7 +638,7 @@ func (m *Manager) restartKMSPlugin(config Config, idx int) {
|
|||
}
|
||||
|
||||
func (m *Manager) restartAuthPlugin(config Config, idx int) {
|
||||
if atomic.LoadInt32(&m.closed) == 1 {
|
||||
if m.closed.Load() {
|
||||
return
|
||||
}
|
||||
logger.Info(logSender, "", "try to restart crashed auth plugin %#v, idx: %v", config.Cmd, idx)
|
||||
|
@ -654,7 +654,7 @@ func (m *Manager) restartAuthPlugin(config Config, idx int) {
|
|||
}
|
||||
|
||||
func (m *Manager) restartSearcherPlugin(config Config) {
|
||||
if atomic.LoadInt32(&m.closed) == 1 {
|
||||
if m.closed.Load() {
|
||||
return
|
||||
}
|
||||
logger.Info(logSender, "", "try to restart crashed searcher plugin %#v", config.Cmd)
|
||||
|
@ -670,7 +670,7 @@ func (m *Manager) restartSearcherPlugin(config Config) {
|
|||
}
|
||||
|
||||
func (m *Manager) restartMetadaterPlugin(config Config) {
|
||||
if atomic.LoadInt32(&m.closed) == 1 {
|
||||
if m.closed.Load() {
|
||||
return
|
||||
}
|
||||
logger.Info(logSender, "", "try to restart crashed metadater plugin %#v", config.Cmd)
|
||||
|
@ -686,7 +686,7 @@ func (m *Manager) restartMetadaterPlugin(config Config) {
|
|||
}
|
||||
|
||||
func (m *Manager) restartIPFilterPlugin(config Config) {
|
||||
if atomic.LoadInt32(&m.closed) == 1 {
|
||||
if m.closed.Load() {
|
||||
return
|
||||
}
|
||||
logger.Info(logSender, "", "try to restart crashed IP filter plugin %#v", config.Cmd)
|
||||
|
@ -712,7 +712,7 @@ func (m *Manager) removeTask() {
|
|||
// Cleanup releases all the active plugins
|
||||
func (m *Manager) Cleanup() {
|
||||
logger.Debug(logSender, "", "cleanup")
|
||||
atomic.StoreInt32(&m.closed, 1)
|
||||
m.closed.Store(true)
|
||||
close(m.done)
|
||||
m.notifLock.Lock()
|
||||
for _, n := range m.notifiers {
|
||||
|
|
|
@ -1785,7 +1785,7 @@ func TestUploadError(t *testing.T) {
|
|||
if assert.Error(t, transfer.ErrTransfer) {
|
||||
assert.EqualError(t, transfer.ErrTransfer, errFake.Error())
|
||||
}
|
||||
assert.Equal(t, int64(0), transfer.BytesReceived)
|
||||
assert.Equal(t, int64(0), transfer.BytesReceived.Load())
|
||||
|
||||
assert.NoFileExists(t, testfile)
|
||||
assert.NoFileExists(t, fileTempName)
|
||||
|
|
|
@ -1114,12 +1114,12 @@ func TestConcurrency(t *testing.T) {
|
|||
err = createTestFile(testFilePath, testFileSize)
|
||||
assert.NoError(t, err)
|
||||
|
||||
closedConns := int32(0)
|
||||
var closedConns atomic.Int32
|
||||
for i := 0; i < numLogins; i++ {
|
||||
wg.Add(1)
|
||||
go func(counter int) {
|
||||
defer wg.Done()
|
||||
defer atomic.AddInt32(&closedConns, 1)
|
||||
defer closedConns.Add(1)
|
||||
|
||||
conn, client, err := getSftpClient(user, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
|
@ -1139,7 +1139,7 @@ func TestConcurrency(t *testing.T) {
|
|||
maxConns := 0
|
||||
maxSessions := 0
|
||||
for {
|
||||
servedReqs := atomic.LoadInt32(&closedConns)
|
||||
servedReqs := closedConns.Load()
|
||||
if servedReqs > 0 {
|
||||
stats := common.Connections.GetStats()
|
||||
if len(stats) > maxConns {
|
||||
|
|
|
@ -17,7 +17,6 @@ package sftpd
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/eikenb/pipeat"
|
||||
|
||||
|
@ -107,7 +106,7 @@ func (t *transfer) ReadAt(p []byte, off int64) (n int, err error) {
|
|||
t.Connection.UpdateLastActivity()
|
||||
|
||||
n, err = t.readerAt.ReadAt(p, off)
|
||||
atomic.AddInt64(&t.BytesSent, int64(n))
|
||||
t.BytesSent.Add(int64(n))
|
||||
|
||||
if err == nil {
|
||||
err = t.CheckRead()
|
||||
|
@ -133,7 +132,7 @@ func (t *transfer) WriteAt(p []byte, off int64) (n int, err error) {
|
|||
}
|
||||
|
||||
n, err = t.writerAt.WriteAt(p, off)
|
||||
atomic.AddInt64(&t.BytesReceived, int64(n))
|
||||
t.BytesReceived.Add(int64(n))
|
||||
|
||||
if err == nil {
|
||||
err = t.CheckWrite()
|
||||
|
@ -213,13 +212,13 @@ func (t *transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64,
|
|||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
if isDownload {
|
||||
atomic.StoreInt64(&t.BytesSent, written)
|
||||
t.BytesSent.Store(written)
|
||||
if errCheck := t.CheckRead(); errCheck != nil {
|
||||
err = errCheck
|
||||
break
|
||||
}
|
||||
} else {
|
||||
atomic.StoreInt64(&t.BytesReceived, written)
|
||||
t.BytesReceived.Store(written)
|
||||
if errCheck := t.CheckWrite(); errCheck != nil {
|
||||
err = errCheck
|
||||
break
|
||||
|
@ -245,7 +244,7 @@ func (t *transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64,
|
|||
}
|
||||
t.ErrTransfer = err
|
||||
if written > 0 || err != nil {
|
||||
metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.GetType(),
|
||||
metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(), t.GetType(),
|
||||
t.ErrTransfer, vfs.IsSFTPFs(t.Fs))
|
||||
}
|
||||
return written, err
|
||||
|
|
|
@ -32,14 +32,14 @@ func (l *listener) Accept() (net.Conn, error) {
|
|||
return nil, err
|
||||
}
|
||||
tc := &Conn{
|
||||
Conn: c,
|
||||
ReadTimeout: l.ReadTimeout,
|
||||
WriteTimeout: l.WriteTimeout,
|
||||
ReadThreshold: int32((l.ReadTimeout * 1024) / time.Second),
|
||||
WriteThreshold: int32((l.WriteTimeout * 1024) / time.Second),
|
||||
BytesReadFromDeadline: 0,
|
||||
BytesWrittenFromDeadline: 0,
|
||||
Conn: c,
|
||||
ReadTimeout: l.ReadTimeout,
|
||||
WriteTimeout: l.WriteTimeout,
|
||||
ReadThreshold: int32((l.ReadTimeout * 1024) / time.Second),
|
||||
WriteThreshold: int32((l.WriteTimeout * 1024) / time.Second),
|
||||
}
|
||||
tc.BytesReadFromDeadline.Store(0)
|
||||
tc.BytesWrittenFromDeadline.Store(0)
|
||||
return tc, nil
|
||||
}
|
||||
|
||||
|
@ -51,13 +51,13 @@ type Conn struct {
|
|||
WriteTimeout time.Duration
|
||||
ReadThreshold int32
|
||||
WriteThreshold int32
|
||||
BytesReadFromDeadline int32
|
||||
BytesWrittenFromDeadline int32
|
||||
BytesReadFromDeadline atomic.Int32
|
||||
BytesWrittenFromDeadline atomic.Int32
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
if atomic.LoadInt32(&c.BytesReadFromDeadline) > c.ReadThreshold {
|
||||
atomic.StoreInt32(&c.BytesReadFromDeadline, 0)
|
||||
if c.BytesReadFromDeadline.Load() > c.ReadThreshold {
|
||||
c.BytesReadFromDeadline.Store(0)
|
||||
// we set both read and write deadlines here otherwise after the request
|
||||
// is read writing the response fails with an i/o timeout error
|
||||
err = c.Conn.SetDeadline(time.Now().Add(c.ReadTimeout))
|
||||
|
@ -66,13 +66,13 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||
}
|
||||
}
|
||||
n, err = c.Conn.Read(b)
|
||||
atomic.AddInt32(&c.BytesReadFromDeadline, int32(n))
|
||||
c.BytesReadFromDeadline.Add(int32(n))
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
if atomic.LoadInt32(&c.BytesWrittenFromDeadline) > c.WriteThreshold {
|
||||
atomic.StoreInt32(&c.BytesWrittenFromDeadline, 0)
|
||||
if c.BytesWrittenFromDeadline.Load() > c.WriteThreshold {
|
||||
c.BytesWrittenFromDeadline.Store(0)
|
||||
// we extend the read deadline too, not sure it's necessary,
|
||||
// but it doesn't hurt
|
||||
err = c.Conn.SetDeadline(time.Now().Add(c.WriteTimeout))
|
||||
|
@ -81,7 +81,7 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
|||
}
|
||||
}
|
||||
n, err = c.Conn.Write(b)
|
||||
atomic.AddInt32(&c.BytesWrittenFromDeadline, int32(n))
|
||||
c.BytesWrittenFromDeadline.Add(int32(n))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -902,7 +902,7 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *a
|
|||
finished := false
|
||||
var wg sync.WaitGroup
|
||||
var errOnce sync.Once
|
||||
var hasError int32
|
||||
var hasError atomic.Bool
|
||||
var poolError error
|
||||
|
||||
poolCtx, poolCancel := context.WithCancel(ctx)
|
||||
|
@ -919,7 +919,7 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *a
|
|||
offset = end
|
||||
|
||||
guard <- struct{}{}
|
||||
if atomic.LoadInt32(&hasError) == 1 {
|
||||
if hasError.Load() {
|
||||
fsLog(fs, logger.LevelDebug, "pool error, download for part %v not started", part)
|
||||
break
|
||||
}
|
||||
|
@ -941,7 +941,7 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *a
|
|||
if err != nil {
|
||||
errOnce.Do(func() {
|
||||
fsLog(fs, logger.LevelError, "multipart download error: %+v", err)
|
||||
atomic.StoreInt32(&hasError, 1)
|
||||
hasError.Store(true)
|
||||
poolError = fmt.Errorf("multipart download error: %w", err)
|
||||
poolCancel()
|
||||
})
|
||||
|
@ -971,7 +971,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
|
|||
var blocks []string
|
||||
var wg sync.WaitGroup
|
||||
var errOnce sync.Once
|
||||
var hasError int32
|
||||
var hasError atomic.Bool
|
||||
var poolError error
|
||||
|
||||
poolCtx, poolCancel := context.WithCancel(ctx)
|
||||
|
@ -999,7 +999,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
|
|||
blocks = append(blocks, blockID)
|
||||
|
||||
guard <- struct{}{}
|
||||
if atomic.LoadInt32(&hasError) == 1 {
|
||||
if hasError.Load() {
|
||||
fsLog(fs, logger.LevelError, "pool error, upload for part %v not started", part)
|
||||
pool.releaseBuffer(buf)
|
||||
break
|
||||
|
@ -1023,7 +1023,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
|
|||
if err != nil {
|
||||
errOnce.Do(func() {
|
||||
fsLog(fs, logger.LevelDebug, "multipart upload error: %+v", err)
|
||||
atomic.StoreInt32(&hasError, 1)
|
||||
hasError.Store(true)
|
||||
poolError = fmt.Errorf("multipart upload error: %w", err)
|
||||
poolCancel()
|
||||
})
|
||||
|
|
|
@ -835,7 +835,7 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int
|
|||
var completedParts []types.CompletedPart
|
||||
var partMutex sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
var hasError int32
|
||||
var hasError atomic.Bool
|
||||
var errOnce sync.Once
|
||||
var copyError error
|
||||
var partNumber int32
|
||||
|
@ -854,7 +854,7 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int
|
|||
offset = end
|
||||
|
||||
guard <- struct{}{}
|
||||
if atomic.LoadInt32(&hasError) == 1 {
|
||||
if hasError.Load() {
|
||||
fsLog(fs, logger.LevelDebug, "previous multipart copy error, copy for part %d not started", partNumber)
|
||||
break
|
||||
}
|
||||
|
@ -880,7 +880,7 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int
|
|||
if err != nil {
|
||||
errOnce.Do(func() {
|
||||
fsLog(fs, logger.LevelError, "unable to copy part number %d: %+v", partNum, err)
|
||||
atomic.StoreInt32(&hasError, 1)
|
||||
hasError.Store(true)
|
||||
copyError = fmt.Errorf("error copying part number %d: %w", partNum, err)
|
||||
opCancel()
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ type webDavFile struct {
|
|||
info os.FileInfo
|
||||
startOffset int64
|
||||
isFinished bool
|
||||
readTryed int32
|
||||
readTryed atomic.Bool
|
||||
}
|
||||
|
||||
func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, pipeReader *pipeat.PipeReaderAt) *webDavFile {
|
||||
|
@ -56,15 +56,16 @@ func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter
|
|||
} else if pipeReader != nil {
|
||||
reader = pipeReader
|
||||
}
|
||||
return &webDavFile{
|
||||
f := &webDavFile{
|
||||
BaseTransfer: baseTransfer,
|
||||
writer: writer,
|
||||
reader: reader,
|
||||
isFinished: false,
|
||||
startOffset: 0,
|
||||
info: nil,
|
||||
readTryed: 0,
|
||||
}
|
||||
f.readTryed.Store(false)
|
||||
return f
|
||||
}
|
||||
|
||||
type webDavFileInfo struct {
|
||||
|
@ -124,7 +125,7 @@ func (f *webDavFile) Stat() (os.FileInfo, error) {
|
|||
f.Unlock()
|
||||
if f.GetType() == common.TransferUpload && errUpload == nil {
|
||||
info := &webDavFileInfo{
|
||||
FileInfo: vfs.NewFileInfo(f.GetFsPath(), false, atomic.LoadInt64(&f.BytesReceived), time.Unix(0, 0), false),
|
||||
FileInfo: vfs.NewFileInfo(f.GetFsPath(), false, f.BytesReceived.Load(), time.Unix(0, 0), false),
|
||||
Fs: f.Fs,
|
||||
virtualPath: f.GetVirtualPath(),
|
||||
fsPath: f.GetFsPath(),
|
||||
|
@ -149,10 +150,10 @@ func (f *webDavFile) Stat() (os.FileInfo, error) {
|
|||
|
||||
// Read reads the contents to downloads.
|
||||
func (f *webDavFile) Read(p []byte) (n int, err error) {
|
||||
if atomic.LoadInt32(&f.AbortTransfer) == 1 {
|
||||
if f.AbortTransfer.Load() {
|
||||
return 0, errTransferAborted
|
||||
}
|
||||
if atomic.LoadInt32(&f.readTryed) == 0 {
|
||||
if !f.readTryed.Load() {
|
||||
if !f.Connection.User.HasPerm(dataprovider.PermDownload, path.Dir(f.GetVirtualPath())) {
|
||||
return 0, f.Connection.GetPermissionDeniedError()
|
||||
}
|
||||
|
@ -171,7 +172,7 @@ func (f *webDavFile) Read(p []byte) (n int, err error) {
|
|||
f.Connection.Log(logger.LevelDebug, "download for file %#v denied by pre action: %v", f.GetVirtualPath(), err)
|
||||
return 0, f.Connection.GetPermissionDeniedError()
|
||||
}
|
||||
atomic.StoreInt32(&f.readTryed, 1)
|
||||
f.readTryed.Store(true)
|
||||
}
|
||||
|
||||
f.Connection.UpdateLastActivity()
|
||||
|
@ -198,7 +199,7 @@ func (f *webDavFile) Read(p []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
n, err = f.reader.Read(p)
|
||||
atomic.AddInt64(&f.BytesSent, int64(n))
|
||||
f.BytesSent.Add(int64(n))
|
||||
if err == nil {
|
||||
err = f.CheckRead()
|
||||
}
|
||||
|
@ -212,14 +213,14 @@ func (f *webDavFile) Read(p []byte) (n int, err error) {
|
|||
|
||||
// Write writes the uploaded contents.
|
||||
func (f *webDavFile) Write(p []byte) (n int, err error) {
|
||||
if atomic.LoadInt32(&f.AbortTransfer) == 1 {
|
||||
if f.AbortTransfer.Load() {
|
||||
return 0, errTransferAborted
|
||||
}
|
||||
|
||||
f.Connection.UpdateLastActivity()
|
||||
|
||||
n, err = f.writer.Write(p)
|
||||
atomic.AddInt64(&f.BytesReceived, int64(n))
|
||||
f.BytesReceived.Add(int64(n))
|
||||
|
||||
if err == nil {
|
||||
err = f.CheckWrite()
|
||||
|
@ -252,7 +253,7 @@ func (f *webDavFile) updateTransferQuotaOnSeek() {
|
|||
if transferQuota.HasSizeLimits() {
|
||||
go func(ulSize, dlSize int64, user dataprovider.User) {
|
||||
dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
|
||||
}(atomic.LoadInt64(&f.BytesReceived), atomic.LoadInt64(&f.BytesSent), f.Connection.User)
|
||||
}(f.BytesReceived.Load(), f.BytesSent.Load(), f.Connection.User)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -270,7 +271,7 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) {
|
|||
return ret, err
|
||||
}
|
||||
if f.GetType() == common.TransferDownload {
|
||||
readOffset := f.startOffset + atomic.LoadInt64(&f.BytesSent)
|
||||
readOffset := f.startOffset + f.BytesSent.Load()
|
||||
if offset == 0 && readOffset == 0 {
|
||||
if whence == io.SeekStart {
|
||||
return 0, nil
|
||||
|
@ -288,8 +289,8 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) {
|
|||
f.reader = nil
|
||||
}
|
||||
startByte := int64(0)
|
||||
atomic.StoreInt64(&f.BytesReceived, 0)
|
||||
atomic.StoreInt64(&f.BytesSent, 0)
|
||||
f.BytesReceived.Store(0)
|
||||
f.BytesSent.Store(0)
|
||||
f.updateTransferQuotaOnSeek()
|
||||
|
||||
switch whence {
|
||||
|
@ -369,7 +370,7 @@ func (f *webDavFile) setFinished() error {
|
|||
|
||||
func (f *webDavFile) isTransfer() bool {
|
||||
if f.GetType() == common.TransferDownload {
|
||||
return atomic.LoadInt32(&f.readTryed) > 0
|
||||
return f.readTryed.Load()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue