use the new atomic types introduced in Go 1.19

we depend on Go 1.19 anyway

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2022-08-30 15:47:41 +02:00
parent da03f6c4e3
commit 95e9106902
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
22 changed files with 231 additions and 231 deletions

View file

@ -23,13 +23,13 @@ import (
// clienstMap is a struct containing the map of the connected clients
type clientsMap struct {
totalConnections int32
totalConnections atomic.Int32
mu sync.RWMutex
clients map[string]int
}
func (c *clientsMap) add(source string) {
atomic.AddInt32(&c.totalConnections, 1)
c.totalConnections.Add(1)
c.mu.Lock()
defer c.mu.Unlock()
@ -42,7 +42,7 @@ func (c *clientsMap) remove(source string) {
defer c.mu.Unlock()
if val, ok := c.clients[source]; ok {
atomic.AddInt32(&c.totalConnections, -1)
c.totalConnections.Add(-1)
c.clients[source]--
if val > 1 {
return
@ -54,7 +54,7 @@ func (c *clientsMap) remove(source string) {
}
func (c *clientsMap) getTotal() int32 {
return atomic.LoadInt32(&c.totalConnections)
return c.totalConnections.Load()
}
func (c *clientsMap) getTotalFrom(source string) int {

View file

@ -704,16 +704,17 @@ func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error {
type SSHConnection struct {
id string
conn net.Conn
lastActivity int64
lastActivity atomic.Int64
}
// NewSSHConnection returns a new SSHConnection
func NewSSHConnection(id string, conn net.Conn) *SSHConnection {
return &SSHConnection{
id: id,
conn: conn,
lastActivity: time.Now().UnixNano(),
c := &SSHConnection{
id: id,
conn: conn,
}
c.lastActivity.Store(time.Now().UnixNano())
return c
}
// GetID returns the ID for this SSHConnection
@ -723,12 +724,12 @@ func (c *SSHConnection) GetID() string {
// UpdateLastActivity updates last activity for this connection
func (c *SSHConnection) UpdateLastActivity() {
atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano())
c.lastActivity.Store(time.Now().UnixNano())
}
// GetLastActivity returns the last connection activity
func (c *SSHConnection) GetLastActivity() time.Time {
return time.Unix(0, atomic.LoadInt64(&c.lastActivity))
return time.Unix(0, c.lastActivity.Load())
}
// Close closes the underlying network connection
@ -741,7 +742,7 @@ type ActiveConnections struct {
// clients contains both authenticated and estabilished connections and the ones waiting
// for authentication
clients clientsMap
transfersCheckStatus int32
transfersCheckStatus atomic.Bool
sync.RWMutex
connections []ActiveConnection
sshConnections []*SSHConnection
@ -953,12 +954,12 @@ func (conns *ActiveConnections) checkIdles() {
}
func (conns *ActiveConnections) checkTransfers() {
if atomic.LoadInt32(&conns.transfersCheckStatus) == 1 {
if conns.transfersCheckStatus.Load() {
logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution")
return
}
atomic.StoreInt32(&conns.transfersCheckStatus, 1)
defer atomic.StoreInt32(&conns.transfersCheckStatus, 0)
conns.transfersCheckStatus.Store(true)
defer conns.transfersCheckStatus.Store(false)
conns.RLock()

View file

@ -24,7 +24,6 @@ import (
"path/filepath"
"runtime"
"strings"
"sync/atomic"
"testing"
"time"
@ -498,14 +497,14 @@ func TestIdleConnections(t *testing.T) {
},
}
c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, "", "", user)
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano())
fakeConn := &fakeConnection{
BaseConnection: c,
}
// both ssh connections are expired but they should get removed only
// if there is no associated connection
sshConn1.lastActivity = c.lastActivity
sshConn2.lastActivity = c.lastActivity
sshConn1.lastActivity.Store(c.lastActivity.Load())
sshConn2.lastActivity.Store(c.lastActivity.Load())
Connections.AddSSHConnection(sshConn1)
err = Connections.Add(fakeConn)
assert.NoError(t, err)
@ -520,7 +519,7 @@ func TestIdleConnections(t *testing.T) {
assert.Equal(t, Connections.GetActiveSessions(username), 2)
cFTP := NewBaseConnection("id2", ProtocolFTP, "", "", dataprovider.User{})
cFTP.lastActivity = time.Now().UnixNano()
cFTP.lastActivity.Store(time.Now().UnixNano())
fakeConn = &fakeConnection{
BaseConnection: cFTP,
}
@ -541,9 +540,9 @@ func TestIdleConnections(t *testing.T) {
}, 1*time.Second, 200*time.Millisecond)
stopEventScheduler()
assert.Len(t, Connections.GetStats(), 2)
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
sshConn2.lastActivity = c.lastActivity
c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano())
cFTP.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano())
sshConn2.lastActivity.Store(c.lastActivity.Load())
startPeriodicChecks(100 * time.Millisecond)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 2*time.Second, 200*time.Millisecond)
assert.Eventually(t, func() bool {
@ -646,9 +645,9 @@ func TestConnectionStatus(t *testing.T) {
BaseConnection: c1,
}
t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
t1.BytesReceived = 123
t1.BytesReceived.Store(123)
t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
t2.BytesSent = 456
t2.BytesSent.Store(456)
c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user)
fakeConn2 := &fakeConnection{
BaseConnection: c2,
@ -698,7 +697,7 @@ func TestConnectionStatus(t *testing.T) {
err = fakeConn3.SignalTransfersAbort()
assert.NoError(t, err)
assert.Equal(t, int32(1), atomic.LoadInt32(&t3.AbortTransfer))
assert.True(t, t3.AbortTransfer.Load())
err = t3.Close()
assert.NoError(t, err)
err = fakeConn3.SignalTransfersAbort()

View file

@ -38,12 +38,12 @@ import (
type BaseConnection struct {
// last activity for this connection.
// Since this field is accessed atomically we put it as first element of the struct to achieve 64 bit alignment
lastActivity int64
lastActivity atomic.Int64
uploadDone atomic.Bool
downloadDone atomic.Bool
// unique ID for a transfer.
// This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment
transferID int64
transferID atomic.Int64
// Unique identifier for the connection
ID string
// user associated with this connection if any
@ -64,16 +64,18 @@ func NewBaseConnection(id, protocol, localAddr, remoteAddr string, user dataprov
connID = fmt.Sprintf("%s_%s", protocol, id)
}
user.UploadBandwidth, user.DownloadBandwidth = user.GetBandwidthForIP(util.GetIPFromRemoteAddress(remoteAddr), connID)
return &BaseConnection{
ID: connID,
User: user,
startTime: time.Now(),
protocol: protocol,
localAddr: localAddr,
remoteAddr: remoteAddr,
lastActivity: time.Now().UnixNano(),
transferID: 0,
c := &BaseConnection{
ID: connID,
User: user,
startTime: time.Now(),
protocol: protocol,
localAddr: localAddr,
remoteAddr: remoteAddr,
}
c.transferID.Store(0)
c.lastActivity.Store(time.Now().UnixNano())
return c
}
// Log outputs a log entry to the configured logger
@ -83,7 +85,7 @@ func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...any) {
// GetTransferID returns an unique transfer ID for this connection
func (c *BaseConnection) GetTransferID() int64 {
return atomic.AddInt64(&c.transferID, 1)
return c.transferID.Add(1)
}
// GetID returns the connection ID
@ -126,12 +128,12 @@ func (c *BaseConnection) GetConnectionTime() time.Time {
// UpdateLastActivity updates last activity for this connection
func (c *BaseConnection) UpdateLastActivity() {
atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano())
c.lastActivity.Store(time.Now().UnixNano())
}
// GetLastActivity returns the last connection activity
func (c *BaseConnection) GetLastActivity() time.Time {
return time.Unix(0, atomic.LoadInt64(&c.lastActivity))
return time.Unix(0, c.lastActivity.Load())
}
// CloseFS closes the underlying fs

View file

@ -82,7 +82,7 @@ func HandleCertificateEvent(params EventParams) {
// eventRulesContainer stores event rules by trigger
type eventRulesContainer struct {
sync.RWMutex
lastLoad int64
lastLoad atomic.Int64
FsEvents []dataprovider.EventRule
ProviderEvents []dataprovider.EventRule
Schedules []dataprovider.EventRule
@ -101,11 +101,11 @@ func (r *eventRulesContainer) removeAsyncTask() {
}
func (r *eventRulesContainer) getLastLoadTime() int64 {
return atomic.LoadInt64(&r.lastLoad)
return r.lastLoad.Load()
}
func (r *eventRulesContainer) setLastLoadTime(modTime int64) {
atomic.StoreInt64(&r.lastLoad, modTime)
r.lastLoad.Store(modTime)
}
// RemoveRule deletes the rule with the specified name

View file

@ -186,16 +186,16 @@ func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
}
type sourceRateLimiter struct {
lastActivity int64
lastActivity *atomic.Int64
bucket *rate.Limiter
}
func (s *sourceRateLimiter) updateLastActivity() {
atomic.StoreInt64(&s.lastActivity, time.Now().UnixNano())
s.lastActivity.Store(time.Now().UnixNano())
}
func (s *sourceRateLimiter) getLastActivity() int64 {
return atomic.LoadInt64(&s.lastActivity)
return s.lastActivity.Load()
}
type sourceBuckets struct {
@ -224,7 +224,8 @@ func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Rese
b.cleanup()
src := sourceRateLimiter{
bucket: r,
lastActivity: new(atomic.Int64),
bucket: r,
}
src.updateLastActivity()
b.buckets[source] = src

View file

@ -35,8 +35,8 @@ var (
// BaseTransfer contains protocols common transfer details for an upload or a download.
type BaseTransfer struct { //nolint:maligned
ID int64
BytesSent int64
BytesReceived int64
BytesSent atomic.Int64
BytesReceived atomic.Int64
Fs vfs.Fs
File vfs.File
Connection *BaseConnection
@ -52,7 +52,7 @@ type BaseTransfer struct { //nolint:maligned
truncatedSize int64
isNewFile bool
transferType int
AbortTransfer int32
AbortTransfer atomic.Bool
aTime time.Time
mTime time.Time
transferQuota dataprovider.TransferQuota
@ -79,14 +79,14 @@ func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPat
InitialSize: initialSize,
isNewFile: isNewFile,
requestPath: requestPath,
BytesSent: 0,
BytesReceived: 0,
MaxWriteSize: maxWriteSize,
AbortTransfer: 0,
truncatedSize: truncatedSize,
transferQuota: transferQuota,
Fs: fs,
}
t.AbortTransfer.Store(false)
t.BytesSent.Store(0)
t.BytesReceived.Store(0)
conn.AddTransfer(t)
return t
@ -115,19 +115,19 @@ func (t *BaseTransfer) GetType() int {
// GetSize returns the transferred size
func (t *BaseTransfer) GetSize() int64 {
if t.transferType == TransferDownload {
return atomic.LoadInt64(&t.BytesSent)
return t.BytesSent.Load()
}
return atomic.LoadInt64(&t.BytesReceived)
return t.BytesReceived.Load()
}
// GetDownloadedSize returns the transferred size
func (t *BaseTransfer) GetDownloadedSize() int64 {
return atomic.LoadInt64(&t.BytesSent)
return t.BytesSent.Load()
}
// GetUploadedSize returns the transferred size
func (t *BaseTransfer) GetUploadedSize() int64 {
return atomic.LoadInt64(&t.BytesReceived)
return t.BytesReceived.Load()
}
// GetStartTime returns the start time
@ -153,7 +153,7 @@ func (t *BaseTransfer) SignalClose(err error) {
t.Lock()
t.errAbort = err
t.Unlock()
atomic.StoreInt32(&(t.AbortTransfer), 1)
t.AbortTransfer.Store(true)
}
// GetTruncatedSize returns the truncated sized if this is an upload overwriting
@ -217,11 +217,11 @@ func (t *BaseTransfer) CheckRead() error {
return nil
}
if t.transferQuota.AllowedTotalSize > 0 {
if atomic.LoadInt64(&t.BytesSent)+atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedTotalSize {
if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize {
return t.Connection.GetReadQuotaExceededError()
}
} else if t.transferQuota.AllowedDLSize > 0 {
if atomic.LoadInt64(&t.BytesSent) > t.transferQuota.AllowedDLSize {
if t.BytesSent.Load() > t.transferQuota.AllowedDLSize {
return t.Connection.GetReadQuotaExceededError()
}
}
@ -230,18 +230,18 @@ func (t *BaseTransfer) CheckRead() error {
// CheckWrite returns an error if write if not allowed
func (t *BaseTransfer) CheckWrite() error {
if t.MaxWriteSize > 0 && atomic.LoadInt64(&t.BytesReceived) > t.MaxWriteSize {
if t.MaxWriteSize > 0 && t.BytesReceived.Load() > t.MaxWriteSize {
return t.Connection.GetQuotaExceededError()
}
if t.transferQuota.AllowedULSize == 0 && t.transferQuota.AllowedTotalSize == 0 {
return nil
}
if t.transferQuota.AllowedTotalSize > 0 {
if atomic.LoadInt64(&t.BytesSent)+atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedTotalSize {
if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize {
return t.Connection.GetQuotaExceededError()
}
} else if t.transferQuota.AllowedULSize > 0 {
if atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedULSize {
if t.BytesReceived.Load() > t.transferQuota.AllowedULSize {
return t.Connection.GetQuotaExceededError()
}
}
@ -261,14 +261,14 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) {
if t.MaxWriteSize > 0 {
sizeDiff := initialSize - size
t.MaxWriteSize += sizeDiff
metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived),
metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(),
t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs))
if t.transferQuota.HasSizeLimits() {
go func(ulSize, dlSize int64, user dataprovider.User) {
dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
}(atomic.LoadInt64(&t.BytesReceived), atomic.LoadInt64(&t.BytesSent), t.Connection.User)
}(t.BytesReceived.Load(), t.BytesSent.Load(), t.Connection.User)
}
atomic.StoreInt64(&t.BytesReceived, 0)
t.BytesReceived.Store(0)
}
t.Unlock()
}
@ -276,7 +276,7 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) {
fsPath, size, t.MaxWriteSize, t.InitialSize, err)
return initialSize, err
}
if size == 0 && atomic.LoadInt64(&t.BytesSent) == 0 {
if size == 0 && t.BytesSent.Load() == 0 {
// for cloud providers the file is always truncated to zero, we don't support append/resume for uploads
// for buffered SFTP we can have buffered bytes so we returns an error
if !vfs.IsBufferedSFTPFs(t.Fs) {
@ -302,8 +302,8 @@ func (t *BaseTransfer) TransferError(err error) {
}
elapsed := time.Since(t.start).Nanoseconds() / 1000000
t.Connection.Log(logger.LevelError, "Unexpected error for transfer, path: %#v, error: \"%v\" bytes sent: %v, "+
"bytes received: %v transfer running since %v ms", t.fsPath, t.ErrTransfer, atomic.LoadInt64(&t.BytesSent),
atomic.LoadInt64(&t.BytesReceived), elapsed)
"bytes received: %v transfer running since %v ms", t.fsPath, t.ErrTransfer, t.BytesSent.Load(),
t.BytesReceived.Load(), elapsed)
}
func (t *BaseTransfer) getUploadFileSize() (int64, error) {
@ -333,7 +333,7 @@ func (t *BaseTransfer) checkUploadOutsideHomeDir(err error) int {
t.Connection.Log(logger.LevelWarn, "upload in temp path cannot be renamed, delete temporary file: %#v, deletion error: %v",
t.effectiveFsPath, err)
// the file is outside the home dir so don't update the quota
atomic.StoreInt64(&t.BytesReceived, 0)
t.BytesReceived.Store(0)
t.MinWriteOffset = 0
return 1
}
@ -351,18 +351,18 @@ func (t *BaseTransfer) Close() error {
if t.isNewFile {
numFiles = 1
}
metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived),
metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(),
t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs))
if t.transferQuota.HasSizeLimits() {
dataprovider.UpdateUserTransferQuota(&t.Connection.User, atomic.LoadInt64(&t.BytesReceived), //nolint:errcheck
atomic.LoadInt64(&t.BytesSent), false)
dataprovider.UpdateUserTransferQuota(&t.Connection.User, t.BytesReceived.Load(), //nolint:errcheck
t.BytesSent.Load(), false)
}
if t.File != nil && t.Connection.IsQuotaExceededError(t.ErrTransfer) {
// if quota is exceeded we try to remove the partial file for uploads to local filesystem
err = t.Fs.Remove(t.File.Name(), false)
if err == nil {
numFiles--
atomic.StoreInt64(&t.BytesReceived, 0)
t.BytesReceived.Store(0)
t.MinWriteOffset = 0
}
t.Connection.Log(logger.LevelWarn, "upload denied due to space limit, delete temporary file: %#v, deletion error: %v",
@ -380,7 +380,7 @@ func (t *BaseTransfer) Close() error {
t.ErrTransfer, t.effectiveFsPath, err)
if err == nil {
numFiles--
atomic.StoreInt64(&t.BytesReceived, 0)
t.BytesReceived.Store(0)
t.MinWriteOffset = 0
}
}
@ -388,12 +388,12 @@ func (t *BaseTransfer) Close() error {
elapsed := time.Since(t.start).Nanoseconds() / 1000000
var uploadFileSize int64
if t.transferType == TransferDownload {
logger.TransferLog(downloadLogSender, t.fsPath, elapsed, atomic.LoadInt64(&t.BytesSent), t.Connection.User.Username,
logger.TransferLog(downloadLogSender, t.fsPath, elapsed, t.BytesSent.Load(), t.Connection.User.Username,
t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode)
ExecuteActionNotification(t.Connection, operationDownload, t.fsPath, t.requestPath, "", "", "", //nolint:errcheck
atomic.LoadInt64(&t.BytesSent), t.ErrTransfer)
t.BytesSent.Load(), t.ErrTransfer)
} else {
uploadFileSize = atomic.LoadInt64(&t.BytesReceived) + t.MinWriteOffset
uploadFileSize = t.BytesReceived.Load() + t.MinWriteOffset
if statSize, errStat := t.getUploadFileSize(); errStat == nil {
uploadFileSize = statSize
}
@ -401,7 +401,7 @@ func (t *BaseTransfer) Close() error {
numFiles, uploadFileSize = t.executeUploadHook(numFiles, uploadFileSize)
t.updateQuota(numFiles, uploadFileSize)
t.updateTimes()
logger.TransferLog(uploadLogSender, t.fsPath, elapsed, atomic.LoadInt64(&t.BytesReceived), t.Connection.User.Username,
logger.TransferLog(uploadLogSender, t.fsPath, elapsed, t.BytesReceived.Load(), t.Connection.User.Username,
t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode)
}
if t.ErrTransfer != nil {
@ -428,11 +428,11 @@ func (t *BaseTransfer) updateTransferTimestamps(uploadFileSize int64) {
}
return
}
if t.Connection.User.FirstDownload == 0 && !t.Connection.downloadDone.Load() && atomic.LoadInt64(&t.BytesSent) > 0 {
if t.Connection.User.FirstDownload == 0 && !t.Connection.downloadDone.Load() && t.BytesSent.Load() > 0 {
if err := dataprovider.UpdateUserTransferTimestamps(t.Connection.User.Username, false); err == nil {
t.Connection.downloadDone.Store(true)
ExecuteActionNotification(t.Connection, operationFirstDownload, t.fsPath, t.requestPath, "", //nolint:errcheck
"", "", atomic.LoadInt64(&t.BytesSent), t.ErrTransfer)
"", "", t.BytesSent.Load(), t.ErrTransfer)
}
}
}
@ -449,7 +449,7 @@ func (t *BaseTransfer) executeUploadHook(numFiles int, fileSize int64) (int, int
if err == nil {
numFiles--
fileSize = 0
atomic.StoreInt64(&t.BytesReceived, 0)
t.BytesReceived.Store(0)
t.MinWriteOffset = 0
} else {
t.Connection.Log(logger.LevelWarn, "unable to remove path %q after upload hook failure: %v", t.fsPath, err)
@ -494,10 +494,10 @@ func (t *BaseTransfer) HandleThrottle() {
var trasferredBytes int64
if t.transferType == TransferDownload {
wantedBandwidth = t.Connection.User.DownloadBandwidth
trasferredBytes = atomic.LoadInt64(&t.BytesSent)
trasferredBytes = t.BytesSent.Load()
} else {
wantedBandwidth = t.Connection.User.UploadBandwidth
trasferredBytes = atomic.LoadInt64(&t.BytesReceived)
trasferredBytes = t.BytesReceived.Load()
}
if wantedBandwidth > 0 {
// real and wanted elapsed as milliseconds, bytes as kilobytes

View file

@ -33,11 +33,11 @@ import (
func TestTransferUpdateQuota(t *testing.T) {
conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{})
transfer := BaseTransfer{
Connection: conn,
transferType: TransferUpload,
BytesReceived: 123,
Fs: vfs.NewOsFs("", os.TempDir(), ""),
Connection: conn,
transferType: TransferUpload,
Fs: vfs.NewOsFs("", os.TempDir(), ""),
}
transfer.BytesReceived.Store(123)
errFake := errors.New("fake error")
transfer.TransferError(errFake)
assert.False(t, transfer.updateQuota(1, 0))
@ -56,7 +56,7 @@ func TestTransferUpdateQuota(t *testing.T) {
QuotaSize: -1,
})
transfer.ErrTransfer = nil
transfer.BytesReceived = 1
transfer.BytesReceived.Store(1)
transfer.requestPath = "/vdir/file"
assert.True(t, transfer.updateQuota(1, 0))
err = transfer.Close()
@ -80,7 +80,7 @@ func TestTransferThrottling(t *testing.T) {
wantedDownloadElapsed -= wantedDownloadElapsed / 10
conn := NewBaseConnection("id", ProtocolSCP, "", "", u)
transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
transfer.BytesReceived = testFileSize
transfer.BytesReceived.Store(testFileSize)
transfer.Connection.UpdateLastActivity()
startTime := transfer.Connection.GetLastActivity()
transfer.HandleThrottle()
@ -90,7 +90,7 @@ func TestTransferThrottling(t *testing.T) {
assert.NoError(t, err)
transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
transfer.BytesSent = testFileSize
transfer.BytesSent.Store(testFileSize)
transfer.Connection.UpdateLastActivity()
startTime = transfer.Connection.GetLastActivity()
@ -226,7 +226,7 @@ func TestTransferErrors(t *testing.T) {
assert.Equal(t, testFile, transfer.GetFsPath())
transfer.SetCancelFn(cancelFn)
errFake := errors.New("err fake")
transfer.BytesReceived = 9
transfer.BytesReceived.Store(9)
transfer.TransferError(ErrQuotaExceeded)
assert.True(t, isCancelled)
transfer.TransferError(errFake)
@ -249,7 +249,7 @@ func TestTransferErrors(t *testing.T) {
fsPath := filepath.Join(os.TempDir(), "test_file")
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true,
fs, dataprovider.TransferQuota{})
transfer.BytesReceived = 9
transfer.BytesReceived.Store(9)
transfer.TransferError(errFake)
assert.Error(t, transfer.ErrTransfer, errFake.Error())
// the file is closed from the embedding struct before to call close
@ -269,7 +269,7 @@ func TestTransferErrors(t *testing.T) {
}
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true,
fs, dataprovider.TransferQuota{})
transfer.BytesReceived = 9
transfer.BytesReceived.Store(9)
// the file is closed from the embedding struct before to call close
err = file.Close()
assert.NoError(t, err)
@ -310,11 +310,11 @@ func TestRemovePartialCryptoFile(t *testing.T) {
func TestFTPMode(t *testing.T) {
conn := NewBaseConnection("", ProtocolFTP, "", "", dataprovider.User{})
transfer := BaseTransfer{
Connection: conn,
transferType: TransferUpload,
BytesReceived: 123,
Fs: vfs.NewOsFs("", os.TempDir(), ""),
Connection: conn,
transferType: TransferUpload,
Fs: vfs.NewOsFs("", os.TempDir(), ""),
}
transfer.BytesReceived.Store(123)
assert.Empty(t, transfer.ftpMode)
transfer.SetFtpMode("active")
assert.Equal(t, "active", transfer.ftpMode)
@ -399,14 +399,14 @@ func TestTransferQuota(t *testing.T) {
transfer.transferQuota = dataprovider.TransferQuota{
AllowedTotalSize: 10,
}
transfer.BytesReceived = 5
transfer.BytesSent = 4
transfer.BytesReceived.Store(5)
transfer.BytesSent.Store(4)
err = transfer.CheckRead()
assert.NoError(t, err)
err = transfer.CheckWrite()
assert.NoError(t, err)
transfer.BytesSent = 6
transfer.BytesSent.Store(6)
err = transfer.CheckRead()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error())
@ -428,7 +428,7 @@ func TestTransferQuota(t *testing.T) {
err = transfer.CheckWrite()
assert.NoError(t, err)
transfer.BytesReceived = 11
transfer.BytesReceived.Store(11)
err = transfer.CheckRead()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error())
@ -442,11 +442,11 @@ func TestUploadOutsideHomeRenameError(t *testing.T) {
conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{})
transfer := BaseTransfer{
Connection: conn,
transferType: TransferUpload,
BytesReceived: 123,
Fs: vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), ""),
Connection: conn,
transferType: TransferUpload,
Fs: vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), ""),
}
transfer.BytesReceived.Store(123)
fileName := filepath.Join(os.TempDir(), "_temp")
err := os.WriteFile(fileName, []byte(`data`), 0644)
@ -459,10 +459,10 @@ func TestUploadOutsideHomeRenameError(t *testing.T) {
Config.TempPath = filepath.Clean(os.TempDir())
res = transfer.checkUploadOutsideHomeDir(nil)
assert.Equal(t, 0, res)
assert.Greater(t, transfer.BytesReceived, int64(0))
assert.Greater(t, transfer.BytesReceived.Load(), int64(0))
res = transfer.checkUploadOutsideHomeDir(os.ErrPermission)
assert.Equal(t, 1, res)
assert.Equal(t, int64(0), transfer.BytesReceived)
assert.Equal(t, int64(0), transfer.BytesReceived.Load())
assert.NoFileExists(t, fileName)
Config.TempPath = oldTempPath

View file

@ -21,7 +21,6 @@ import (
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
@ -96,7 +95,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
}
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
"/file1", TransferUpload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
transfer1.BytesReceived = 150
transfer1.BytesReceived.Store(150)
err = Connections.Add(fakeConn1)
assert.NoError(t, err)
// the transferschecker will do nothing if there is only one ongoing transfer
@ -110,8 +109,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
}
transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
"/file2", TransferUpload, 0, 0, 120, 40, true, fsUser, dataprovider.TransferQuota{})
transfer1.BytesReceived = 50
transfer2.BytesReceived = 60
transfer1.BytesReceived.Store(50)
transfer2.BytesReceived.Store(60)
err = Connections.Add(fakeConn2)
assert.NoError(t, err)
@ -122,7 +121,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
}
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"),
"/file3", TransferDownload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
transfer3.BytesReceived = 60 // this value will be ignored, this is a download
transfer3.BytesReceived.Store(60) // this value will be ignored, this is a download
err = Connections.Add(fakeConn3)
assert.NoError(t, err)
@ -132,20 +131,20 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
transfer1.BytesReceived = 80 // truncated size will be subtracted, we are not overquota
transfer1.BytesReceived.Store(80) // truncated size will be subtracted, we are not overquota
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
transfer1.BytesReceived = 120
transfer1.BytesReceived.Store(120)
// we are now overquota
// if another check is in progress nothing is done
atomic.StoreInt32(&Connections.transfersCheckStatus, 1)
Connections.transfersCheckStatus.Store(true)
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
atomic.StoreInt32(&Connections.transfersCheckStatus, 0)
Connections.transfersCheckStatus.Store(false)
Connections.checkTransfers()
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort), transfer1.errAbort)
@ -172,8 +171,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
// now check a public folder
transfer1.BytesReceived = 0
transfer2.BytesReceived = 0
transfer1.BytesReceived.Store(0)
transfer2.BytesReceived.Store(0)
connID4 := xid.New().String()
fsFolder, err := user.GetFilesystemForPath(path.Join(vdirPath, "/file1"), connID4)
assert.NoError(t, err)
@ -197,12 +196,12 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
err = Connections.Add(fakeConn5)
assert.NoError(t, err)
transfer4.BytesReceived = 50
transfer5.BytesReceived = 40
transfer4.BytesReceived.Store(50)
transfer5.BytesReceived.Store(40)
Connections.checkTransfers()
assert.Nil(t, transfer4.errAbort)
assert.Nil(t, transfer5.errAbort)
transfer5.BytesReceived = 60
transfer5.BytesReceived.Store(60)
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
@ -286,7 +285,7 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
}
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
"/file1", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
transfer1.BytesReceived = 150
transfer1.BytesReceived.Store(150)
err = Connections.Add(fakeConn1)
assert.NoError(t, err)
// the transferschecker will do nothing if there is only one ongoing transfer
@ -300,26 +299,26 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
}
transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
"/file2", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
transfer2.BytesReceived = 150
transfer2.BytesReceived.Store(150)
err = Connections.Add(fakeConn2)
assert.NoError(t, err)
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
// now test overquota
transfer1.BytesReceived = 1024*1024 + 1
transfer2.BytesReceived = 0
transfer1.BytesReceived.Store(1024*1024 + 1)
transfer2.BytesReceived.Store(0)
Connections.checkTransfers()
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort))
assert.Nil(t, transfer2.errAbort)
transfer1.errAbort = nil
transfer1.BytesReceived = 1024*1024 + 1
transfer2.BytesReceived = 1024
transfer1.BytesReceived.Store(1024*1024 + 1)
transfer2.BytesReceived.Store(1024)
Connections.checkTransfers()
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort))
assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort))
transfer1.BytesReceived = 0
transfer2.BytesReceived = 0
transfer1.BytesReceived.Store(0)
transfer2.BytesReceived.Store(0)
transfer1.errAbort = nil
transfer2.errAbort = nil
@ -337,7 +336,7 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
}
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
"/file1", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
transfer3.BytesSent = 150
transfer3.BytesSent.Store(150)
err = Connections.Add(fakeConn3)
assert.NoError(t, err)
@ -348,15 +347,15 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
}
transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
"/file2", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
transfer4.BytesSent = 150
transfer4.BytesSent.Store(150)
err = Connections.Add(fakeConn4)
assert.NoError(t, err)
Connections.checkTransfers()
assert.Nil(t, transfer3.errAbort)
assert.Nil(t, transfer4.errAbort)
transfer3.BytesSent = 512 * 1024
transfer4.BytesSent = 512*1024 + 1
transfer3.BytesSent.Store(512 * 1024)
transfer4.BytesSent.Store(512*1024 + 1)
Connections.checkTransfers()
if assert.Error(t, transfer3.errAbort) {
assert.Contains(t, transfer3.errAbort.Error(), ErrReadQuotaExceeded.Error())

View file

@ -155,7 +155,7 @@ var (
ErrInvalidCredentials = errors.New("invalid credentials")
// ErrLoginNotAllowedFromIP defines the error to return if login is denied from the current IP
ErrLoginNotAllowedFromIP = errors.New("login is not allowed from this IP")
isAdminCreated = int32(0)
isAdminCreated atomic.Bool
validTLSUsernames = []string{string(sdk.TLSUsernameNone), string(sdk.TLSUsernameCN)}
config Config
provider Provider
@ -844,7 +844,7 @@ func Initialize(cnf Config, basePath string, checkAdmins bool) error {
if err != nil {
return err
}
atomic.StoreInt32(&isAdminCreated, int32(len(admins)))
isAdminCreated.Store(len(admins) > 0)
delayedQuotaUpdater.start()
return startScheduler()
}
@ -1722,7 +1722,7 @@ func UpdateTaskTimestamp(name string) error {
// HasAdmin returns true if the first admin has been created
// and so SFTPGo is ready to be used
func HasAdmin() bool {
return atomic.LoadInt32(&isAdminCreated) > 0
return isAdminCreated.Load()
}
// AddAdmin adds a new SFTPGo admin
@ -1734,7 +1734,7 @@ func AddAdmin(admin *Admin, executor, ipAddress string) error {
admin.Username = config.convertName(admin.Username)
err := provider.addAdmin(admin)
if err == nil {
atomic.StoreInt32(&isAdminCreated, 1)
isAdminCreated.Store(true)
executeAction(operationAdd, executor, ipAddress, actionObjectAdmin, admin.Username, admin)
}
return err

View file

@ -28,11 +28,11 @@ import (
var (
scheduler *cron.Cron
lastUserCacheUpdate int64
lastUserCacheUpdate atomic.Int64
// used for bolt and memory providers, so we avoid iterating all users/rules
// to find recently modified ones
lastUserUpdate int64
lastRuleUpdate int64
lastUserUpdate atomic.Int64
lastRuleUpdate atomic.Int64
)
func stopScheduler() {
@ -62,7 +62,7 @@ func startScheduler() error {
}
func addScheduledCacheUpdates() error {
lastUserCacheUpdate = util.GetTimeAsMsSinceEpoch(time.Now())
lastUserCacheUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now()))
_, err := scheduler.AddFunc("@every 10m", checkCacheUpdates)
if err != nil {
return fmt.Errorf("unable to schedule cache updates: %w", err)
@ -79,9 +79,9 @@ func checkDataprovider() {
}
func checkCacheUpdates() {
providerLog(logger.LevelDebug, "start user cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate))
providerLog(logger.LevelDebug, "start user cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate.Load()))
checkTime := util.GetTimeAsMsSinceEpoch(time.Now())
users, err := provider.getRecentlyUpdatedUsers(lastUserCacheUpdate)
users, err := provider.getRecentlyUpdatedUsers(lastUserCacheUpdate.Load())
if err != nil {
providerLog(logger.LevelError, "unable to get recently updated users: %v", err)
return
@ -102,22 +102,22 @@ func checkCacheUpdates() {
cachedPasswords.Remove(user.Username)
}
lastUserCacheUpdate = checkTime
providerLog(logger.LevelDebug, "end user cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate))
lastUserCacheUpdate.Store(checkTime)
providerLog(logger.LevelDebug, "end user cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate.Load()))
}
func setLastUserUpdate() {
atomic.StoreInt64(&lastUserUpdate, util.GetTimeAsMsSinceEpoch(time.Now()))
lastUserUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now()))
}
func getLastUserUpdate() int64 {
return atomic.LoadInt64(&lastUserUpdate)
return lastUserUpdate.Load()
}
func setLastRuleUpdate() {
atomic.StoreInt64(&lastRuleUpdate, util.GetTimeAsMsSinceEpoch(time.Now()))
lastRuleUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now()))
}
func getLastRuleUpdate() int64 {
return atomic.LoadInt64(&lastRuleUpdate)
return lastRuleUpdate.Load()
}

View file

@ -17,7 +17,6 @@ package ftpd
import (
"errors"
"io"
"sync/atomic"
"github.com/eikenb/pipeat"
@ -61,7 +60,7 @@ func (t *transfer) Read(p []byte) (n int, err error) {
t.Connection.UpdateLastActivity()
n, err = t.reader.Read(p)
atomic.AddInt64(&t.BytesSent, int64(n))
t.BytesSent.Add(int64(n))
if err == nil {
err = t.CheckRead()
@ -79,7 +78,7 @@ func (t *transfer) Write(p []byte) (n int, err error) {
t.Connection.UpdateLastActivity()
n, err = t.writer.Write(p)
atomic.AddInt64(&t.BytesReceived, int64(n))
t.BytesReceived.Add(int64(n))
if err == nil {
err = t.CheckWrite()

View file

@ -16,7 +16,6 @@ package httpd
import (
"io"
"sync/atomic"
"github.com/eikenb/pipeat"
@ -52,7 +51,7 @@ func newHTTPDFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter,
// Read reads the contents to downloads.
func (f *httpdFile) Read(p []byte) (n int, err error) {
if atomic.LoadInt32(&f.AbortTransfer) == 1 {
if f.AbortTransfer.Load() {
err := f.GetAbortError()
f.TransferError(err)
return 0, err
@ -61,7 +60,7 @@ func (f *httpdFile) Read(p []byte) (n int, err error) {
f.Connection.UpdateLastActivity()
n, err = f.reader.Read(p)
atomic.AddInt64(&f.BytesSent, int64(n))
f.BytesSent.Add(int64(n))
if err == nil {
err = f.CheckRead()
@ -76,7 +75,7 @@ func (f *httpdFile) Read(p []byte) (n int, err error) {
// Write writes the contents to upload
func (f *httpdFile) Write(p []byte) (n int, err error) {
if atomic.LoadInt32(&f.AbortTransfer) == 1 {
if f.AbortTransfer.Load() {
err := f.GetAbortError()
f.TransferError(err)
return 0, err
@ -85,7 +84,7 @@ func (f *httpdFile) Write(p []byte) (n int, err error) {
f.Connection.UpdateLastActivity()
n, err = f.writer.Write(p)
atomic.AddInt64(&f.BytesReceived, int64(n))
f.BytesReceived.Add(int64(n))
if err == nil {
err = f.CheckWrite()

View file

@ -238,24 +238,24 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request
func newThrottledReader(r io.ReadCloser, limit int64, conn *Connection) *throttledReader {
t := &throttledReader{
bytesRead: 0,
id: conn.GetTransferID(),
limit: limit,
r: r,
abortTransfer: 0,
start: time.Now(),
conn: conn,
id: conn.GetTransferID(),
limit: limit,
r: r,
start: time.Now(),
conn: conn,
}
t.bytesRead.Store(0)
t.abortTransfer.Store(false)
conn.AddTransfer(t)
return t
}
type throttledReader struct {
bytesRead int64
bytesRead atomic.Int64
id int64
limit int64
r io.ReadCloser
abortTransfer int32
abortTransfer atomic.Bool
start time.Time
conn *Connection
mu sync.Mutex
@ -271,7 +271,7 @@ func (t *throttledReader) GetType() int {
}
func (t *throttledReader) GetSize() int64 {
return atomic.LoadInt64(&t.bytesRead)
return t.bytesRead.Load()
}
func (t *throttledReader) GetDownloadedSize() int64 {
@ -279,7 +279,7 @@ func (t *throttledReader) GetDownloadedSize() int64 {
}
func (t *throttledReader) GetUploadedSize() int64 {
return atomic.LoadInt64(&t.bytesRead)
return t.bytesRead.Load()
}
func (t *throttledReader) GetVirtualPath() string {
@ -304,7 +304,7 @@ func (t *throttledReader) SignalClose(err error) {
t.mu.Lock()
t.errAbort = err
t.mu.Unlock()
atomic.StoreInt32(&(t.abortTransfer), 1)
t.abortTransfer.Store(true)
}
func (t *throttledReader) GetTruncatedSize() int64 {
@ -328,15 +328,15 @@ func (t *throttledReader) SetTimes(fsPath string, atime time.Time, mtime time.Ti
}
func (t *throttledReader) Read(p []byte) (n int, err error) {
if atomic.LoadInt32(&t.abortTransfer) == 1 {
if t.abortTransfer.Load() {
return 0, t.GetAbortError()
}
t.conn.UpdateLastActivity()
n, err = t.r.Read(p)
if t.limit > 0 {
atomic.AddInt64(&t.bytesRead, int64(n))
trasferredBytes := atomic.LoadInt64(&t.bytesRead)
t.bytesRead.Add(int64(n))
trasferredBytes := t.bytesRead.Load()
elapsed := time.Since(t.start).Nanoseconds() / 1000000
wantedElapsed := 1000 * (trasferredBytes / 1024) / t.limit
if wantedElapsed > elapsed {

View file

@ -93,7 +93,7 @@ func (c *Config) newKMSPluginSecretProvider(base kms.BaseSecret, url, masterKey
// Manager handles enabled plugins
type Manager struct {
closed int32
closed atomic.Bool
done chan bool
// List of configured plugins
Configs []Config `json:"plugins" mapstructure:"plugins"`
@ -124,10 +124,10 @@ func Initialize(configs []Config, logLevel string) error {
Handler = Manager{
Configs: configs,
done: make(chan bool),
closed: 0,
authScopes: -1,
concurrencyGuard: make(chan struct{}, 250),
}
Handler.closed.Store(false)
setLogLevel(logLevel)
if len(configs) == 0 {
return nil
@ -604,7 +604,7 @@ func (m *Manager) checkCrashedPlugins() {
}
func (m *Manager) restartNotifierPlugin(config Config, idx int) {
if atomic.LoadInt32(&m.closed) == 1 {
if m.closed.Load() {
return
}
logger.Info(logSender, "", "try to restart crashed notifier plugin %#v, idx: %v", config.Cmd, idx)
@ -622,7 +622,7 @@ func (m *Manager) restartNotifierPlugin(config Config, idx int) {
}
func (m *Manager) restartKMSPlugin(config Config, idx int) {
if atomic.LoadInt32(&m.closed) == 1 {
if m.closed.Load() {
return
}
logger.Info(logSender, "", "try to restart crashed kms plugin %#v, idx: %v", config.Cmd, idx)
@ -638,7 +638,7 @@ func (m *Manager) restartKMSPlugin(config Config, idx int) {
}
func (m *Manager) restartAuthPlugin(config Config, idx int) {
if atomic.LoadInt32(&m.closed) == 1 {
if m.closed.Load() {
return
}
logger.Info(logSender, "", "try to restart crashed auth plugin %#v, idx: %v", config.Cmd, idx)
@ -654,7 +654,7 @@ func (m *Manager) restartAuthPlugin(config Config, idx int) {
}
func (m *Manager) restartSearcherPlugin(config Config) {
if atomic.LoadInt32(&m.closed) == 1 {
if m.closed.Load() {
return
}
logger.Info(logSender, "", "try to restart crashed searcher plugin %#v", config.Cmd)
@ -670,7 +670,7 @@ func (m *Manager) restartSearcherPlugin(config Config) {
}
func (m *Manager) restartMetadaterPlugin(config Config) {
if atomic.LoadInt32(&m.closed) == 1 {
if m.closed.Load() {
return
}
logger.Info(logSender, "", "try to restart crashed metadater plugin %#v", config.Cmd)
@ -686,7 +686,7 @@ func (m *Manager) restartMetadaterPlugin(config Config) {
}
func (m *Manager) restartIPFilterPlugin(config Config) {
if atomic.LoadInt32(&m.closed) == 1 {
if m.closed.Load() {
return
}
logger.Info(logSender, "", "try to restart crashed IP filter plugin %#v", config.Cmd)
@ -712,7 +712,7 @@ func (m *Manager) removeTask() {
// Cleanup releases all the active plugins
func (m *Manager) Cleanup() {
logger.Debug(logSender, "", "cleanup")
atomic.StoreInt32(&m.closed, 1)
m.closed.Store(true)
close(m.done)
m.notifLock.Lock()
for _, n := range m.notifiers {

View file

@ -1785,7 +1785,7 @@ func TestUploadError(t *testing.T) {
if assert.Error(t, transfer.ErrTransfer) {
assert.EqualError(t, transfer.ErrTransfer, errFake.Error())
}
assert.Equal(t, int64(0), transfer.BytesReceived)
assert.Equal(t, int64(0), transfer.BytesReceived.Load())
assert.NoFileExists(t, testfile)
assert.NoFileExists(t, fileTempName)

View file

@ -1114,12 +1114,12 @@ func TestConcurrency(t *testing.T) {
err = createTestFile(testFilePath, testFileSize)
assert.NoError(t, err)
closedConns := int32(0)
var closedConns atomic.Int32
for i := 0; i < numLogins; i++ {
wg.Add(1)
go func(counter int) {
defer wg.Done()
defer atomic.AddInt32(&closedConns, 1)
defer closedConns.Add(1)
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
@ -1139,7 +1139,7 @@ func TestConcurrency(t *testing.T) {
maxConns := 0
maxSessions := 0
for {
servedReqs := atomic.LoadInt32(&closedConns)
servedReqs := closedConns.Load()
if servedReqs > 0 {
stats := common.Connections.GetStats()
if len(stats) > maxConns {

View file

@ -17,7 +17,6 @@ package sftpd
import (
"fmt"
"io"
"sync/atomic"
"github.com/eikenb/pipeat"
@ -107,7 +106,7 @@ func (t *transfer) ReadAt(p []byte, off int64) (n int, err error) {
t.Connection.UpdateLastActivity()
n, err = t.readerAt.ReadAt(p, off)
atomic.AddInt64(&t.BytesSent, int64(n))
t.BytesSent.Add(int64(n))
if err == nil {
err = t.CheckRead()
@ -133,7 +132,7 @@ func (t *transfer) WriteAt(p []byte, off int64) (n int, err error) {
}
n, err = t.writerAt.WriteAt(p, off)
atomic.AddInt64(&t.BytesReceived, int64(n))
t.BytesReceived.Add(int64(n))
if err == nil {
err = t.CheckWrite()
@ -213,13 +212,13 @@ func (t *transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64,
if nw > 0 {
written += int64(nw)
if isDownload {
atomic.StoreInt64(&t.BytesSent, written)
t.BytesSent.Store(written)
if errCheck := t.CheckRead(); errCheck != nil {
err = errCheck
break
}
} else {
atomic.StoreInt64(&t.BytesReceived, written)
t.BytesReceived.Store(written)
if errCheck := t.CheckWrite(); errCheck != nil {
err = errCheck
break
@ -245,7 +244,7 @@ func (t *transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64,
}
t.ErrTransfer = err
if written > 0 || err != nil {
metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.GetType(),
metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(), t.GetType(),
t.ErrTransfer, vfs.IsSFTPFs(t.Fs))
}
return written, err

View file

@ -32,14 +32,14 @@ func (l *listener) Accept() (net.Conn, error) {
return nil, err
}
tc := &Conn{
Conn: c,
ReadTimeout: l.ReadTimeout,
WriteTimeout: l.WriteTimeout,
ReadThreshold: int32((l.ReadTimeout * 1024) / time.Second),
WriteThreshold: int32((l.WriteTimeout * 1024) / time.Second),
BytesReadFromDeadline: 0,
BytesWrittenFromDeadline: 0,
Conn: c,
ReadTimeout: l.ReadTimeout,
WriteTimeout: l.WriteTimeout,
ReadThreshold: int32((l.ReadTimeout * 1024) / time.Second),
WriteThreshold: int32((l.WriteTimeout * 1024) / time.Second),
}
tc.BytesReadFromDeadline.Store(0)
tc.BytesWrittenFromDeadline.Store(0)
return tc, nil
}
@ -51,13 +51,13 @@ type Conn struct {
WriteTimeout time.Duration
ReadThreshold int32
WriteThreshold int32
BytesReadFromDeadline int32
BytesWrittenFromDeadline int32
BytesReadFromDeadline atomic.Int32
BytesWrittenFromDeadline atomic.Int32
}
func (c *Conn) Read(b []byte) (n int, err error) {
if atomic.LoadInt32(&c.BytesReadFromDeadline) > c.ReadThreshold {
atomic.StoreInt32(&c.BytesReadFromDeadline, 0)
if c.BytesReadFromDeadline.Load() > c.ReadThreshold {
c.BytesReadFromDeadline.Store(0)
// we set both read and write deadlines here otherwise after the request
// is read writing the response fails with an i/o timeout error
err = c.Conn.SetDeadline(time.Now().Add(c.ReadTimeout))
@ -66,13 +66,13 @@ func (c *Conn) Read(b []byte) (n int, err error) {
}
}
n, err = c.Conn.Read(b)
atomic.AddInt32(&c.BytesReadFromDeadline, int32(n))
c.BytesReadFromDeadline.Add(int32(n))
return
}
func (c *Conn) Write(b []byte) (n int, err error) {
if atomic.LoadInt32(&c.BytesWrittenFromDeadline) > c.WriteThreshold {
atomic.StoreInt32(&c.BytesWrittenFromDeadline, 0)
if c.BytesWrittenFromDeadline.Load() > c.WriteThreshold {
c.BytesWrittenFromDeadline.Store(0)
// we extend the read deadline too, not sure it's necessary,
// but it doesn't hurt
err = c.Conn.SetDeadline(time.Now().Add(c.WriteTimeout))
@ -81,7 +81,7 @@ func (c *Conn) Write(b []byte) (n int, err error) {
}
}
n, err = c.Conn.Write(b)
atomic.AddInt32(&c.BytesWrittenFromDeadline, int32(n))
c.BytesWrittenFromDeadline.Add(int32(n))
return
}

View file

@ -902,7 +902,7 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *a
finished := false
var wg sync.WaitGroup
var errOnce sync.Once
var hasError int32
var hasError atomic.Bool
var poolError error
poolCtx, poolCancel := context.WithCancel(ctx)
@ -919,7 +919,7 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *a
offset = end
guard <- struct{}{}
if atomic.LoadInt32(&hasError) == 1 {
if hasError.Load() {
fsLog(fs, logger.LevelDebug, "pool error, download for part %v not started", part)
break
}
@ -941,7 +941,7 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *a
if err != nil {
errOnce.Do(func() {
fsLog(fs, logger.LevelError, "multipart download error: %+v", err)
atomic.StoreInt32(&hasError, 1)
hasError.Store(true)
poolError = fmt.Errorf("multipart download error: %w", err)
poolCancel()
})
@ -971,7 +971,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
var blocks []string
var wg sync.WaitGroup
var errOnce sync.Once
var hasError int32
var hasError atomic.Bool
var poolError error
poolCtx, poolCancel := context.WithCancel(ctx)
@ -999,7 +999,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
blocks = append(blocks, blockID)
guard <- struct{}{}
if atomic.LoadInt32(&hasError) == 1 {
if hasError.Load() {
fsLog(fs, logger.LevelError, "pool error, upload for part %v not started", part)
pool.releaseBuffer(buf)
break
@ -1023,7 +1023,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
if err != nil {
errOnce.Do(func() {
fsLog(fs, logger.LevelDebug, "multipart upload error: %+v", err)
atomic.StoreInt32(&hasError, 1)
hasError.Store(true)
poolError = fmt.Errorf("multipart upload error: %w", err)
poolCancel()
})

View file

@ -835,7 +835,7 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int
var completedParts []types.CompletedPart
var partMutex sync.Mutex
var wg sync.WaitGroup
var hasError int32
var hasError atomic.Bool
var errOnce sync.Once
var copyError error
var partNumber int32
@ -854,7 +854,7 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int
offset = end
guard <- struct{}{}
if atomic.LoadInt32(&hasError) == 1 {
if hasError.Load() {
fsLog(fs, logger.LevelDebug, "previous multipart copy error, copy for part %d not started", partNumber)
break
}
@ -880,7 +880,7 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int
if err != nil {
errOnce.Do(func() {
fsLog(fs, logger.LevelError, "unable to copy part number %d: %+v", partNum, err)
atomic.StoreInt32(&hasError, 1)
hasError.Store(true)
copyError = fmt.Errorf("error copying part number %d: %w", partNum, err)
opCancel()

View file

@ -42,7 +42,7 @@ type webDavFile struct {
info os.FileInfo
startOffset int64
isFinished bool
readTryed int32
readTryed atomic.Bool
}
func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, pipeReader *pipeat.PipeReaderAt) *webDavFile {
@ -56,15 +56,16 @@ func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter
} else if pipeReader != nil {
reader = pipeReader
}
return &webDavFile{
f := &webDavFile{
BaseTransfer: baseTransfer,
writer: writer,
reader: reader,
isFinished: false,
startOffset: 0,
info: nil,
readTryed: 0,
}
f.readTryed.Store(false)
return f
}
type webDavFileInfo struct {
@ -124,7 +125,7 @@ func (f *webDavFile) Stat() (os.FileInfo, error) {
f.Unlock()
if f.GetType() == common.TransferUpload && errUpload == nil {
info := &webDavFileInfo{
FileInfo: vfs.NewFileInfo(f.GetFsPath(), false, atomic.LoadInt64(&f.BytesReceived), time.Unix(0, 0), false),
FileInfo: vfs.NewFileInfo(f.GetFsPath(), false, f.BytesReceived.Load(), time.Unix(0, 0), false),
Fs: f.Fs,
virtualPath: f.GetVirtualPath(),
fsPath: f.GetFsPath(),
@ -149,10 +150,10 @@ func (f *webDavFile) Stat() (os.FileInfo, error) {
// Read reads the contents to downloads.
func (f *webDavFile) Read(p []byte) (n int, err error) {
if atomic.LoadInt32(&f.AbortTransfer) == 1 {
if f.AbortTransfer.Load() {
return 0, errTransferAborted
}
if atomic.LoadInt32(&f.readTryed) == 0 {
if !f.readTryed.Load() {
if !f.Connection.User.HasPerm(dataprovider.PermDownload, path.Dir(f.GetVirtualPath())) {
return 0, f.Connection.GetPermissionDeniedError()
}
@ -171,7 +172,7 @@ func (f *webDavFile) Read(p []byte) (n int, err error) {
f.Connection.Log(logger.LevelDebug, "download for file %#v denied by pre action: %v", f.GetVirtualPath(), err)
return 0, f.Connection.GetPermissionDeniedError()
}
atomic.StoreInt32(&f.readTryed, 1)
f.readTryed.Store(true)
}
f.Connection.UpdateLastActivity()
@ -198,7 +199,7 @@ func (f *webDavFile) Read(p []byte) (n int, err error) {
}
n, err = f.reader.Read(p)
atomic.AddInt64(&f.BytesSent, int64(n))
f.BytesSent.Add(int64(n))
if err == nil {
err = f.CheckRead()
}
@ -212,14 +213,14 @@ func (f *webDavFile) Read(p []byte) (n int, err error) {
// Write writes the uploaded contents.
func (f *webDavFile) Write(p []byte) (n int, err error) {
if atomic.LoadInt32(&f.AbortTransfer) == 1 {
if f.AbortTransfer.Load() {
return 0, errTransferAborted
}
f.Connection.UpdateLastActivity()
n, err = f.writer.Write(p)
atomic.AddInt64(&f.BytesReceived, int64(n))
f.BytesReceived.Add(int64(n))
if err == nil {
err = f.CheckWrite()
@ -252,7 +253,7 @@ func (f *webDavFile) updateTransferQuotaOnSeek() {
if transferQuota.HasSizeLimits() {
go func(ulSize, dlSize int64, user dataprovider.User) {
dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
}(atomic.LoadInt64(&f.BytesReceived), atomic.LoadInt64(&f.BytesSent), f.Connection.User)
}(f.BytesReceived.Load(), f.BytesSent.Load(), f.Connection.User)
}
}
@ -270,7 +271,7 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) {
return ret, err
}
if f.GetType() == common.TransferDownload {
readOffset := f.startOffset + atomic.LoadInt64(&f.BytesSent)
readOffset := f.startOffset + f.BytesSent.Load()
if offset == 0 && readOffset == 0 {
if whence == io.SeekStart {
return 0, nil
@ -288,8 +289,8 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) {
f.reader = nil
}
startByte := int64(0)
atomic.StoreInt64(&f.BytesReceived, 0)
atomic.StoreInt64(&f.BytesSent, 0)
f.BytesReceived.Store(0)
f.BytesSent.Store(0)
f.updateTransferQuotaOnSeek()
switch whence {
@ -369,7 +370,7 @@ func (f *webDavFile) setFinished() error {
func (f *webDavFile) isTransfer() bool {
if f.GetType() == common.TransferDownload {
return atomic.LoadInt32(&f.readTryed) > 0
return f.readTryed.Load()
}
return true
}