Bladeren bron

use the new atomic types introduced in Go 1.19

we depend on Go 1.19 anyway

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 2 jaren geleden
bovenliggende
commit
95e9106902

+ 4 - 4
internal/common/clientsmap.go

@@ -23,13 +23,13 @@ import (
 
 // clienstMap is a struct containing the map of the connected clients
 type clientsMap struct {
-	totalConnections int32
+	totalConnections atomic.Int32
 	mu               sync.RWMutex
 	clients          map[string]int
 }
 
 func (c *clientsMap) add(source string) {
-	atomic.AddInt32(&c.totalConnections, 1)
+	c.totalConnections.Add(1)
 
 	c.mu.Lock()
 	defer c.mu.Unlock()
@@ -42,7 +42,7 @@ func (c *clientsMap) remove(source string) {
 	defer c.mu.Unlock()
 
 	if val, ok := c.clients[source]; ok {
-		atomic.AddInt32(&c.totalConnections, -1)
+		c.totalConnections.Add(-1)
 		c.clients[source]--
 		if val > 1 {
 			return
@@ -54,7 +54,7 @@ func (c *clientsMap) remove(source string) {
 }
 
 func (c *clientsMap) getTotal() int32 {
-	return atomic.LoadInt32(&c.totalConnections)
+	return c.totalConnections.Load()
 }
 
 func (c *clientsMap) getTotalFrom(source string) int {

+ 12 - 11
internal/common/common.go

@@ -704,16 +704,17 @@ func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error {
 type SSHConnection struct {
 	id           string
 	conn         net.Conn
-	lastActivity int64
+	lastActivity atomic.Int64
 }
 
 // NewSSHConnection returns a new SSHConnection
 func NewSSHConnection(id string, conn net.Conn) *SSHConnection {
-	return &SSHConnection{
-		id:           id,
-		conn:         conn,
-		lastActivity: time.Now().UnixNano(),
+	c := &SSHConnection{
+		id:   id,
+		conn: conn,
 	}
+	c.lastActivity.Store(time.Now().UnixNano())
+	return c
 }
 
 // GetID returns the ID for this SSHConnection
@@ -723,12 +724,12 @@ func (c *SSHConnection) GetID() string {
 
 // UpdateLastActivity updates last activity for this connection
 func (c *SSHConnection) UpdateLastActivity() {
-	atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano())
+	c.lastActivity.Store(time.Now().UnixNano())
 }
 
 // GetLastActivity returns the last connection activity
 func (c *SSHConnection) GetLastActivity() time.Time {
-	return time.Unix(0, atomic.LoadInt64(&c.lastActivity))
+	return time.Unix(0, c.lastActivity.Load())
 }
 
 // Close closes the underlying network connection
@@ -741,7 +742,7 @@ type ActiveConnections struct {
 	// clients contains both authenticated and estabilished connections and the ones waiting
 	// for authentication
 	clients              clientsMap
-	transfersCheckStatus int32
+	transfersCheckStatus atomic.Bool
 	sync.RWMutex
 	connections    []ActiveConnection
 	sshConnections []*SSHConnection
@@ -953,12 +954,12 @@ func (conns *ActiveConnections) checkIdles() {
 }
 
 func (conns *ActiveConnections) checkTransfers() {
-	if atomic.LoadInt32(&conns.transfersCheckStatus) == 1 {
+	if conns.transfersCheckStatus.Load() {
 		logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution")
 		return
 	}
-	atomic.StoreInt32(&conns.transfersCheckStatus, 1)
-	defer atomic.StoreInt32(&conns.transfersCheckStatus, 0)
+	conns.transfersCheckStatus.Store(true)
+	defer conns.transfersCheckStatus.Store(false)
 
 	conns.RLock()
 

+ 10 - 11
internal/common/common_test.go

@@ -24,7 +24,6 @@ import (
 	"path/filepath"
 	"runtime"
 	"strings"
-	"sync/atomic"
 	"testing"
 	"time"
 
@@ -498,14 +497,14 @@ func TestIdleConnections(t *testing.T) {
 		},
 	}
 	c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, "", "", user)
-	c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
+	c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano())
 	fakeConn := &fakeConnection{
 		BaseConnection: c,
 	}
 	// both ssh connections are expired but they should get removed only
 	// if there is no associated connection
-	sshConn1.lastActivity = c.lastActivity
-	sshConn2.lastActivity = c.lastActivity
+	sshConn1.lastActivity.Store(c.lastActivity.Load())
+	sshConn2.lastActivity.Store(c.lastActivity.Load())
 	Connections.AddSSHConnection(sshConn1)
 	err = Connections.Add(fakeConn)
 	assert.NoError(t, err)
@@ -520,7 +519,7 @@ func TestIdleConnections(t *testing.T) {
 	assert.Equal(t, Connections.GetActiveSessions(username), 2)
 
 	cFTP := NewBaseConnection("id2", ProtocolFTP, "", "", dataprovider.User{})
-	cFTP.lastActivity = time.Now().UnixNano()
+	cFTP.lastActivity.Store(time.Now().UnixNano())
 	fakeConn = &fakeConnection{
 		BaseConnection: cFTP,
 	}
@@ -541,9 +540,9 @@ func TestIdleConnections(t *testing.T) {
 	}, 1*time.Second, 200*time.Millisecond)
 	stopEventScheduler()
 	assert.Len(t, Connections.GetStats(), 2)
-	c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
-	cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
-	sshConn2.lastActivity = c.lastActivity
+	c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano())
+	cFTP.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano())
+	sshConn2.lastActivity.Store(c.lastActivity.Load())
 	startPeriodicChecks(100 * time.Millisecond)
 	assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 2*time.Second, 200*time.Millisecond)
 	assert.Eventually(t, func() bool {
@@ -646,9 +645,9 @@ func TestConnectionStatus(t *testing.T) {
 		BaseConnection: c1,
 	}
 	t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
-	t1.BytesReceived = 123
+	t1.BytesReceived.Store(123)
 	t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
-	t2.BytesSent = 456
+	t2.BytesSent.Store(456)
 	c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user)
 	fakeConn2 := &fakeConnection{
 		BaseConnection: c2,
@@ -698,7 +697,7 @@ func TestConnectionStatus(t *testing.T) {
 
 	err = fakeConn3.SignalTransfersAbort()
 	assert.NoError(t, err)
-	assert.Equal(t, int32(1), atomic.LoadInt32(&t3.AbortTransfer))
+	assert.True(t, t3.AbortTransfer.Load())
 	err = t3.Close()
 	assert.NoError(t, err)
 	err = fakeConn3.SignalTransfersAbort()

+ 16 - 14
internal/common/connection.go

@@ -38,12 +38,12 @@ import (
 type BaseConnection struct {
 	// last activity for this connection.
 	// Since this field is accessed atomically we put it as first element of the struct to achieve 64 bit alignment
-	lastActivity int64
+	lastActivity atomic.Int64
 	uploadDone   atomic.Bool
 	downloadDone atomic.Bool
 	// unique ID for a transfer.
 	// This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment
-	transferID int64
+	transferID atomic.Int64
 	// Unique identifier for the connection
 	ID string
 	// user associated with this connection if any
@@ -64,16 +64,18 @@ func NewBaseConnection(id, protocol, localAddr, remoteAddr string, user dataprov
 		connID = fmt.Sprintf("%s_%s", protocol, id)
 	}
 	user.UploadBandwidth, user.DownloadBandwidth = user.GetBandwidthForIP(util.GetIPFromRemoteAddress(remoteAddr), connID)
-	return &BaseConnection{
-		ID:           connID,
-		User:         user,
-		startTime:    time.Now(),
-		protocol:     protocol,
-		localAddr:    localAddr,
-		remoteAddr:   remoteAddr,
-		lastActivity: time.Now().UnixNano(),
-		transferID:   0,
+	c := &BaseConnection{
+		ID:         connID,
+		User:       user,
+		startTime:  time.Now(),
+		protocol:   protocol,
+		localAddr:  localAddr,
+		remoteAddr: remoteAddr,
 	}
+	c.transferID.Store(0)
+	c.lastActivity.Store(time.Now().UnixNano())
+
+	return c
 }
 
 // Log outputs a log entry to the configured logger
@@ -83,7 +85,7 @@ func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...any) {
 
 // GetTransferID returns an unique transfer ID for this connection
 func (c *BaseConnection) GetTransferID() int64 {
-	return atomic.AddInt64(&c.transferID, 1)
+	return c.transferID.Add(1)
 }
 
 // GetID returns the connection ID
@@ -126,12 +128,12 @@ func (c *BaseConnection) GetConnectionTime() time.Time {
 
 // UpdateLastActivity updates last activity for this connection
 func (c *BaseConnection) UpdateLastActivity() {
-	atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano())
+	c.lastActivity.Store(time.Now().UnixNano())
 }
 
 // GetLastActivity returns the last connection activity
 func (c *BaseConnection) GetLastActivity() time.Time {
-	return time.Unix(0, atomic.LoadInt64(&c.lastActivity))
+	return time.Unix(0, c.lastActivity.Load())
 }
 
 // CloseFS closes the underlying fs

+ 3 - 3
internal/common/eventmanager.go

@@ -82,7 +82,7 @@ func HandleCertificateEvent(params EventParams) {
 // eventRulesContainer stores event rules by trigger
 type eventRulesContainer struct {
 	sync.RWMutex
-	lastLoad          int64
+	lastLoad          atomic.Int64
 	FsEvents          []dataprovider.EventRule
 	ProviderEvents    []dataprovider.EventRule
 	Schedules         []dataprovider.EventRule
@@ -101,11 +101,11 @@ func (r *eventRulesContainer) removeAsyncTask() {
 }
 
 func (r *eventRulesContainer) getLastLoadTime() int64 {
-	return atomic.LoadInt64(&r.lastLoad)
+	return r.lastLoad.Load()
 }
 
 func (r *eventRulesContainer) setLastLoadTime(modTime int64) {
-	atomic.StoreInt64(&r.lastLoad, modTime)
+	r.lastLoad.Store(modTime)
 }
 
 // RemoveRule deletes the rule with the specified name

+ 5 - 4
internal/common/ratelimiter.go

@@ -186,16 +186,16 @@ func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
 }
 
 type sourceRateLimiter struct {
-	lastActivity int64
+	lastActivity *atomic.Int64
 	bucket       *rate.Limiter
 }
 
 func (s *sourceRateLimiter) updateLastActivity() {
-	atomic.StoreInt64(&s.lastActivity, time.Now().UnixNano())
+	s.lastActivity.Store(time.Now().UnixNano())
 }
 
 func (s *sourceRateLimiter) getLastActivity() int64 {
-	return atomic.LoadInt64(&s.lastActivity)
+	return s.lastActivity.Load()
 }
 
 type sourceBuckets struct {
@@ -224,7 +224,8 @@ func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Rese
 	b.cleanup()
 
 	src := sourceRateLimiter{
-		bucket: r,
+		lastActivity: new(atomic.Int64),
+		bucket:       r,
 	}
 	src.updateLastActivity()
 	b.buckets[source] = src

+ 37 - 37
internal/common/transfer.go

@@ -35,8 +35,8 @@ var (
 // BaseTransfer contains protocols common transfer details for an upload or a download.
 type BaseTransfer struct { //nolint:maligned
 	ID              int64
-	BytesSent       int64
-	BytesReceived   int64
+	BytesSent       atomic.Int64
+	BytesReceived   atomic.Int64
 	Fs              vfs.Fs
 	File            vfs.File
 	Connection      *BaseConnection
@@ -52,7 +52,7 @@ type BaseTransfer struct { //nolint:maligned
 	truncatedSize   int64
 	isNewFile       bool
 	transferType    int
-	AbortTransfer   int32
+	AbortTransfer   atomic.Bool
 	aTime           time.Time
 	mTime           time.Time
 	transferQuota   dataprovider.TransferQuota
@@ -79,14 +79,14 @@ func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPat
 		InitialSize:     initialSize,
 		isNewFile:       isNewFile,
 		requestPath:     requestPath,
-		BytesSent:       0,
-		BytesReceived:   0,
 		MaxWriteSize:    maxWriteSize,
-		AbortTransfer:   0,
 		truncatedSize:   truncatedSize,
 		transferQuota:   transferQuota,
 		Fs:              fs,
 	}
+	t.AbortTransfer.Store(false)
+	t.BytesSent.Store(0)
+	t.BytesReceived.Store(0)
 
 	conn.AddTransfer(t)
 	return t
@@ -115,19 +115,19 @@ func (t *BaseTransfer) GetType() int {
 // GetSize returns the transferred size
 func (t *BaseTransfer) GetSize() int64 {
 	if t.transferType == TransferDownload {
-		return atomic.LoadInt64(&t.BytesSent)
+		return t.BytesSent.Load()
 	}
-	return atomic.LoadInt64(&t.BytesReceived)
+	return t.BytesReceived.Load()
 }
 
 // GetDownloadedSize returns the transferred size
 func (t *BaseTransfer) GetDownloadedSize() int64 {
-	return atomic.LoadInt64(&t.BytesSent)
+	return t.BytesSent.Load()
 }
 
 // GetUploadedSize returns the transferred size
 func (t *BaseTransfer) GetUploadedSize() int64 {
-	return atomic.LoadInt64(&t.BytesReceived)
+	return t.BytesReceived.Load()
 }
 
 // GetStartTime returns the start time
@@ -153,7 +153,7 @@ func (t *BaseTransfer) SignalClose(err error) {
 	t.Lock()
 	t.errAbort = err
 	t.Unlock()
-	atomic.StoreInt32(&(t.AbortTransfer), 1)
+	t.AbortTransfer.Store(true)
 }
 
 // GetTruncatedSize returns the truncated sized if this is an upload overwriting
@@ -217,11 +217,11 @@ func (t *BaseTransfer) CheckRead() error {
 		return nil
 	}
 	if t.transferQuota.AllowedTotalSize > 0 {
-		if atomic.LoadInt64(&t.BytesSent)+atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedTotalSize {
+		if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize {
 			return t.Connection.GetReadQuotaExceededError()
 		}
 	} else if t.transferQuota.AllowedDLSize > 0 {
-		if atomic.LoadInt64(&t.BytesSent) > t.transferQuota.AllowedDLSize {
+		if t.BytesSent.Load() > t.transferQuota.AllowedDLSize {
 			return t.Connection.GetReadQuotaExceededError()
 		}
 	}
@@ -230,18 +230,18 @@ func (t *BaseTransfer) CheckRead() error {
 
 // CheckWrite returns an error if write if not allowed
 func (t *BaseTransfer) CheckWrite() error {
-	if t.MaxWriteSize > 0 && atomic.LoadInt64(&t.BytesReceived) > t.MaxWriteSize {
+	if t.MaxWriteSize > 0 && t.BytesReceived.Load() > t.MaxWriteSize {
 		return t.Connection.GetQuotaExceededError()
 	}
 	if t.transferQuota.AllowedULSize == 0 && t.transferQuota.AllowedTotalSize == 0 {
 		return nil
 	}
 	if t.transferQuota.AllowedTotalSize > 0 {
-		if atomic.LoadInt64(&t.BytesSent)+atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedTotalSize {
+		if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize {
 			return t.Connection.GetQuotaExceededError()
 		}
 	} else if t.transferQuota.AllowedULSize > 0 {
-		if atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedULSize {
+		if t.BytesReceived.Load() > t.transferQuota.AllowedULSize {
 			return t.Connection.GetQuotaExceededError()
 		}
 	}
@@ -261,14 +261,14 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) {
 				if t.MaxWriteSize > 0 {
 					sizeDiff := initialSize - size
 					t.MaxWriteSize += sizeDiff
-					metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived),
+					metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(),
 						t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs))
 					if t.transferQuota.HasSizeLimits() {
 						go func(ulSize, dlSize int64, user dataprovider.User) {
 							dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
-						}(atomic.LoadInt64(&t.BytesReceived), atomic.LoadInt64(&t.BytesSent), t.Connection.User)
+						}(t.BytesReceived.Load(), t.BytesSent.Load(), t.Connection.User)
 					}
-					atomic.StoreInt64(&t.BytesReceived, 0)
+					t.BytesReceived.Store(0)
 				}
 				t.Unlock()
 			}
@@ -276,7 +276,7 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) {
 				fsPath, size, t.MaxWriteSize, t.InitialSize, err)
 			return initialSize, err
 		}
-		if size == 0 && atomic.LoadInt64(&t.BytesSent) == 0 {
+		if size == 0 && t.BytesSent.Load() == 0 {
 			// for cloud providers the file is always truncated to zero, we don't support append/resume for uploads
 			// for buffered SFTP we can have buffered bytes so we returns an error
 			if !vfs.IsBufferedSFTPFs(t.Fs) {
@@ -302,8 +302,8 @@ func (t *BaseTransfer) TransferError(err error) {
 	}
 	elapsed := time.Since(t.start).Nanoseconds() / 1000000
 	t.Connection.Log(logger.LevelError, "Unexpected error for transfer, path: %#v, error: \"%v\" bytes sent: %v, "+
-		"bytes received: %v transfer running since %v ms", t.fsPath, t.ErrTransfer, atomic.LoadInt64(&t.BytesSent),
-		atomic.LoadInt64(&t.BytesReceived), elapsed)
+		"bytes received: %v transfer running since %v ms", t.fsPath, t.ErrTransfer, t.BytesSent.Load(),
+		t.BytesReceived.Load(), elapsed)
 }
 
 func (t *BaseTransfer) getUploadFileSize() (int64, error) {
@@ -333,7 +333,7 @@ func (t *BaseTransfer) checkUploadOutsideHomeDir(err error) int {
 	t.Connection.Log(logger.LevelWarn, "upload in temp path cannot be renamed, delete temporary file: %#v, deletion error: %v",
 		t.effectiveFsPath, err)
 	// the file is outside the home dir so don't update the quota
-	atomic.StoreInt64(&t.BytesReceived, 0)
+	t.BytesReceived.Store(0)
 	t.MinWriteOffset = 0
 	return 1
 }
@@ -351,18 +351,18 @@ func (t *BaseTransfer) Close() error {
 	if t.isNewFile {
 		numFiles = 1
 	}
-	metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived),
+	metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(),
 		t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs))
 	if t.transferQuota.HasSizeLimits() {
-		dataprovider.UpdateUserTransferQuota(&t.Connection.User, atomic.LoadInt64(&t.BytesReceived), //nolint:errcheck
-			atomic.LoadInt64(&t.BytesSent), false)
+		dataprovider.UpdateUserTransferQuota(&t.Connection.User, t.BytesReceived.Load(), //nolint:errcheck
+			t.BytesSent.Load(), false)
 	}
 	if t.File != nil && t.Connection.IsQuotaExceededError(t.ErrTransfer) {
 		// if quota is exceeded we try to remove the partial file for uploads to local filesystem
 		err = t.Fs.Remove(t.File.Name(), false)
 		if err == nil {
 			numFiles--
-			atomic.StoreInt64(&t.BytesReceived, 0)
+			t.BytesReceived.Store(0)
 			t.MinWriteOffset = 0
 		}
 		t.Connection.Log(logger.LevelWarn, "upload denied due to space limit, delete temporary file: %#v, deletion error: %v",
@@ -380,7 +380,7 @@ func (t *BaseTransfer) Close() error {
 				t.ErrTransfer, t.effectiveFsPath, err)
 			if err == nil {
 				numFiles--
-				atomic.StoreInt64(&t.BytesReceived, 0)
+				t.BytesReceived.Store(0)
 				t.MinWriteOffset = 0
 			}
 		}
@@ -388,12 +388,12 @@ func (t *BaseTransfer) Close() error {
 	elapsed := time.Since(t.start).Nanoseconds() / 1000000
 	var uploadFileSize int64
 	if t.transferType == TransferDownload {
-		logger.TransferLog(downloadLogSender, t.fsPath, elapsed, atomic.LoadInt64(&t.BytesSent), t.Connection.User.Username,
+		logger.TransferLog(downloadLogSender, t.fsPath, elapsed, t.BytesSent.Load(), t.Connection.User.Username,
 			t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode)
 		ExecuteActionNotification(t.Connection, operationDownload, t.fsPath, t.requestPath, "", "", "", //nolint:errcheck
-			atomic.LoadInt64(&t.BytesSent), t.ErrTransfer)
+			t.BytesSent.Load(), t.ErrTransfer)
 	} else {
-		uploadFileSize = atomic.LoadInt64(&t.BytesReceived) + t.MinWriteOffset
+		uploadFileSize = t.BytesReceived.Load() + t.MinWriteOffset
 		if statSize, errStat := t.getUploadFileSize(); errStat == nil {
 			uploadFileSize = statSize
 		}
@@ -401,7 +401,7 @@ func (t *BaseTransfer) Close() error {
 		numFiles, uploadFileSize = t.executeUploadHook(numFiles, uploadFileSize)
 		t.updateQuota(numFiles, uploadFileSize)
 		t.updateTimes()
-		logger.TransferLog(uploadLogSender, t.fsPath, elapsed, atomic.LoadInt64(&t.BytesReceived), t.Connection.User.Username,
+		logger.TransferLog(uploadLogSender, t.fsPath, elapsed, t.BytesReceived.Load(), t.Connection.User.Username,
 			t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode)
 	}
 	if t.ErrTransfer != nil {
@@ -428,11 +428,11 @@ func (t *BaseTransfer) updateTransferTimestamps(uploadFileSize int64) {
 		}
 		return
 	}
-	if t.Connection.User.FirstDownload == 0 && !t.Connection.downloadDone.Load() && atomic.LoadInt64(&t.BytesSent) > 0 {
+	if t.Connection.User.FirstDownload == 0 && !t.Connection.downloadDone.Load() && t.BytesSent.Load() > 0 {
 		if err := dataprovider.UpdateUserTransferTimestamps(t.Connection.User.Username, false); err == nil {
 			t.Connection.downloadDone.Store(true)
 			ExecuteActionNotification(t.Connection, operationFirstDownload, t.fsPath, t.requestPath, "", //nolint:errcheck
-				"", "", atomic.LoadInt64(&t.BytesSent), t.ErrTransfer)
+				"", "", t.BytesSent.Load(), t.ErrTransfer)
 		}
 	}
 }
@@ -449,7 +449,7 @@ func (t *BaseTransfer) executeUploadHook(numFiles int, fileSize int64) (int, int
 		if err == nil {
 			numFiles--
 			fileSize = 0
-			atomic.StoreInt64(&t.BytesReceived, 0)
+			t.BytesReceived.Store(0)
 			t.MinWriteOffset = 0
 		} else {
 			t.Connection.Log(logger.LevelWarn, "unable to remove path %q after upload hook failure: %v", t.fsPath, err)
@@ -494,10 +494,10 @@ func (t *BaseTransfer) HandleThrottle() {
 	var trasferredBytes int64
 	if t.transferType == TransferDownload {
 		wantedBandwidth = t.Connection.User.DownloadBandwidth
-		trasferredBytes = atomic.LoadInt64(&t.BytesSent)
+		trasferredBytes = t.BytesSent.Load()
 	} else {
 		wantedBandwidth = t.Connection.User.UploadBandwidth
-		trasferredBytes = atomic.LoadInt64(&t.BytesReceived)
+		trasferredBytes = t.BytesReceived.Load()
 	}
 	if wantedBandwidth > 0 {
 		// real and wanted elapsed as milliseconds, bytes as kilobytes

+ 24 - 24
internal/common/transfer_test.go

@@ -33,11 +33,11 @@ import (
 func TestTransferUpdateQuota(t *testing.T) {
 	conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{})
 	transfer := BaseTransfer{
-		Connection:    conn,
-		transferType:  TransferUpload,
-		BytesReceived: 123,
-		Fs:            vfs.NewOsFs("", os.TempDir(), ""),
+		Connection:   conn,
+		transferType: TransferUpload,
+		Fs:           vfs.NewOsFs("", os.TempDir(), ""),
 	}
+	transfer.BytesReceived.Store(123)
 	errFake := errors.New("fake error")
 	transfer.TransferError(errFake)
 	assert.False(t, transfer.updateQuota(1, 0))
@@ -56,7 +56,7 @@ func TestTransferUpdateQuota(t *testing.T) {
 		QuotaSize:   -1,
 	})
 	transfer.ErrTransfer = nil
-	transfer.BytesReceived = 1
+	transfer.BytesReceived.Store(1)
 	transfer.requestPath = "/vdir/file"
 	assert.True(t, transfer.updateQuota(1, 0))
 	err = transfer.Close()
@@ -80,7 +80,7 @@ func TestTransferThrottling(t *testing.T) {
 	wantedDownloadElapsed -= wantedDownloadElapsed / 10
 	conn := NewBaseConnection("id", ProtocolSCP, "", "", u)
 	transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
-	transfer.BytesReceived = testFileSize
+	transfer.BytesReceived.Store(testFileSize)
 	transfer.Connection.UpdateLastActivity()
 	startTime := transfer.Connection.GetLastActivity()
 	transfer.HandleThrottle()
@@ -90,7 +90,7 @@ func TestTransferThrottling(t *testing.T) {
 	assert.NoError(t, err)
 
 	transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
-	transfer.BytesSent = testFileSize
+	transfer.BytesSent.Store(testFileSize)
 	transfer.Connection.UpdateLastActivity()
 	startTime = transfer.Connection.GetLastActivity()
 
@@ -226,7 +226,7 @@ func TestTransferErrors(t *testing.T) {
 	assert.Equal(t, testFile, transfer.GetFsPath())
 	transfer.SetCancelFn(cancelFn)
 	errFake := errors.New("err fake")
-	transfer.BytesReceived = 9
+	transfer.BytesReceived.Store(9)
 	transfer.TransferError(ErrQuotaExceeded)
 	assert.True(t, isCancelled)
 	transfer.TransferError(errFake)
@@ -249,7 +249,7 @@ func TestTransferErrors(t *testing.T) {
 	fsPath := filepath.Join(os.TempDir(), "test_file")
 	transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true,
 		fs, dataprovider.TransferQuota{})
-	transfer.BytesReceived = 9
+	transfer.BytesReceived.Store(9)
 	transfer.TransferError(errFake)
 	assert.Error(t, transfer.ErrTransfer, errFake.Error())
 	// the file is closed from the embedding struct before to call close
@@ -269,7 +269,7 @@ func TestTransferErrors(t *testing.T) {
 	}
 	transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true,
 		fs, dataprovider.TransferQuota{})
-	transfer.BytesReceived = 9
+	transfer.BytesReceived.Store(9)
 	// the file is closed from the embedding struct before to call close
 	err = file.Close()
 	assert.NoError(t, err)
@@ -310,11 +310,11 @@ func TestRemovePartialCryptoFile(t *testing.T) {
 func TestFTPMode(t *testing.T) {
 	conn := NewBaseConnection("", ProtocolFTP, "", "", dataprovider.User{})
 	transfer := BaseTransfer{
-		Connection:    conn,
-		transferType:  TransferUpload,
-		BytesReceived: 123,
-		Fs:            vfs.NewOsFs("", os.TempDir(), ""),
+		Connection:   conn,
+		transferType: TransferUpload,
+		Fs:           vfs.NewOsFs("", os.TempDir(), ""),
 	}
+	transfer.BytesReceived.Store(123)
 	assert.Empty(t, transfer.ftpMode)
 	transfer.SetFtpMode("active")
 	assert.Equal(t, "active", transfer.ftpMode)
@@ -399,14 +399,14 @@ func TestTransferQuota(t *testing.T) {
 	transfer.transferQuota = dataprovider.TransferQuota{
 		AllowedTotalSize: 10,
 	}
-	transfer.BytesReceived = 5
-	transfer.BytesSent = 4
+	transfer.BytesReceived.Store(5)
+	transfer.BytesSent.Store(4)
 	err = transfer.CheckRead()
 	assert.NoError(t, err)
 	err = transfer.CheckWrite()
 	assert.NoError(t, err)
 
-	transfer.BytesSent = 6
+	transfer.BytesSent.Store(6)
 	err = transfer.CheckRead()
 	if assert.Error(t, err) {
 		assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error())
@@ -428,7 +428,7 @@ func TestTransferQuota(t *testing.T) {
 	err = transfer.CheckWrite()
 	assert.NoError(t, err)
 
-	transfer.BytesReceived = 11
+	transfer.BytesReceived.Store(11)
 	err = transfer.CheckRead()
 	if assert.Error(t, err) {
 		assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error())
@@ -442,11 +442,11 @@ func TestUploadOutsideHomeRenameError(t *testing.T) {
 
 	conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{})
 	transfer := BaseTransfer{
-		Connection:    conn,
-		transferType:  TransferUpload,
-		BytesReceived: 123,
-		Fs:            vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), ""),
+		Connection:   conn,
+		transferType: TransferUpload,
+		Fs:           vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), ""),
 	}
+	transfer.BytesReceived.Store(123)
 
 	fileName := filepath.Join(os.TempDir(), "_temp")
 	err := os.WriteFile(fileName, []byte(`data`), 0644)
@@ -459,10 +459,10 @@ func TestUploadOutsideHomeRenameError(t *testing.T) {
 	Config.TempPath = filepath.Clean(os.TempDir())
 	res = transfer.checkUploadOutsideHomeDir(nil)
 	assert.Equal(t, 0, res)
-	assert.Greater(t, transfer.BytesReceived, int64(0))
+	assert.Greater(t, transfer.BytesReceived.Load(), int64(0))
 	res = transfer.checkUploadOutsideHomeDir(os.ErrPermission)
 	assert.Equal(t, 1, res)
-	assert.Equal(t, int64(0), transfer.BytesReceived)
+	assert.Equal(t, int64(0), transfer.BytesReceived.Load())
 	assert.NoFileExists(t, fileName)
 
 	Config.TempPath = oldTempPath

+ 25 - 26
internal/common/transferschecker_test.go

@@ -21,7 +21,6 @@ import (
 	"path/filepath"
 	"strconv"
 	"strings"
-	"sync/atomic"
 	"testing"
 	"time"
 
@@ -96,7 +95,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
 	}
 	transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
 		"/file1", TransferUpload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
-	transfer1.BytesReceived = 150
+	transfer1.BytesReceived.Store(150)
 	err = Connections.Add(fakeConn1)
 	assert.NoError(t, err)
 	// the transferschecker will do nothing if there is only one ongoing transfer
@@ -110,8 +109,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
 	}
 	transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
 		"/file2", TransferUpload, 0, 0, 120, 40, true, fsUser, dataprovider.TransferQuota{})
-	transfer1.BytesReceived = 50
-	transfer2.BytesReceived = 60
+	transfer1.BytesReceived.Store(50)
+	transfer2.BytesReceived.Store(60)
 	err = Connections.Add(fakeConn2)
 	assert.NoError(t, err)
 
@@ -122,7 +121,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
 	}
 	transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"),
 		"/file3", TransferDownload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
-	transfer3.BytesReceived = 60 // this value will be ignored, this is a download
+	transfer3.BytesReceived.Store(60) // this value will be ignored, this is a download
 	err = Connections.Add(fakeConn3)
 	assert.NoError(t, err)
 
@@ -132,20 +131,20 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
 	assert.Nil(t, transfer2.errAbort)
 	assert.Nil(t, transfer3.errAbort)
 
-	transfer1.BytesReceived = 80 // truncated size will be subtracted, we are not overquota
+	transfer1.BytesReceived.Store(80) // truncated size will be subtracted, we are not overquota
 	Connections.checkTransfers()
 	assert.Nil(t, transfer1.errAbort)
 	assert.Nil(t, transfer2.errAbort)
 	assert.Nil(t, transfer3.errAbort)
-	transfer1.BytesReceived = 120
+	transfer1.BytesReceived.Store(120)
 	// we are now overquota
 	// if another check is in progress nothing is done
-	atomic.StoreInt32(&Connections.transfersCheckStatus, 1)
+	Connections.transfersCheckStatus.Store(true)
 	Connections.checkTransfers()
 	assert.Nil(t, transfer1.errAbort)
 	assert.Nil(t, transfer2.errAbort)
 	assert.Nil(t, transfer3.errAbort)
-	atomic.StoreInt32(&Connections.transfersCheckStatus, 0)
+	Connections.transfersCheckStatus.Store(false)
 
 	Connections.checkTransfers()
 	assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort), transfer1.errAbort)
@@ -172,8 +171,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
 	assert.Nil(t, transfer2.errAbort)
 	assert.Nil(t, transfer3.errAbort)
 	// now check a public folder
-	transfer1.BytesReceived = 0
-	transfer2.BytesReceived = 0
+	transfer1.BytesReceived.Store(0)
+	transfer2.BytesReceived.Store(0)
 	connID4 := xid.New().String()
 	fsFolder, err := user.GetFilesystemForPath(path.Join(vdirPath, "/file1"), connID4)
 	assert.NoError(t, err)
@@ -197,12 +196,12 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
 
 	err = Connections.Add(fakeConn5)
 	assert.NoError(t, err)
-	transfer4.BytesReceived = 50
-	transfer5.BytesReceived = 40
+	transfer4.BytesReceived.Store(50)
+	transfer5.BytesReceived.Store(40)
 	Connections.checkTransfers()
 	assert.Nil(t, transfer4.errAbort)
 	assert.Nil(t, transfer5.errAbort)
-	transfer5.BytesReceived = 60
+	transfer5.BytesReceived.Store(60)
 	Connections.checkTransfers()
 	assert.Nil(t, transfer1.errAbort)
 	assert.Nil(t, transfer2.errAbort)
@@ -286,7 +285,7 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
 	}
 	transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
 		"/file1", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
-	transfer1.BytesReceived = 150
+	transfer1.BytesReceived.Store(150)
 	err = Connections.Add(fakeConn1)
 	assert.NoError(t, err)
 	// the transferschecker will do nothing if there is only one ongoing transfer
@@ -300,26 +299,26 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
 	}
 	transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
 		"/file2", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
-	transfer2.BytesReceived = 150
+	transfer2.BytesReceived.Store(150)
 	err = Connections.Add(fakeConn2)
 	assert.NoError(t, err)
 	Connections.checkTransfers()
 	assert.Nil(t, transfer1.errAbort)
 	assert.Nil(t, transfer2.errAbort)
 	// now test overquota
-	transfer1.BytesReceived = 1024*1024 + 1
-	transfer2.BytesReceived = 0
+	transfer1.BytesReceived.Store(1024*1024 + 1)
+	transfer2.BytesReceived.Store(0)
 	Connections.checkTransfers()
 	assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort))
 	assert.Nil(t, transfer2.errAbort)
 	transfer1.errAbort = nil
-	transfer1.BytesReceived = 1024*1024 + 1
-	transfer2.BytesReceived = 1024
+	transfer1.BytesReceived.Store(1024*1024 + 1)
+	transfer2.BytesReceived.Store(1024)
 	Connections.checkTransfers()
 	assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort))
 	assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort))
-	transfer1.BytesReceived = 0
-	transfer2.BytesReceived = 0
+	transfer1.BytesReceived.Store(0)
+	transfer2.BytesReceived.Store(0)
 	transfer1.errAbort = nil
 	transfer2.errAbort = nil
 
@@ -337,7 +336,7 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
 	}
 	transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
 		"/file1", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
-	transfer3.BytesSent = 150
+	transfer3.BytesSent.Store(150)
 	err = Connections.Add(fakeConn3)
 	assert.NoError(t, err)
 
@@ -348,15 +347,15 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
 	}
 	transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
 		"/file2", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
-	transfer4.BytesSent = 150
+	transfer4.BytesSent.Store(150)
 	err = Connections.Add(fakeConn4)
 	assert.NoError(t, err)
 	Connections.checkTransfers()
 	assert.Nil(t, transfer3.errAbort)
 	assert.Nil(t, transfer4.errAbort)
 
-	transfer3.BytesSent = 512 * 1024
-	transfer4.BytesSent = 512*1024 + 1
+	transfer3.BytesSent.Store(512 * 1024)
+	transfer4.BytesSent.Store(512*1024 + 1)
 	Connections.checkTransfers()
 	if assert.Error(t, transfer3.errAbort) {
 		assert.Contains(t, transfer3.errAbort.Error(), ErrReadQuotaExceeded.Error())

+ 4 - 4
internal/dataprovider/dataprovider.go

@@ -155,7 +155,7 @@ var (
 	ErrInvalidCredentials = errors.New("invalid credentials")
 	// ErrLoginNotAllowedFromIP defines the error to return if login is denied from the current IP
 	ErrLoginNotAllowedFromIP = errors.New("login is not allowed from this IP")
-	isAdminCreated           = int32(0)
+	isAdminCreated           atomic.Bool
 	validTLSUsernames        = []string{string(sdk.TLSUsernameNone), string(sdk.TLSUsernameCN)}
 	config                   Config
 	provider                 Provider
@@ -844,7 +844,7 @@ func Initialize(cnf Config, basePath string, checkAdmins bool) error {
 	if err != nil {
 		return err
 	}
-	atomic.StoreInt32(&isAdminCreated, int32(len(admins)))
+	isAdminCreated.Store(len(admins) > 0)
 	delayedQuotaUpdater.start()
 	return startScheduler()
 }
@@ -1722,7 +1722,7 @@ func UpdateTaskTimestamp(name string) error {
 // HasAdmin returns true if the first admin has been created
 // and so SFTPGo is ready to be used
 func HasAdmin() bool {
-	return atomic.LoadInt32(&isAdminCreated) > 0
+	return isAdminCreated.Load()
 }
 
 // AddAdmin adds a new SFTPGo admin
@@ -1734,7 +1734,7 @@ func AddAdmin(admin *Admin, executor, ipAddress string) error {
 	admin.Username = config.convertName(admin.Username)
 	err := provider.addAdmin(admin)
 	if err == nil {
-		atomic.StoreInt32(&isAdminCreated, 1)
+		isAdminCreated.Store(true)
 		executeAction(operationAdd, executor, ipAddress, actionObjectAdmin, admin.Username, admin)
 	}
 	return err

+ 12 - 12
internal/dataprovider/scheduler.go

@@ -28,11 +28,11 @@ import (
 
 var (
 	scheduler           *cron.Cron
-	lastUserCacheUpdate int64
+	lastUserCacheUpdate atomic.Int64
 	// used for bolt and memory providers, so we avoid iterating all users/rules
 	// to find recently modified ones
-	lastUserUpdate int64
-	lastRuleUpdate int64
+	lastUserUpdate atomic.Int64
+	lastRuleUpdate atomic.Int64
 )
 
 func stopScheduler() {
@@ -62,7 +62,7 @@ func startScheduler() error {
 }
 
 func addScheduledCacheUpdates() error {
-	lastUserCacheUpdate = util.GetTimeAsMsSinceEpoch(time.Now())
+	lastUserCacheUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now()))
 	_, err := scheduler.AddFunc("@every 10m", checkCacheUpdates)
 	if err != nil {
 		return fmt.Errorf("unable to schedule cache updates: %w", err)
@@ -79,9 +79,9 @@ func checkDataprovider() {
 }
 
 func checkCacheUpdates() {
-	providerLog(logger.LevelDebug, "start user cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate))
+	providerLog(logger.LevelDebug, "start user cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate.Load()))
 	checkTime := util.GetTimeAsMsSinceEpoch(time.Now())
-	users, err := provider.getRecentlyUpdatedUsers(lastUserCacheUpdate)
+	users, err := provider.getRecentlyUpdatedUsers(lastUserCacheUpdate.Load())
 	if err != nil {
 		providerLog(logger.LevelError, "unable to get recently updated users: %v", err)
 		return
@@ -102,22 +102,22 @@ func checkCacheUpdates() {
 		cachedPasswords.Remove(user.Username)
 	}
 
-	lastUserCacheUpdate = checkTime
-	providerLog(logger.LevelDebug, "end user cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate))
+	lastUserCacheUpdate.Store(checkTime)
+	providerLog(logger.LevelDebug, "end user cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate.Load()))
 }
 
 func setLastUserUpdate() {
-	atomic.StoreInt64(&lastUserUpdate, util.GetTimeAsMsSinceEpoch(time.Now()))
+	lastUserUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now()))
 }
 
 func getLastUserUpdate() int64 {
-	return atomic.LoadInt64(&lastUserUpdate)
+	return lastUserUpdate.Load()
 }
 
 func setLastRuleUpdate() {
-	atomic.StoreInt64(&lastRuleUpdate, util.GetTimeAsMsSinceEpoch(time.Now()))
+	lastRuleUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now()))
 }
 
 func getLastRuleUpdate() int64 {
-	return atomic.LoadInt64(&lastRuleUpdate)
+	return lastRuleUpdate.Load()
 }

+ 2 - 3
internal/ftpd/transfer.go

@@ -17,7 +17,6 @@ package ftpd
 import (
 	"errors"
 	"io"
-	"sync/atomic"
 
 	"github.com/eikenb/pipeat"
 
@@ -61,7 +60,7 @@ func (t *transfer) Read(p []byte) (n int, err error) {
 	t.Connection.UpdateLastActivity()
 
 	n, err = t.reader.Read(p)
-	atomic.AddInt64(&t.BytesSent, int64(n))
+	t.BytesSent.Add(int64(n))
 
 	if err == nil {
 		err = t.CheckRead()
@@ -79,7 +78,7 @@ func (t *transfer) Write(p []byte) (n int, err error) {
 	t.Connection.UpdateLastActivity()
 
 	n, err = t.writer.Write(p)
-	atomic.AddInt64(&t.BytesReceived, int64(n))
+	t.BytesReceived.Add(int64(n))
 
 	if err == nil {
 		err = t.CheckWrite()

+ 4 - 5
internal/httpd/file.go

@@ -16,7 +16,6 @@ package httpd
 
 import (
 	"io"
-	"sync/atomic"
 
 	"github.com/eikenb/pipeat"
 
@@ -52,7 +51,7 @@ func newHTTPDFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter,
 
 // Read reads the contents to downloads.
 func (f *httpdFile) Read(p []byte) (n int, err error) {
-	if atomic.LoadInt32(&f.AbortTransfer) == 1 {
+	if f.AbortTransfer.Load() {
 		err := f.GetAbortError()
 		f.TransferError(err)
 		return 0, err
@@ -61,7 +60,7 @@ func (f *httpdFile) Read(p []byte) (n int, err error) {
 	f.Connection.UpdateLastActivity()
 
 	n, err = f.reader.Read(p)
-	atomic.AddInt64(&f.BytesSent, int64(n))
+	f.BytesSent.Add(int64(n))
 
 	if err == nil {
 		err = f.CheckRead()
@@ -76,7 +75,7 @@ func (f *httpdFile) Read(p []byte) (n int, err error) {
 
 // Write writes the contents to upload
 func (f *httpdFile) Write(p []byte) (n int, err error) {
-	if atomic.LoadInt32(&f.AbortTransfer) == 1 {
+	if f.AbortTransfer.Load() {
 		err := f.GetAbortError()
 		f.TransferError(err)
 		return 0, err
@@ -85,7 +84,7 @@ func (f *httpdFile) Write(p []byte) (n int, err error) {
 	f.Connection.UpdateLastActivity()
 
 	n, err = f.writer.Write(p)
-	atomic.AddInt64(&f.BytesReceived, int64(n))
+	f.BytesReceived.Add(int64(n))
 
 	if err == nil {
 		err = f.CheckWrite()

+ 15 - 15
internal/httpd/handler.go

@@ -238,24 +238,24 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
 
 func newThrottledReader(r io.ReadCloser, limit int64, conn *Connection) *throttledReader {
 	t := &throttledReader{
-		bytesRead:     0,
-		id:            conn.GetTransferID(),
-		limit:         limit,
-		r:             r,
-		abortTransfer: 0,
-		start:         time.Now(),
-		conn:          conn,
+		id:    conn.GetTransferID(),
+		limit: limit,
+		r:     r,
+		start: time.Now(),
+		conn:  conn,
 	}
+	t.bytesRead.Store(0)
+	t.abortTransfer.Store(false)
 	conn.AddTransfer(t)
 	return t
 }
 
 type throttledReader struct {
-	bytesRead     int64
+	bytesRead     atomic.Int64
 	id            int64
 	limit         int64
 	r             io.ReadCloser
-	abortTransfer int32
+	abortTransfer atomic.Bool
 	start         time.Time
 	conn          *Connection
 	mu            sync.Mutex
@@ -271,7 +271,7 @@ func (t *throttledReader) GetType() int {
 }
 
 func (t *throttledReader) GetSize() int64 {
-	return atomic.LoadInt64(&t.bytesRead)
+	return t.bytesRead.Load()
 }
 
 func (t *throttledReader) GetDownloadedSize() int64 {
@@ -279,7 +279,7 @@ func (t *throttledReader) GetDownloadedSize() int64 {
 }
 
 func (t *throttledReader) GetUploadedSize() int64 {
-	return atomic.LoadInt64(&t.bytesRead)
+	return t.bytesRead.Load()
 }
 
 func (t *throttledReader) GetVirtualPath() string {
@@ -304,7 +304,7 @@ func (t *throttledReader) SignalClose(err error) {
 	t.mu.Lock()
 	t.errAbort = err
 	t.mu.Unlock()
-	atomic.StoreInt32(&(t.abortTransfer), 1)
+	t.abortTransfer.Store(true)
 }
 
 func (t *throttledReader) GetTruncatedSize() int64 {
@@ -328,15 +328,15 @@ func (t *throttledReader) SetTimes(fsPath string, atime time.Time, mtime time.Ti
 }
 
 func (t *throttledReader) Read(p []byte) (n int, err error) {
-	if atomic.LoadInt32(&t.abortTransfer) == 1 {
+	if t.abortTransfer.Load() {
 		return 0, t.GetAbortError()
 	}
 
 	t.conn.UpdateLastActivity()
 	n, err = t.r.Read(p)
 	if t.limit > 0 {
-		atomic.AddInt64(&t.bytesRead, int64(n))
-		trasferredBytes := atomic.LoadInt64(&t.bytesRead)
+		t.bytesRead.Add(int64(n))
+		trasferredBytes := t.bytesRead.Load()
 		elapsed := time.Since(t.start).Nanoseconds() / 1000000
 		wantedElapsed := 1000 * (trasferredBytes / 1024) / t.limit
 		if wantedElapsed > elapsed {

+ 9 - 9
internal/plugin/plugin.go

@@ -93,7 +93,7 @@ func (c *Config) newKMSPluginSecretProvider(base kms.BaseSecret, url, masterKey
 
 // Manager handles enabled plugins
 type Manager struct {
-	closed int32
+	closed atomic.Bool
 	done   chan bool
 	// List of configured plugins
 	Configs          []Config `json:"plugins" mapstructure:"plugins"`
@@ -124,10 +124,10 @@ func Initialize(configs []Config, logLevel string) error {
 	Handler = Manager{
 		Configs:          configs,
 		done:             make(chan bool),
-		closed:           0,
 		authScopes:       -1,
 		concurrencyGuard: make(chan struct{}, 250),
 	}
+	Handler.closed.Store(false)
 	setLogLevel(logLevel)
 	if len(configs) == 0 {
 		return nil
@@ -604,7 +604,7 @@ func (m *Manager) checkCrashedPlugins() {
 }
 
 func (m *Manager) restartNotifierPlugin(config Config, idx int) {
-	if atomic.LoadInt32(&m.closed) == 1 {
+	if m.closed.Load() {
 		return
 	}
 	logger.Info(logSender, "", "try to restart crashed notifier plugin %#v, idx: %v", config.Cmd, idx)
@@ -622,7 +622,7 @@ func (m *Manager) restartNotifierPlugin(config Config, idx int) {
 }
 
 func (m *Manager) restartKMSPlugin(config Config, idx int) {
-	if atomic.LoadInt32(&m.closed) == 1 {
+	if m.closed.Load() {
 		return
 	}
 	logger.Info(logSender, "", "try to restart crashed kms plugin %#v, idx: %v", config.Cmd, idx)
@@ -638,7 +638,7 @@ func (m *Manager) restartKMSPlugin(config Config, idx int) {
 }
 
 func (m *Manager) restartAuthPlugin(config Config, idx int) {
-	if atomic.LoadInt32(&m.closed) == 1 {
+	if m.closed.Load() {
 		return
 	}
 	logger.Info(logSender, "", "try to restart crashed auth plugin %#v, idx: %v", config.Cmd, idx)
@@ -654,7 +654,7 @@ func (m *Manager) restartAuthPlugin(config Config, idx int) {
 }
 
 func (m *Manager) restartSearcherPlugin(config Config) {
-	if atomic.LoadInt32(&m.closed) == 1 {
+	if m.closed.Load() {
 		return
 	}
 	logger.Info(logSender, "", "try to restart crashed searcher plugin %#v", config.Cmd)
@@ -670,7 +670,7 @@ func (m *Manager) restartSearcherPlugin(config Config) {
 }
 
 func (m *Manager) restartMetadaterPlugin(config Config) {
-	if atomic.LoadInt32(&m.closed) == 1 {
+	if m.closed.Load() {
 		return
 	}
 	logger.Info(logSender, "", "try to restart crashed metadater plugin %#v", config.Cmd)
@@ -686,7 +686,7 @@ func (m *Manager) restartMetadaterPlugin(config Config) {
 }
 
 func (m *Manager) restartIPFilterPlugin(config Config) {
-	if atomic.LoadInt32(&m.closed) == 1 {
+	if m.closed.Load() {
 		return
 	}
 	logger.Info(logSender, "", "try to restart crashed IP filter plugin %#v", config.Cmd)
@@ -712,7 +712,7 @@ func (m *Manager) removeTask() {
 // Cleanup releases all the active plugins
 func (m *Manager) Cleanup() {
 	logger.Debug(logSender, "", "cleanup")
-	atomic.StoreInt32(&m.closed, 1)
+	m.closed.Store(true)
 	close(m.done)
 	m.notifLock.Lock()
 	for _, n := range m.notifiers {

+ 1 - 1
internal/sftpd/internal_test.go

@@ -1785,7 +1785,7 @@ func TestUploadError(t *testing.T) {
 	if assert.Error(t, transfer.ErrTransfer) {
 		assert.EqualError(t, transfer.ErrTransfer, errFake.Error())
 	}
-	assert.Equal(t, int64(0), transfer.BytesReceived)
+	assert.Equal(t, int64(0), transfer.BytesReceived.Load())
 
 	assert.NoFileExists(t, testfile)
 	assert.NoFileExists(t, fileTempName)

+ 3 - 3
internal/sftpd/sftpd_test.go

@@ -1114,12 +1114,12 @@ func TestConcurrency(t *testing.T) {
 	err = createTestFile(testFilePath, testFileSize)
 	assert.NoError(t, err)
 
-	closedConns := int32(0)
+	var closedConns atomic.Int32
 	for i := 0; i < numLogins; i++ {
 		wg.Add(1)
 		go func(counter int) {
 			defer wg.Done()
-			defer atomic.AddInt32(&closedConns, 1)
+			defer closedConns.Add(1)
 
 			conn, client, err := getSftpClient(user, usePubKey)
 			if assert.NoError(t, err) {
@@ -1139,7 +1139,7 @@ func TestConcurrency(t *testing.T) {
 		maxConns := 0
 		maxSessions := 0
 		for {
-			servedReqs := atomic.LoadInt32(&closedConns)
+			servedReqs := closedConns.Load()
 			if servedReqs > 0 {
 				stats := common.Connections.GetStats()
 				if len(stats) > maxConns {

+ 5 - 6
internal/sftpd/transfer.go

@@ -17,7 +17,6 @@ package sftpd
 import (
 	"fmt"
 	"io"
-	"sync/atomic"
 
 	"github.com/eikenb/pipeat"
 
@@ -107,7 +106,7 @@ func (t *transfer) ReadAt(p []byte, off int64) (n int, err error) {
 	t.Connection.UpdateLastActivity()
 
 	n, err = t.readerAt.ReadAt(p, off)
-	atomic.AddInt64(&t.BytesSent, int64(n))
+	t.BytesSent.Add(int64(n))
 
 	if err == nil {
 		err = t.CheckRead()
@@ -133,7 +132,7 @@ func (t *transfer) WriteAt(p []byte, off int64) (n int, err error) {
 	}
 
 	n, err = t.writerAt.WriteAt(p, off)
-	atomic.AddInt64(&t.BytesReceived, int64(n))
+	t.BytesReceived.Add(int64(n))
 
 	if err == nil {
 		err = t.CheckWrite()
@@ -213,13 +212,13 @@ func (t *transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64,
 			if nw > 0 {
 				written += int64(nw)
 				if isDownload {
-					atomic.StoreInt64(&t.BytesSent, written)
+					t.BytesSent.Store(written)
 					if errCheck := t.CheckRead(); errCheck != nil {
 						err = errCheck
 						break
 					}
 				} else {
-					atomic.StoreInt64(&t.BytesReceived, written)
+					t.BytesReceived.Store(written)
 					if errCheck := t.CheckWrite(); errCheck != nil {
 						err = errCheck
 						break
@@ -245,7 +244,7 @@ func (t *transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64,
 	}
 	t.ErrTransfer = err
 	if written > 0 || err != nil {
-		metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.GetType(),
+		metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(), t.GetType(),
 			t.ErrTransfer, vfs.IsSFTPFs(t.Fs))
 	}
 	return written, err

+ 15 - 15
internal/util/timeoutlistener.go

@@ -32,14 +32,14 @@ func (l *listener) Accept() (net.Conn, error) {
 		return nil, err
 	}
 	tc := &Conn{
-		Conn:                     c,
-		ReadTimeout:              l.ReadTimeout,
-		WriteTimeout:             l.WriteTimeout,
-		ReadThreshold:            int32((l.ReadTimeout * 1024) / time.Second),
-		WriteThreshold:           int32((l.WriteTimeout * 1024) / time.Second),
-		BytesReadFromDeadline:    0,
-		BytesWrittenFromDeadline: 0,
+		Conn:           c,
+		ReadTimeout:    l.ReadTimeout,
+		WriteTimeout:   l.WriteTimeout,
+		ReadThreshold:  int32((l.ReadTimeout * 1024) / time.Second),
+		WriteThreshold: int32((l.WriteTimeout * 1024) / time.Second),
 	}
+	tc.BytesReadFromDeadline.Store(0)
+	tc.BytesWrittenFromDeadline.Store(0)
 	return tc, nil
 }
 
@@ -51,13 +51,13 @@ type Conn struct {
 	WriteTimeout             time.Duration
 	ReadThreshold            int32
 	WriteThreshold           int32
-	BytesReadFromDeadline    int32
-	BytesWrittenFromDeadline int32
+	BytesReadFromDeadline    atomic.Int32
+	BytesWrittenFromDeadline atomic.Int32
 }
 
 func (c *Conn) Read(b []byte) (n int, err error) {
-	if atomic.LoadInt32(&c.BytesReadFromDeadline) > c.ReadThreshold {
-		atomic.StoreInt32(&c.BytesReadFromDeadline, 0)
+	if c.BytesReadFromDeadline.Load() > c.ReadThreshold {
+		c.BytesReadFromDeadline.Store(0)
 		// we set both read and write deadlines here otherwise after the request
 		// is read writing the response fails with an i/o timeout error
 		err = c.Conn.SetDeadline(time.Now().Add(c.ReadTimeout))
@@ -66,13 +66,13 @@ func (c *Conn) Read(b []byte) (n int, err error) {
 		}
 	}
 	n, err = c.Conn.Read(b)
-	atomic.AddInt32(&c.BytesReadFromDeadline, int32(n))
+	c.BytesReadFromDeadline.Add(int32(n))
 	return
 }
 
 func (c *Conn) Write(b []byte) (n int, err error) {
-	if atomic.LoadInt32(&c.BytesWrittenFromDeadline) > c.WriteThreshold {
-		atomic.StoreInt32(&c.BytesWrittenFromDeadline, 0)
+	if c.BytesWrittenFromDeadline.Load() > c.WriteThreshold {
+		c.BytesWrittenFromDeadline.Store(0)
 		// we extend the read deadline too, not sure it's necessary,
 		// but it doesn't hurt
 		err = c.Conn.SetDeadline(time.Now().Add(c.WriteTimeout))
@@ -81,7 +81,7 @@ func (c *Conn) Write(b []byte) (n int, err error) {
 		}
 	}
 	n, err = c.Conn.Write(b)
-	atomic.AddInt32(&c.BytesWrittenFromDeadline, int32(n))
+	c.BytesWrittenFromDeadline.Add(int32(n))
 	return
 }
 

+ 6 - 6
internal/vfs/azblobfs.go

@@ -902,7 +902,7 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *a
 	finished := false
 	var wg sync.WaitGroup
 	var errOnce sync.Once
-	var hasError int32
+	var hasError atomic.Bool
 	var poolError error
 
 	poolCtx, poolCancel := context.WithCancel(ctx)
@@ -919,7 +919,7 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *a
 		offset = end
 
 		guard <- struct{}{}
-		if atomic.LoadInt32(&hasError) == 1 {
+		if hasError.Load() {
 			fsLog(fs, logger.LevelDebug, "pool error, download for part %v not started", part)
 			break
 		}
@@ -941,7 +941,7 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *a
 			if err != nil {
 				errOnce.Do(func() {
 					fsLog(fs, logger.LevelError, "multipart download error: %+v", err)
-					atomic.StoreInt32(&hasError, 1)
+					hasError.Store(true)
 					poolError = fmt.Errorf("multipart download error: %w", err)
 					poolCancel()
 				})
@@ -971,7 +971,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
 	var blocks []string
 	var wg sync.WaitGroup
 	var errOnce sync.Once
-	var hasError int32
+	var hasError atomic.Bool
 	var poolError error
 
 	poolCtx, poolCancel := context.WithCancel(ctx)
@@ -999,7 +999,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
 		blocks = append(blocks, blockID)
 
 		guard <- struct{}{}
-		if atomic.LoadInt32(&hasError) == 1 {
+		if hasError.Load() {
 			fsLog(fs, logger.LevelError, "pool error, upload for part %v not started", part)
 			pool.releaseBuffer(buf)
 			break
@@ -1023,7 +1023,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
 			if err != nil {
 				errOnce.Do(func() {
 					fsLog(fs, logger.LevelDebug, "multipart upload error: %+v", err)
-					atomic.StoreInt32(&hasError, 1)
+					hasError.Store(true)
 					poolError = fmt.Errorf("multipart upload error: %w", err)
 					poolCancel()
 				})

+ 3 - 3
internal/vfs/s3fs.go

@@ -835,7 +835,7 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int
 	var completedParts []types.CompletedPart
 	var partMutex sync.Mutex
 	var wg sync.WaitGroup
-	var hasError int32
+	var hasError atomic.Bool
 	var errOnce sync.Once
 	var copyError error
 	var partNumber int32
@@ -854,7 +854,7 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int
 		offset = end
 
 		guard <- struct{}{}
-		if atomic.LoadInt32(&hasError) == 1 {
+		if hasError.Load() {
 			fsLog(fs, logger.LevelDebug, "previous multipart copy error, copy for part %d not started", partNumber)
 			break
 		}
@@ -880,7 +880,7 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int
 			if err != nil {
 				errOnce.Do(func() {
 					fsLog(fs, logger.LevelError, "unable to copy part number %d: %+v", partNum, err)
-					atomic.StoreInt32(&hasError, 1)
+					hasError.Store(true)
 					copyError = fmt.Errorf("error copying part number %d: %w", partNum, err)
 					opCancel()
 

+ 16 - 15
internal/webdavd/file.go

@@ -42,7 +42,7 @@ type webDavFile struct {
 	info        os.FileInfo
 	startOffset int64
 	isFinished  bool
-	readTryed   int32
+	readTryed   atomic.Bool
 }
 
 func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, pipeReader *pipeat.PipeReaderAt) *webDavFile {
@@ -56,15 +56,16 @@ func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter
 	} else if pipeReader != nil {
 		reader = pipeReader
 	}
-	return &webDavFile{
+	f := &webDavFile{
 		BaseTransfer: baseTransfer,
 		writer:       writer,
 		reader:       reader,
 		isFinished:   false,
 		startOffset:  0,
 		info:         nil,
-		readTryed:    0,
 	}
+	f.readTryed.Store(false)
+	return f
 }
 
 type webDavFileInfo struct {
@@ -124,7 +125,7 @@ func (f *webDavFile) Stat() (os.FileInfo, error) {
 	f.Unlock()
 	if f.GetType() == common.TransferUpload && errUpload == nil {
 		info := &webDavFileInfo{
-			FileInfo:    vfs.NewFileInfo(f.GetFsPath(), false, atomic.LoadInt64(&f.BytesReceived), time.Unix(0, 0), false),
+			FileInfo:    vfs.NewFileInfo(f.GetFsPath(), false, f.BytesReceived.Load(), time.Unix(0, 0), false),
 			Fs:          f.Fs,
 			virtualPath: f.GetVirtualPath(),
 			fsPath:      f.GetFsPath(),
@@ -149,10 +150,10 @@ func (f *webDavFile) Stat() (os.FileInfo, error) {
 
 // Read reads the contents to downloads.
 func (f *webDavFile) Read(p []byte) (n int, err error) {
-	if atomic.LoadInt32(&f.AbortTransfer) == 1 {
+	if f.AbortTransfer.Load() {
 		return 0, errTransferAborted
 	}
-	if atomic.LoadInt32(&f.readTryed) == 0 {
+	if !f.readTryed.Load() {
 		if !f.Connection.User.HasPerm(dataprovider.PermDownload, path.Dir(f.GetVirtualPath())) {
 			return 0, f.Connection.GetPermissionDeniedError()
 		}
@@ -171,7 +172,7 @@ func (f *webDavFile) Read(p []byte) (n int, err error) {
 			f.Connection.Log(logger.LevelDebug, "download for file %#v denied by pre action: %v", f.GetVirtualPath(), err)
 			return 0, f.Connection.GetPermissionDeniedError()
 		}
-		atomic.StoreInt32(&f.readTryed, 1)
+		f.readTryed.Store(true)
 	}
 
 	f.Connection.UpdateLastActivity()
@@ -198,7 +199,7 @@ func (f *webDavFile) Read(p []byte) (n int, err error) {
 	}
 
 	n, err = f.reader.Read(p)
-	atomic.AddInt64(&f.BytesSent, int64(n))
+	f.BytesSent.Add(int64(n))
 	if err == nil {
 		err = f.CheckRead()
 	}
@@ -212,14 +213,14 @@ func (f *webDavFile) Read(p []byte) (n int, err error) {
 
 // Write writes the uploaded contents.
 func (f *webDavFile) Write(p []byte) (n int, err error) {
-	if atomic.LoadInt32(&f.AbortTransfer) == 1 {
+	if f.AbortTransfer.Load() {
 		return 0, errTransferAborted
 	}
 
 	f.Connection.UpdateLastActivity()
 
 	n, err = f.writer.Write(p)
-	atomic.AddInt64(&f.BytesReceived, int64(n))
+	f.BytesReceived.Add(int64(n))
 
 	if err == nil {
 		err = f.CheckWrite()
@@ -252,7 +253,7 @@ func (f *webDavFile) updateTransferQuotaOnSeek() {
 	if transferQuota.HasSizeLimits() {
 		go func(ulSize, dlSize int64, user dataprovider.User) {
 			dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
-		}(atomic.LoadInt64(&f.BytesReceived), atomic.LoadInt64(&f.BytesSent), f.Connection.User)
+		}(f.BytesReceived.Load(), f.BytesSent.Load(), f.Connection.User)
 	}
 }
 
@@ -270,7 +271,7 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) {
 		return ret, err
 	}
 	if f.GetType() == common.TransferDownload {
-		readOffset := f.startOffset + atomic.LoadInt64(&f.BytesSent)
+		readOffset := f.startOffset + f.BytesSent.Load()
 		if offset == 0 && readOffset == 0 {
 			if whence == io.SeekStart {
 				return 0, nil
@@ -288,8 +289,8 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) {
 			f.reader = nil
 		}
 		startByte := int64(0)
-		atomic.StoreInt64(&f.BytesReceived, 0)
-		atomic.StoreInt64(&f.BytesSent, 0)
+		f.BytesReceived.Store(0)
+		f.BytesSent.Store(0)
 		f.updateTransferQuotaOnSeek()
 
 		switch whence {
@@ -369,7 +370,7 @@ func (f *webDavFile) setFinished() error {
 
 func (f *webDavFile) isTransfer() bool {
 	if f.GetType() == common.TransferDownload {
-		return atomic.LoadInt32(&f.readTryed) > 0
+		return f.readTryed.Load()
 	}
 	return true
 }