From be9230e85b5d682ef33bc28c3139237e199a123e Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Tue, 16 Feb 2021 19:11:36 +0100 Subject: [PATCH] micro optimizations spotted using the go-critic linter --- cmd/startsubsys.go | 2 +- common/actions.go | 14 ++++---- common/actions_test.go | 4 +-- common/common.go | 12 +++---- common/connection.go | 66 +++++++++++++++++----------------- common/connection_test.go | 16 ++++----- common/transfer.go | 6 ++-- dataprovider/bolt.go | 7 ++-- dataprovider/dataprovider.go | 68 ++++++++++++++++++------------------ dataprovider/memory.go | 4 +-- dataprovider/mysql.go | 8 ++--- dataprovider/pgsql.go | 8 ++--- dataprovider/sqlcommon.go | 41 +++------------------- dataprovider/sqlite.go | 8 ++--- ftpd/handler.go | 6 ++-- ftpd/server.go | 2 +- httpd/api_quota.go | 8 ++--- httpd/web.go | 2 +- logger/logger.go | 2 +- sftpd/handler.go | 8 ++--- sftpd/internal_test.go | 4 +-- sftpd/scp.go | 6 ++-- sftpd/server.go | 10 +++--- sftpd/sftpd_test.go | 2 +- sftpd/ssh_cmd.go | 6 ++-- sftpd/subsystem.go | 4 +-- vfs/azblobfs.go | 17 +++++---- webdavd/handler.go | 6 ++-- webdavd/server.go | 2 +- 29 files changed, 160 insertions(+), 189 deletions(-) diff --git a/cmd/startsubsys.go b/cmd/startsubsys.go index b65b7cbf..ce91b605 100644 --- a/cmd/startsubsys.go +++ b/cmd/startsubsys.go @@ -122,7 +122,7 @@ Command-line flags should be specified in the Subsystem declaration. os.Exit(1) } } - err = sftpd.ServeSubSystemConnection(user, connectionID, os.Stdin, os.Stdout) + err = sftpd.ServeSubSystemConnection(&user, connectionID, os.Stdin, os.Stdout) if err != nil && err != io.EOF { logger.Warn(logSender, connectionID, "serving subsystem finished with error: %v", err) os.Exit(1) diff --git a/common/actions.go b/common/actions.go index 9f59231c..c9e5bd1e 100644 --- a/common/actions.go +++ b/common/actions.go @@ -52,7 +52,7 @@ func SSHCommandActionNotification(user *dataprovider.User, filePath, target, ssh // ActionHandler handles a notification for a Protocol Action. type ActionHandler interface { - Handle(notification ActionNotification) error + Handle(notification *ActionNotification) error } // ActionNotification defines a notification for a Protocol Action. @@ -75,7 +75,7 @@ func newActionNotification( operation, filePath, target, sshCmd, protocol string, fileSize int64, err error, -) ActionNotification { +) *ActionNotification { var bucket, endpoint string status := 1 @@ -99,7 +99,7 @@ func newActionNotification( status = 0 } - return ActionNotification{ + return &ActionNotification{ Action: operation, Username: user.Username, Path: filePath, @@ -116,7 +116,7 @@ func newActionNotification( type defaultActionHandler struct{} -func (h *defaultActionHandler) Handle(notification ActionNotification) error { +func (h *defaultActionHandler) Handle(notification *ActionNotification) error { if !utils.IsStringInSlice(notification.Action, Config.Actions.ExecuteOn) { return errUnconfiguredAction } @@ -134,7 +134,7 @@ func (h *defaultActionHandler) Handle(notification ActionNotification) error { return h.handleCommand(notification) } -func (h *defaultActionHandler) handleHTTP(notification ActionNotification) error { +func (h *defaultActionHandler) handleHTTP(notification *ActionNotification) error { u, err := url.Parse(Config.Actions.Hook) if err != nil { logger.Warn(notification.Protocol, "", "Invalid hook %#v for operation %#v: %v", Config.Actions.Hook, notification.Action, err) @@ -165,7 +165,7 @@ func (h *defaultActionHandler) handleHTTP(notification ActionNotification) error return err } -func (h *defaultActionHandler) handleCommand(notification ActionNotification) error { +func (h *defaultActionHandler) handleCommand(notification *ActionNotification) error { if !filepath.IsAbs(Config.Actions.Hook) { err := fmt.Errorf("invalid notification command %#v", Config.Actions.Hook) logger.Warn(notification.Protocol, "", "unable to execute notification command: %v", err) @@ -188,7 +188,7 @@ func (h *defaultActionHandler) handleCommand(notification ActionNotification) er return err } -func notificationAsEnvVars(notification ActionNotification) []string { +func notificationAsEnvVars(notification *ActionNotification) []string { return []string{ fmt.Sprintf("SFTPGO_ACTION=%v", notification.Action), fmt.Sprintf("SFTPGO_ACTION_USERNAME=%v", notification.Username), diff --git a/common/actions_test.go b/common/actions_test.go index 51eb3b56..07a00a0b 100644 --- a/common/actions_test.go +++ b/common/actions_test.go @@ -201,7 +201,7 @@ type actionHandlerStub struct { called bool } -func (h *actionHandlerStub) Handle(notification ActionNotification) error { +func (h *actionHandlerStub) Handle(notification *ActionNotification) error { h.called = true return nil @@ -215,7 +215,7 @@ func TestInitializeActionHandler(t *testing.T) { InitializeActionHandler(&defaultActionHandler{}) }) - err := actionHandler.Handle(ActionNotification{}) + err := actionHandler.Handle(&ActionNotification{}) assert.NoError(t, err) assert.True(t, handler.called) diff --git a/common/common.go b/common/common.go index 656bfc14..08d2b79e 100644 --- a/common/common.go +++ b/common/common.go @@ -630,13 +630,13 @@ func (conns *ActiveConnections) IsNewConnectionAllowed() bool { } // GetStats returns stats for active connections -func (conns *ActiveConnections) GetStats() []ConnectionStatus { +func (conns *ActiveConnections) GetStats() []*ConnectionStatus { conns.RLock() defer conns.RUnlock() - stats := make([]ConnectionStatus, 0, len(conns.connections)) + stats := make([]*ConnectionStatus, 0, len(conns.connections)) for _, c := range conns.connections { - stat := ConnectionStatus{ + stat := &ConnectionStatus{ Username: c.GetUsername(), ConnectionID: c.GetID(), ClientVersion: c.GetClientVersion(), @@ -675,14 +675,14 @@ type ConnectionStatus struct { } // GetConnectionDuration returns the connection duration as string -func (c ConnectionStatus) GetConnectionDuration() string { +func (c *ConnectionStatus) GetConnectionDuration() string { elapsed := time.Since(utils.GetTimeFromMsecSinceEpoch(c.ConnectionTime)) return utils.GetDurationAsString(elapsed) } // GetConnectionInfo returns connection info. // Protocol,Client Version and RemoteAddress are returned. -func (c ConnectionStatus) GetConnectionInfo() string { +func (c *ConnectionStatus) GetConnectionInfo() string { var result strings.Builder result.WriteString(fmt.Sprintf("%v. Client: %#v From: %#v", c.Protocol, c.ClientVersion, c.RemoteAddress)) @@ -702,7 +702,7 @@ func (c ConnectionStatus) GetConnectionInfo() string { } // GetTransfersAsString returns the active transfers as string -func (c ConnectionStatus) GetTransfersAsString() string { +func (c *ConnectionStatus) GetTransfersAsString() string { result := "" for _, t := range c.Transfers { if result != "" { diff --git a/common/connection.go b/common/connection.go index 62700d38..87063589 100644 --- a/common/connection.go +++ b/common/connection.go @@ -37,10 +37,10 @@ type BaseConnection struct { } // NewBaseConnection returns a new BaseConnection -func NewBaseConnection(ID, protocol string, user dataprovider.User, fs vfs.Fs) *BaseConnection { - connID := ID +func NewBaseConnection(id, protocol string, user dataprovider.User, fs vfs.Fs) *BaseConnection { + connID := id if utils.IsStringInSlice(protocol, supportedProtocols) { - connID = fmt.Sprintf("%v_%v", protocol, ID) + connID = fmt.Sprintf("%v_%v", protocol, id) } return &BaseConnection{ ID: connID, @@ -272,12 +272,12 @@ func (c *BaseConnection) RemoveFile(fsPath, virtualPath string, info os.FileInfo if info.Mode()&os.ModeSymlink == 0 { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath)) if err == nil { - dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, -1, -size, false) //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, -1, -size, false) //nolint:errcheck if vfolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(c.User, -1, -size, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck } } else { - dataprovider.UpdateUserQuota(c.User, -1, -size, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck } } if actionErr != nil { @@ -577,12 +577,12 @@ func (c *BaseConnection) truncateFile(fsPath, virtualPath string, size int64) er sizeDiff := initialSize - size vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath)) if err == nil { - dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -sizeDiff, false) //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -sizeDiff, false) //nolint:errcheck if vfolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(c.User, 0, -sizeDiff, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck } } else { - dataprovider.UpdateUserQuota(c.User, 0, -sizeDiff, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck } } return err @@ -835,64 +835,64 @@ func (c *BaseConnection) isCrossFoldersRequest(virtualSourcePath, virtualTargetP return true } -func (c *BaseConnection) updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder vfs.VirtualFolder, initialSize, +func (c *BaseConnection) updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) { if sourceFolder.MappedPath == dstFolder.MappedPath { // both files are inside the same virtual folder if initialSize != -1 { - dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, -numFiles, -initialSize, false) //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, -numFiles, -initialSize, false) //nolint:errcheck if dstFolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(c.User, -numFiles, -initialSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, -numFiles, -initialSize, false) //nolint:errcheck } } return } // files are inside different virtual folders - dataprovider.UpdateVirtualFolderQuota(sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck if sourceFolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(c.User, -numFiles, -filesSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck } if initialSize == -1 { - dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck if dstFolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(c.User, numFiles, filesSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck } } else { // we cannot have a directory here, initialSize != -1 only for files - dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck if dstFolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(c.User, 0, filesSize-initialSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck } } } -func (c *BaseConnection) updateQuotaMoveFromVFolder(sourceFolder vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) { +func (c *BaseConnection) updateQuotaMoveFromVFolder(sourceFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) { // move between a virtual folder and the user home dir - dataprovider.UpdateVirtualFolderQuota(sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck if sourceFolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(c.User, -numFiles, -filesSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck } if initialSize == -1 { - dataprovider.UpdateUserQuota(c.User, numFiles, filesSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck } else { // we cannot have a directory here, initialSize != -1 only for files - dataprovider.UpdateUserQuota(c.User, 0, filesSize-initialSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck } } -func (c *BaseConnection) updateQuotaMoveToVFolder(dstFolder vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) { +func (c *BaseConnection) updateQuotaMoveToVFolder(dstFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) { // move between the user home dir and a virtual folder - dataprovider.UpdateUserQuota(c.User, -numFiles, -filesSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck if initialSize == -1 { - dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck if dstFolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(c.User, numFiles, filesSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck } } else { // we cannot have a directory here, initialSize != -1 only for files - dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck if dstFolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(c.User, 0, filesSize-initialSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck } } } @@ -909,7 +909,7 @@ func (c *BaseConnection) updateQuotaAfterRename(virtualSourcePath, virtualTarget // both files are contained inside the user home dir if initialSize != -1 { // we cannot have a directory here - dataprovider.UpdateUserQuota(c.User, -1, -initialSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, -1, -initialSize, false) //nolint:errcheck } return nil } @@ -932,13 +932,13 @@ func (c *BaseConnection) updateQuotaAfterRename(virtualSourcePath, virtualTarget return err } if errSrc == nil && errDst == nil { - c.updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder, initialSize, filesSize, numFiles) + c.updateQuotaMoveBetweenVFolders(&sourceFolder, &dstFolder, initialSize, filesSize, numFiles) } if errSrc == nil && errDst != nil { - c.updateQuotaMoveFromVFolder(sourceFolder, initialSize, filesSize, numFiles) + c.updateQuotaMoveFromVFolder(&sourceFolder, initialSize, filesSize, numFiles) } if errSrc != nil && errDst == nil { - c.updateQuotaMoveToVFolder(dstFolder, initialSize, filesSize, numFiles) + c.updateQuotaMoveToVFolder(&dstFolder, initialSize, filesSize, numFiles) } return nil } diff --git a/common/connection_test.go b/common/connection_test.go index 6dcb9177..8b8c44b6 100644 --- a/common/connection_test.go +++ b/common/connection_test.go @@ -1054,7 +1054,7 @@ func TestHasSpace(t *testing.T) { folder, err := dataprovider.GetFolderByName(folderName) assert.NoError(t, err) - err = dataprovider.UpdateVirtualFolderQuota(folder, 10, 1048576, true) + err = dataprovider.UpdateVirtualFolderQuota(&folder, 10, 1048576, true) assert.NoError(t, err) quotaResult = c.HasSpace(true, false, "/vdir/file1") assert.False(t, quotaResult.HasSpace) @@ -1105,14 +1105,14 @@ func TestUpdateQuotaMoveVFolders(t *testing.T) { assert.NoError(t, err) folder2, err := dataprovider.GetFolderByName(folderName2) assert.NoError(t, err) - err = dataprovider.UpdateVirtualFolderQuota(folder1, 1, 100, true) + err = dataprovider.UpdateVirtualFolderQuota(&folder1, 1, 100, true) assert.NoError(t, err) - err = dataprovider.UpdateVirtualFolderQuota(folder2, 2, 150, true) + err = dataprovider.UpdateVirtualFolderQuota(&folder2, 2, 150, true) assert.NoError(t, err) fs, err := user.GetFilesystem("id") assert.NoError(t, err) c := NewBaseConnection("", ProtocolSFTP, user, fs) - c.updateQuotaMoveBetweenVFolders(user.VirtualFolders[0], user.VirtualFolders[1], -1, 100, 1) + c.updateQuotaMoveBetweenVFolders(&user.VirtualFolders[0], &user.VirtualFolders[1], -1, 100, 1) folder1, err = dataprovider.GetFolderByName(folderName1) assert.NoError(t, err) assert.Equal(t, 0, folder1.UsedQuotaFiles) @@ -1122,7 +1122,7 @@ func TestUpdateQuotaMoveVFolders(t *testing.T) { assert.Equal(t, 3, folder2.UsedQuotaFiles) assert.Equal(t, int64(250), folder2.UsedQuotaSize) - c.updateQuotaMoveBetweenVFolders(user.VirtualFolders[1], user.VirtualFolders[0], 10, 100, 1) + c.updateQuotaMoveBetweenVFolders(&user.VirtualFolders[1], &user.VirtualFolders[0], 10, 100, 1) folder1, err = dataprovider.GetFolderByName(folderName1) assert.NoError(t, err) assert.Equal(t, 0, folder1.UsedQuotaFiles) @@ -1132,9 +1132,9 @@ func TestUpdateQuotaMoveVFolders(t *testing.T) { assert.Equal(t, 2, folder2.UsedQuotaFiles) assert.Equal(t, int64(150), folder2.UsedQuotaSize) - err = dataprovider.UpdateUserQuota(user, 1, 100, true) + err = dataprovider.UpdateUserQuota(&user, 1, 100, true) assert.NoError(t, err) - c.updateQuotaMoveFromVFolder(user.VirtualFolders[1], -1, 50, 1) + c.updateQuotaMoveFromVFolder(&user.VirtualFolders[1], -1, 50, 1) folder2, err = dataprovider.GetFolderByName(folderName2) assert.NoError(t, err) assert.Equal(t, 1, folder2.UsedQuotaFiles) @@ -1144,7 +1144,7 @@ func TestUpdateQuotaMoveVFolders(t *testing.T) { assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, int64(100), user.UsedQuotaSize) - c.updateQuotaMoveToVFolder(user.VirtualFolders[1], -1, 100, 1) + c.updateQuotaMoveToVFolder(&user.VirtualFolders[1], -1, 100, 1) folder2, err = dataprovider.GetFolderByName(folderName2) assert.NoError(t, err) assert.Equal(t, 2, folder2.UsedQuotaFiles) diff --git a/common/transfer.go b/common/transfer.go index 093e2cb7..19f50906 100644 --- a/common/transfer.go +++ b/common/transfer.go @@ -266,13 +266,13 @@ func (t *BaseTransfer) updateQuota(numFiles int, fileSize int64) bool { if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff > 0) { vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath)) if err == nil { - dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck sizeDiff, false) if vfolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck } } else { - dataprovider.UpdateUserQuota(t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck } return true } diff --git a/dataprovider/bolt.go b/dataprovider/bolt.go index c2cb7f81..77a7ea9e 100644 --- a/dataprovider/bolt.go +++ b/dataprovider/bolt.go @@ -22,8 +22,7 @@ const ( ) var ( - usersBucket = []byte("users") - //usersIDIdxBucket = []byte("users_id_idx") + usersBucket = []byte("users") foldersBucket = []byte("folders") adminsBucket = []byte("admins") dbVersionBucket = []byte("db_version") @@ -113,7 +112,7 @@ func (p *BoltProvider) validateUserAndPass(username, password, ip, protocol stri providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err) return user, err } - return checkUserAndPass(user, password, ip, protocol) + return checkUserAndPass(&user, password, ip, protocol) } func (p *BoltProvider) validateAdminAndPass(username, password, ip string) (Admin, error) { @@ -136,7 +135,7 @@ func (p *BoltProvider) validateUserAndPubKey(username string, pubKey []byte) (Us providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err) return user, "", err } - return checkUserAndPubKey(user, pubKey) + return checkUserAndPubKey(&user, pubKey) } func (p *BoltProvider) updateLastLogin(username string) error { diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index b7cfaf50..bf6b161c 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -625,14 +625,14 @@ func CheckUserAndPass(username, password, ip, protocol string) (User, error) { if err != nil { return user, err } - return checkUserAndPass(user, password, ip, protocol) + return checkUserAndPass(&user, password, ip, protocol) } if config.PreLoginHook != "" { user, err := executePreLoginHook(username, LoginMethodPassword, ip, protocol) if err != nil { return user, err } - return checkUserAndPass(user, password, ip, protocol) + return checkUserAndPass(&user, password, ip, protocol) } return provider.validateUserAndPass(username, password, ip, protocol) } @@ -644,14 +644,14 @@ func CheckUserAndPubKey(username string, pubKey []byte, ip, protocol string) (Us if err != nil { return user, "", err } - return checkUserAndPubKey(user, pubKey) + return checkUserAndPubKey(&user, pubKey) } if config.PreLoginHook != "" { user, err := executePreLoginHook(username, SSHLoginMethodPublicKey, ip, protocol) if err != nil { return user, "", err } - return checkUserAndPubKey(user, pubKey) + return checkUserAndPubKey(&user, pubKey) } return provider.validateUserAndPubKey(username, pubKey) } @@ -671,11 +671,11 @@ func CheckKeyboardInteractiveAuth(username, authHook string, client ssh.Keyboard if err != nil { return user, err } - return doKeyboardInteractiveAuth(user, authHook, client, ip, protocol) + return doKeyboardInteractiveAuth(&user, authHook, client, ip, protocol) } // UpdateLastLogin updates the last login fields for the given SFTP user -func UpdateLastLogin(user User) error { +func UpdateLastLogin(user *User) error { lastLogin := utils.GetTimeFromMsecSinceEpoch(user.LastLogin) diff := -time.Until(lastLogin) if diff < 0 || diff > lastLoginMinDelay { @@ -690,7 +690,7 @@ func UpdateLastLogin(user User) error { // UpdateUserQuota updates the quota for the given SFTP user adding filesAdd and sizeAdd. // If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference. -func UpdateUserQuota(user User, filesAdd int, sizeAdd int64, reset bool) error { +func UpdateUserQuota(user *User, filesAdd int, sizeAdd int64, reset bool) error { if config.TrackQuota == 0 { return &MethodDisabledError{err: trackQuotaDisabledError} } else if config.TrackQuota == 2 && !reset && !user.HasQuotaRestrictions() { @@ -704,7 +704,7 @@ func UpdateUserQuota(user User, filesAdd int, sizeAdd int64, reset bool) error { // UpdateVirtualFolderQuota updates the quota for the given virtual folder adding filesAdd and sizeAdd. // If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference. -func UpdateVirtualFolderQuota(vfolder vfs.BaseVirtualFolder, filesAdd int, sizeAdd int64, reset bool) error { +func UpdateVirtualFolderQuota(vfolder *vfs.BaseVirtualFolder, filesAdd int, sizeAdd int64, reset bool) error { if config.TrackQuota == 0 { return &MethodDisabledError{err: trackQuotaDisabledError} } @@ -1482,53 +1482,53 @@ func isPasswordOK(user *User, password string) (bool, error) { return match, err } -func checkUserAndPass(user User, password, ip, protocol string) (User, error) { - err := checkLoginConditions(&user) +func checkUserAndPass(user *User, password, ip, protocol string) (User, error) { + err := checkLoginConditions(user) if err != nil { - return user, err + return *user, err } if user.Password == "" { - return user, errors.New("Credentials cannot be null or empty") + return *user, errors.New("Credentials cannot be null or empty") } hookResponse, err := executeCheckPasswordHook(user.Username, password, ip, protocol) if err != nil { providerLog(logger.LevelDebug, "error executing check password hook: %v", err) - return user, errors.New("Unable to check credentials") + return *user, errors.New("Unable to check credentials") } switch hookResponse.Status { case -1: // no hook configured case 1: providerLog(logger.LevelDebug, "password accepted by check password hook") - return user, nil + return *user, nil case 2: providerLog(logger.LevelDebug, "partial success from check password hook") password = hookResponse.ToVerify default: providerLog(logger.LevelDebug, "password rejected by check password hook, status: %v", hookResponse.Status) - return user, ErrInvalidCredentials + return *user, ErrInvalidCredentials } - match, err := isPasswordOK(&user, password) + match, err := isPasswordOK(user, password) if !match { err = ErrInvalidCredentials } - return user, err + return *user, err } -func checkUserAndPubKey(user User, pubKey []byte) (User, string, error) { - err := checkLoginConditions(&user) +func checkUserAndPubKey(user *User, pubKey []byte) (User, string, error) { + err := checkLoginConditions(user) if err != nil { - return user, "", err + return *user, "", err } if len(user.PublicKeys) == 0 { - return user, "", ErrInvalidCredentials + return *user, "", ErrInvalidCredentials } for i, k := range user.PublicKeys { storedPubKey, comment, _, _, err := ssh.ParseAuthorizedKey([]byte(k)) if err != nil { providerLog(logger.LevelWarn, "error parsing stored public key %d for user %v: %v", i, user.Username, err) - return user, "", err + return *user, "", err } if bytes.Equal(storedPubKey.Marshal(), pubKey) { certInfo := "" @@ -1537,10 +1537,10 @@ func checkUserAndPubKey(user User, pubKey []byte) (User, string, error) { certInfo = fmt.Sprintf(" %v ID: %v Serial: %v CA: %v", cert.Type(), cert.KeyId, cert.Serial, ssh.FingerprintSHA256(cert.SignatureKey)) } - return user, fmt.Sprintf("%v:%v%v", ssh.FingerprintSHA256(storedPubKey), comment, certInfo), nil + return *user, fmt.Sprintf("%v:%v%v", ssh.FingerprintSHA256(storedPubKey), comment, certInfo), nil } } - return user, "", ErrInvalidCredentials + return *user, "", ErrInvalidCredentials } func compareUnixPasswordAndHash(user *User, password string) (bool, error) { @@ -1712,7 +1712,7 @@ func sendKeyboardAuthHTTPReq(url *url.URL, request keyboardAuthHookRequest) (key return response, err } -func executeKeyboardInteractiveHTTPHook(user User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) { +func executeKeyboardInteractiveHTTPHook(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) { authResult := 0 var url *url.URL url, err := url.Parse(authHook) @@ -1754,7 +1754,7 @@ func executeKeyboardInteractiveHTTPHook(user User, authHook string, client ssh.K } func getKeyboardInteractiveAnswers(client ssh.KeyboardInteractiveChallenge, response keyboardAuthHookResponse, - user User, ip, protocol string) ([]string, error) { + user *User, ip, protocol string) ([]string, error) { questions := response.Questions answers, err := client(user.Username, response.Instruction, questions, response.Echos) if err != nil { @@ -1779,7 +1779,7 @@ func getKeyboardInteractiveAnswers(client ssh.KeyboardInteractiveChallenge, resp } func handleProgramInteractiveQuestions(client ssh.KeyboardInteractiveChallenge, response keyboardAuthHookResponse, - user User, stdin io.WriteCloser, ip, protocol string) error { + user *User, stdin io.WriteCloser, ip, protocol string) error { answers, err := getKeyboardInteractiveAnswers(client, response, user, ip, protocol) if err != nil { return err @@ -1798,7 +1798,7 @@ func handleProgramInteractiveQuestions(client ssh.KeyboardInteractiveChallenge, return nil } -func executeKeyboardInteractiveProgram(user User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) { +func executeKeyboardInteractiveProgram(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) { authResult := 0 ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() @@ -1856,7 +1856,7 @@ func executeKeyboardInteractiveProgram(user User, authHook string, client ssh.Ke return authResult, err } -func doKeyboardInteractiveAuth(user User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (User, error) { +func doKeyboardInteractiveAuth(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (User, error) { var authResult int var err error if strings.HasPrefix(authHook, "http") { @@ -1865,16 +1865,16 @@ func doKeyboardInteractiveAuth(user User, authHook string, client ssh.KeyboardIn authResult, err = executeKeyboardInteractiveProgram(user, authHook, client, ip, protocol) } if err != nil { - return user, err + return *user, err } if authResult != 1 { - return user, fmt.Errorf("keyboard interactive auth failed, result: %v", authResult) + return *user, fmt.Errorf("keyboard interactive auth failed, result: %v", authResult) } - err = checkLoginConditions(&user) + err = checkLoginConditions(user) if err != nil { - return user, err + return *user, err } - return user, nil + return *user, nil } func isCheckPasswordHookDefined(protocol string) bool { diff --git a/dataprovider/memory.go b/dataprovider/memory.go index f2d9bff1..67ec2d3d 100644 --- a/dataprovider/memory.go +++ b/dataprovider/memory.go @@ -99,7 +99,7 @@ func (p *MemoryProvider) validateUserAndPass(username, password, ip, protocol st providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err) return user, err } - return checkUserAndPass(user, password, ip, protocol) + return checkUserAndPass(&user, password, ip, protocol) } func (p *MemoryProvider) validateUserAndPubKey(username string, pubKey []byte) (User, string, error) { @@ -112,7 +112,7 @@ func (p *MemoryProvider) validateUserAndPubKey(username string, pubKey []byte) ( providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err) return user, "", err } - return checkUserAndPubKey(user, pubKey) + return checkUserAndPubKey(&user, pubKey) } func (p *MemoryProvider) validateAdminAndPass(username, password, ip string) (Admin, error) { diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go index bb871ee2..f32a274d 100644 --- a/dataprovider/mysql.go +++ b/dataprovider/mysql.go @@ -225,23 +225,23 @@ func (p *MySQLProvider) initializeDatabase() error { return ErrNoInitRequired } sqlUsers := strings.Replace(mysqlUsersTableSQL, "{{users}}", sqlTableUsers, 1) - tx, err := p.dbHandle.Begin() + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + tx, err := p.dbHandle.BeginTx(ctx, nil) if err != nil { return err } _, err = tx.Exec(sqlUsers) if err != nil { - sqlCommonRollbackTransaction(tx) return err } _, err = tx.Exec(strings.Replace(mysqlSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1)) if err != nil { - sqlCommonRollbackTransaction(tx) return err } _, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1)) if err != nil { - sqlCommonRollbackTransaction(tx) return err } return tx.Commit() diff --git a/dataprovider/pgsql.go b/dataprovider/pgsql.go index 0ea778d8..2352dfde 100644 --- a/dataprovider/pgsql.go +++ b/dataprovider/pgsql.go @@ -229,23 +229,23 @@ func (p *PGSQLProvider) initializeDatabase() error { return ErrNoInitRequired } sqlUsers := strings.Replace(pgsqlUsersTableSQL, "{{users}}", sqlTableUsers, 1) - tx, err := p.dbHandle.Begin() + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + tx, err := p.dbHandle.BeginTx(ctx, nil) if err != nil { return err } _, err = tx.Exec(sqlUsers) if err != nil { - sqlCommonRollbackTransaction(tx) return err } _, err = tx.Exec(strings.Replace(pgsqlSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1)) if err != nil { - sqlCommonRollbackTransaction(tx) return err } _, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1)) if err != nil { - sqlCommonRollbackTransaction(tx) return err } return tx.Commit() diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go index 819c71e4..9617d76b 100644 --- a/dataprovider/sqlcommon.go +++ b/dataprovider/sqlcommon.go @@ -224,7 +224,7 @@ func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHan providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err) return user, err } - return checkUserAndPass(user, password, ip, protocol) + return checkUserAndPass(&user, password, ip, protocol) } func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sql.DB) (User, string, error) { @@ -237,7 +237,7 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sq providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err) return user, "", err } - return checkUserAndPubKey(user, pubKey) + return checkUserAndPubKey(&user, pubKey) } func sqlCommonCheckAvailability(dbHandle *sql.DB) error { @@ -313,6 +313,7 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error { } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() + tx, err := dbHandle.BeginTx(ctx, nil) if err != nil { return err @@ -321,40 +322,33 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error { stmt, err := tx.PrepareContext(ctx, q) if err != nil { providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err) - sqlCommonRollbackTransaction(tx) return err } defer stmt.Close() permissions, err := user.GetPermissionsAsJSON() if err != nil { - sqlCommonRollbackTransaction(tx) return err } publicKeys, err := user.GetPublicKeysAsJSON() if err != nil { - sqlCommonRollbackTransaction(tx) return err } filters, err := user.GetFiltersAsJSON() if err != nil { - sqlCommonRollbackTransaction(tx) return err } fsConfig, err := user.GetFsConfigAsJSON() if err != nil { - sqlCommonRollbackTransaction(tx) return err } _, err = stmt.ExecContext(ctx, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo) if err != nil { - sqlCommonRollbackTransaction(tx) return err } err = generateVirtualFoldersMapping(ctx, user, tx) if err != nil { - sqlCommonRollbackTransaction(tx) return err } return tx.Commit() @@ -367,6 +361,7 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error { } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() + tx, err := dbHandle.BeginTx(ctx, nil) if err != nil { return err @@ -375,40 +370,33 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error { stmt, err := tx.PrepareContext(ctx, q) if err != nil { providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err) - sqlCommonRollbackTransaction(tx) return err } defer stmt.Close() permissions, err := user.GetPermissionsAsJSON() if err != nil { - sqlCommonRollbackTransaction(tx) return err } publicKeys, err := user.GetPublicKeysAsJSON() if err != nil { - sqlCommonRollbackTransaction(tx) return err } filters, err := user.GetFiltersAsJSON() if err != nil { - sqlCommonRollbackTransaction(tx) return err } fsConfig, err := user.GetFsConfigAsJSON() if err != nil { - sqlCommonRollbackTransaction(tx) return err } _, err = stmt.ExecContext(ctx, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo, user.ID) if err != nil { - sqlCommonRollbackTransaction(tx) return err } err = generateVirtualFoldersMapping(ctx, user, tx) if err != nil { - sqlCommonRollbackTransaction(tx) return err } return tx.Commit() @@ -979,13 +967,6 @@ func sqlCommonGetFolderUsedQuota(mappedPath string, dbHandle *sql.DB) (int, int6 return usedFiles, usedSize, err } -func sqlCommonRollbackTransaction(tx *sql.Tx) { - err := tx.Rollback() - if err != nil { - providerLog(logger.LevelWarn, "error rolling back transaction: %v", err) - } -} - func sqlCommonGetDatabaseVersion(dbHandle *sql.DB, showInitWarn bool) (schemaVersion, error) { var result schemaVersion ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) @@ -1030,13 +1011,11 @@ func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sql []string, newVersi } _, err = tx.ExecContext(ctx, q) if err != nil { - sqlCommonRollbackTransaction(tx) return err } } err = sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion) if err != nil { - sqlCommonRollbackTransaction(tx) return err } return tx.Commit() @@ -1130,6 +1109,7 @@ func sqlCommonUpdateDatabaseFrom3To4(sqlV4 string, dbHandle *sql.DB) error { sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping) ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() + tx, err := dbHandle.BeginTx(ctx, nil) if err != nil { return err @@ -1140,25 +1120,14 @@ func sqlCommonUpdateDatabaseFrom3To4(sqlV4 string, dbHandle *sql.DB) error { } _, err = tx.ExecContext(ctx, q) if err != nil { - sqlCommonRollbackTransaction(tx) return err } } - /*_, err = sqlCommonRestoreCompatVirtualFolders(ctx, users, tx) - if err != nil { - sqlCommonRollbackTransaction(tx) - return err - }*/ err = sqlCommonUpdateDatabaseVersion(ctx, tx, 4) if err != nil { - sqlCommonRollbackTransaction(tx) return err } return tx.Commit() - /*if err == nil { - go updateVFoldersQuotaAfterRestore(foldersToScan) - } - return err*/ } //nolint:dupl diff --git a/dataprovider/sqlite.go b/dataprovider/sqlite.go index c5a744b9..8e782ce8 100644 --- a/dataprovider/sqlite.go +++ b/dataprovider/sqlite.go @@ -262,23 +262,23 @@ func (p *SQLiteProvider) initializeDatabase() error { return ErrNoInitRequired } sqlUsers := strings.Replace(sqliteUsersTableSQL, "{{users}}", sqlTableUsers, 1) - tx, err := p.dbHandle.Begin() + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + tx, err := p.dbHandle.BeginTx(ctx, nil) if err != nil { return err } _, err = tx.Exec(sqlUsers) if err != nil { - sqlCommonRollbackTransaction(tx) return err } _, err = tx.Exec(strings.Replace(sqliteSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1)) if err != nil { - sqlCommonRollbackTransaction(tx) return err } _, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1)) if err != nil { - sqlCommonRollbackTransaction(tx) return err } return tx.Commit() diff --git a/ftpd/handler.go b/ftpd/handler.go index 926e1683..e6b1e490 100644 --- a/ftpd/handler.go +++ b/ftpd/handler.go @@ -470,12 +470,12 @@ func (c *Connection) handleFTPUploadToExistingFile(flags int, resolvedPath, file if vfs.IsLocalOrSFTPFs(c.Fs) { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { - dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck if vfolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck } } else { - dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck } } else { initialSize = fileSize diff --git a/ftpd/server.go b/ftpd/server.go index bfa93e55..b80ad19c 100644 --- a/ftpd/server.go +++ b/ftpd/server.go @@ -152,7 +152,7 @@ func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string) connection.Fs.CheckRootPath(connection.GetUsername(), user.GetUID(), user.GetGID()) connection.Log(logger.LevelInfo, "User id: %d, logged in with FTP, username: %#v, home_dir: %#v remote addr: %#v", user.ID, user.Username, user.HomeDir, ipAddr) - dataprovider.UpdateLastLogin(user) //nolint:errcheck + dataprovider.UpdateLastLogin(&user) //nolint:errcheck return connection, nil } diff --git a/httpd/api_quota.go b/httpd/api_quota.go index 34fb8f93..468d4977 100644 --- a/httpd/api_quota.go +++ b/httpd/api_quota.go @@ -58,7 +58,7 @@ func updateUserQuotaUsage(w http.ResponseWriter, r *http.Request) { return } defer common.QuotaScans.RemoveUserQuotaScan(user.Username) - err = dataprovider.UpdateUserQuota(user, u.UsedQuotaFiles, u.UsedQuotaSize, mode == quotaUpdateModeReset) + err = dataprovider.UpdateUserQuota(&user, u.UsedQuotaFiles, u.UsedQuotaSize, mode == quotaUpdateModeReset) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) } else { @@ -94,7 +94,7 @@ func updateVFolderQuotaUsage(w http.ResponseWriter, r *http.Request) { return } defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name) - err = dataprovider.UpdateVirtualFolderQuota(folder, f.UsedQuotaFiles, f.UsedQuotaSize, mode == quotaUpdateModeReset) + err = dataprovider.UpdateVirtualFolderQuota(&folder, f.UsedQuotaFiles, f.UsedQuotaSize, mode == quotaUpdateModeReset) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) } else { @@ -165,7 +165,7 @@ func doQuotaScan(user dataprovider.User) error { logger.Warn(logSender, "", "error scanning user home dir %#v: %v", user.Username, err) return err } - err = dataprovider.UpdateUserQuota(user, numFiles, size, true) + err = dataprovider.UpdateUserQuota(&user, numFiles, size, true) logger.Debug(logSender, "", "user home dir scanned, user: %#v, error: %v", user.Username, err) return err } @@ -178,7 +178,7 @@ func doFolderQuotaScan(folder vfs.BaseVirtualFolder) error { logger.Warn(logSender, "", "error scanning folder %#v: %v", folder.MappedPath, err) return err } - err = dataprovider.UpdateVirtualFolderQuota(folder, numFiles, size, true) + err = dataprovider.UpdateVirtualFolderQuota(&folder, numFiles, size, true) logger.Debug(logSender, "", "virtual folder %#v scanned, error: %v", folder.Name, err) return err } diff --git a/httpd/web.go b/httpd/web.go index fa75dd02..13c34cf0 100644 --- a/httpd/web.go +++ b/httpd/web.go @@ -123,7 +123,7 @@ type foldersPage struct { type connectionsPage struct { basePage - Connections []common.ConnectionStatus + Connections []*common.ConnectionStatus } type statusPage struct { diff --git a/logger/logger.go b/logger/logger.go index 23a41851..7a4492b0 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -69,7 +69,7 @@ func (l *LeveledLogger) addKeysAndValues(ev *zerolog.Event, keysAndValues ...int extra := keysAndValues[kvLen-1] keysAndValues = append(keysAndValues[:kvLen-1], "EXTRA_VALUE_AT_END", extra) } - for i := 0; i < len(keysAndValues); i = i + 2 { + for i := 0; i < len(keysAndValues); i += 2 { key, val := keysAndValues[i], keysAndValues[i+1] if keyStr, ok := key.(string); ok { ev.Str(keyStr, fmt.Sprintf("%v", val)) diff --git a/sftpd/handler.go b/sftpd/handler.go index ac15a600..2b620fdb 100644 --- a/sftpd/handler.go +++ b/sftpd/handler.go @@ -412,12 +412,12 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r if vfs.IsLocalOrSFTPFs(c.Fs) && isTruncate { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { - dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck if vfolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck } } else { - dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck } } else { initialSize = fileSize @@ -460,7 +460,7 @@ func (c *Connection) getStatVFSFromQuotaResult(name string, quotaResult vfs.Quot bsize := uint64(4096) for bsize > uint64(quotaResult.QuotaSize) { - bsize = bsize / 4 + bsize /= 4 } blocks := uint64(quotaResult.QuotaSize) / bsize bfree := uint64(quotaResult.QuotaSize-quotaResult.UsedSize) / bsize diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 9f54ca16..8b5a63da 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -365,7 +365,7 @@ func TestUploadFiles(t *testing.T) { func TestWithInvalidHome(t *testing.T) { u := dataprovider.User{} u.HomeDir = "home_rel_path" //nolint:goconst - _, err := loginUser(u, dataprovider.LoginMethodPassword, "", nil) + _, err := loginUser(&u, dataprovider.LoginMethodPassword, "", nil) assert.Error(t, err, "login a user with an invalid home_dir must fail") u.HomeDir = os.TempDir() @@ -1890,7 +1890,7 @@ func TestRecursiveCopyErrors(t *testing.T) { func TestSFTPSubSystem(t *testing.T) { permissions := make(map[string][]string) permissions["/"] = []string{dataprovider.PermAny} - user := dataprovider.User{ + user := &dataprovider.User{ Permissions: permissions, HomeDir: os.TempDir(), } diff --git a/sftpd/scp.go b/sftpd/scp.go index 28c06243..d0c2b692 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -213,12 +213,12 @@ func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead if vfs.IsLocalOrSFTPFs(c.connection.Fs) { vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { - dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck if vfolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(c.connection.User, 0, -fileSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck } } else { - dataprovider.UpdateUserQuota(c.connection.User, 0, -fileSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck } } else { initialSize = fileSize diff --git a/sftpd/server.go b/sftpd/server.go index 5b902fe5..87dce5b1 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -408,7 +408,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v", user.ID, loginType, user.Username, user.HomeDir, ipAddr) - dataprovider.UpdateLastLogin(user) //nolint:errcheck + dataprovider.UpdateLastLogin(&user) //nolint:errcheck sshConnection := common.NewSSHConnection(connectionID, conn) common.Connections.AddSSHConnection(sshConnection) @@ -557,7 +557,7 @@ func checkRootPath(user *dataprovider.User, connectionID string) error { return nil } -func loginUser(user dataprovider.User, loginMethod, publicKey string, conn ssh.ConnMetadata) (*ssh.Permissions, error) { +func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.ConnMetadata) (*ssh.Permissions, error) { connectionID := "" if conn != nil { connectionID = hex.EncodeToString(conn.SessionID()) @@ -817,7 +817,7 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK logger.Debug(logSender, connectionID, "user %#v authenticated with partial success", conn.User()) return certPerm, ssh.ErrPartialSuccess } - sshPerm, err = loginUser(user, method, keyID, conn) + sshPerm, err = loginUser(&user, method, keyID, conn) if err == nil && certPerm != nil { // if we have a SSH user cert we need to merge certificate permissions with our ones // we only set Extensions, so CriticalOptions are always the ones from the certificate @@ -845,7 +845,7 @@ func (c *Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass } ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()) if user, err = dataprovider.CheckUserAndPass(conn.User(), string(pass), ipAddr, common.ProtocolSSH); err == nil { - sshPerm, err = loginUser(user, method, "", conn) + sshPerm, err = loginUser(&user, method, "", conn) } user.Username = conn.User() updateLoginMetrics(&user, ipAddr, method, err) @@ -864,7 +864,7 @@ func (c *Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMeta ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()) if user, err = dataprovider.CheckKeyboardInteractiveAuth(conn.User(), c.KeyboardInteractiveHook, client, ipAddr, common.ProtocolSSH); err == nil { - sshPerm, err = loginUser(user, method, "", conn) + sshPerm, err = loginUser(&user, method, "", conn) } user.Username = conn.User() updateLoginMetrics(&user, ipAddr, method, err) diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index 6ba19ea4..5aeeed95 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -6490,7 +6490,7 @@ func TestStatVFSCloudBackend(t *testing.T) { if assert.NoError(t, err) { defer client.Close() - err = dataprovider.UpdateUserQuota(user, 100, 8192, true) + err = dataprovider.UpdateUserQuota(&user, 100, 8192, true) assert.NoError(t, err) stat, err := client.StatVFS("/") assert.NoError(t, err) diff --git a/sftpd/ssh_cmd.go b/sftpd/ssh_cmd.go index 27f9ee2a..8a5ec9d0 100644 --- a/sftpd/ssh_cmd.go +++ b/sftpd/ssh_cmd.go @@ -248,12 +248,12 @@ func (c *sshCommand) handeSFTPGoRemove() error { func (c *sshCommand) updateQuota(sshDestPath string, filesNum int, filesSize int64) { vfolder, err := c.connection.User.GetVirtualFolderForPath(sshDestPath) if err == nil { - dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, filesNum, filesSize, false) //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, filesNum, filesSize, false) //nolint:errcheck if vfolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(c.connection.User, filesNum, filesSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.connection.User, filesNum, filesSize, false) //nolint:errcheck } } else { - dataprovider.UpdateUserQuota(c.connection.User, filesNum, filesSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.connection.User, filesNum, filesSize, false) //nolint:errcheck } } diff --git a/sftpd/subsystem.go b/sftpd/subsystem.go index 834ed162..c14e007f 100644 --- a/sftpd/subsystem.go +++ b/sftpd/subsystem.go @@ -35,7 +35,7 @@ func newSubsystemChannel(reader io.Reader, writer io.Writer) *subsystemChannel { } // ServeSubSystemConnection handles a connection as SSH subsystem -func ServeSubSystemConnection(user dataprovider.User, connectionID string, reader io.Reader, writer io.Writer) error { +func ServeSubSystemConnection(user *dataprovider.User, connectionID string, reader io.Reader, writer io.Writer) error { fs, err := user.GetFilesystem(connectionID) if err != nil { return err @@ -44,7 +44,7 @@ func ServeSubSystemConnection(user dataprovider.User, connectionID string, reade dataprovider.UpdateLastLogin(user) //nolint:errcheck connection := &Connection{ - BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolSFTP, user, fs), + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolSFTP, *user, fs), ClientVersion: "", RemoteAddr: &net.IPAddr{}, channel: newSubsystemChannel(reader, writer), diff --git a/vfs/azblobfs.go b/vfs/azblobfs.go index 2aa7a4ad..af18547f 100644 --- a/vfs/azblobfs.go +++ b/vfs/azblobfs.go @@ -253,7 +253,7 @@ func (fs *AzureBlobFs) Create(name string, flag int) (File, *PipeWriter, func(), // if we shutdown Azurite while uploading it hangs, so we use our own wrapper for // the low level functions _, err := azblob.UploadStreamToBlockBlob(ctx, r, blobBlockURL, uploadOptions)*/ - err := fs.handleMultipartUpload(ctx, r, blobBlockURL, headers) + err := fs.handleMultipartUpload(ctx, r, &blobBlockURL, &headers) r.CloseWithError(err) //nolint:errcheck p.Done(err) fsLog(fs, logger.LevelDebug, "upload completed, path: %#v, readed bytes: %v, err: %v", name, r.GetReadedBytes(), err) @@ -438,7 +438,8 @@ func (fs *AzureBlobFs) ReadDir(dirname string) ([]os.FileInfo, error) { result = append(result, NewFileInfo(name, true, 0, time.Now(), false)) prefixes[strings.TrimSuffix(name, "/")] = true } - for _, blobInfo := range listBlob.Segment.BlobItems { + for idx := range listBlob.Segment.BlobItems { + blobInfo := &listBlob.Segment.BlobItems[idx] name := strings.TrimPrefix(blobInfo.Name, prefix) size := int64(0) if blobInfo.Properties.ContentLength != nil { @@ -556,7 +557,8 @@ func (fs *AzureBlobFs) ScanRootDirContents() (int, int64, error) { return numFiles, size, err } marker = listBlob.NextMarker - for _, blobInfo := range listBlob.Segment.BlobItems { + for idx := range listBlob.Segment.BlobItems { + blobInfo := &listBlob.Segment.BlobItems[idx] isDir := false if blobInfo.Properties.ContentType != nil { isDir = (*blobInfo.Properties.ContentType == dirMimeType) @@ -637,7 +639,8 @@ func (fs *AzureBlobFs) Walk(root string, walkFn filepath.WalkFunc) error { return err } marker = listBlob.NextMarker - for _, blobInfo := range listBlob.Segment.BlobItems { + for idx := range listBlob.Segment.BlobItems { + blobInfo := &listBlob.Segment.BlobItems[idx] isDir := false if blobInfo.Properties.ContentType != nil { isDir = (*blobInfo.Properties.ContentType == dirMimeType) @@ -776,8 +779,8 @@ func (fs *AzureBlobFs) hasContents(name string) (bool, error) { return result, err } -func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Reader, blockBlobURL azblob.BlockBlobURL, - httpHeaders azblob.BlobHTTPHeaders) error { +func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Reader, blockBlobURL *azblob.BlockBlobURL, + httpHeaders *azblob.BlobHTTPHeaders) error { partSize := fs.config.UploadPartSize guard := make(chan struct{}, fs.config.UploadConcurrency) blockCtxTimeout := time.Duration(fs.config.UploadPartSize/(1024*1024)) * time.Minute @@ -852,7 +855,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read return poolError } - _, err := blockBlobURL.CommitBlockList(ctx, blocks, httpHeaders, azblob.Metadata{}, azblob.BlobAccessConditions{}, + _, err := blockBlobURL.CommitBlockList(ctx, blocks, *httpHeaders, azblob.Metadata{}, azblob.BlobAccessConditions{}, azblob.AccessTierType(fs.config.AccessTier), nil, azblob.ClientProvidedKeyOptions{}) return err } diff --git a/webdavd/handler.go b/webdavd/handler.go index 1987f033..7d18a54d 100644 --- a/webdavd/handler.go +++ b/webdavd/handler.go @@ -264,12 +264,12 @@ func (c *Connection) handleUploadToExistingFile(resolvedPath, filePath string, f if vfs.IsLocalOrSFTPFs(c.Fs) { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { - dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck + dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck if vfolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck } } else { - dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck } } else { initialSize = fileSize diff --git a/webdavd/server.go b/webdavd/server.go index c632db33..eb28ab3b 100644 --- a/webdavd/server.go +++ b/webdavd/server.go @@ -181,7 +181,7 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { common.Connections.Add(connection) defer common.Connections.Remove(connection.GetID()) - dataprovider.UpdateLastLogin(user) //nolint:errcheck + dataprovider.UpdateLastLogin(&user) //nolint:errcheck if s.checkRequestMethod(ctx, r, connection) { w.Header().Set("Content-Type", "text/xml; charset=utf-8")