micro optimizations spotted using the go-critic linter

This commit is contained in:
Nicola Murino 2021-02-16 19:11:36 +01:00
parent b1ce6eb85b
commit be9230e85b
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
29 changed files with 160 additions and 189 deletions

View file

@ -122,7 +122,7 @@ Command-line flags should be specified in the Subsystem declaration.
os.Exit(1) os.Exit(1)
} }
} }
err = sftpd.ServeSubSystemConnection(user, connectionID, os.Stdin, os.Stdout) err = sftpd.ServeSubSystemConnection(&user, connectionID, os.Stdin, os.Stdout)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
logger.Warn(logSender, connectionID, "serving subsystem finished with error: %v", err) logger.Warn(logSender, connectionID, "serving subsystem finished with error: %v", err)
os.Exit(1) os.Exit(1)

View file

@ -52,7 +52,7 @@ func SSHCommandActionNotification(user *dataprovider.User, filePath, target, ssh
// ActionHandler handles a notification for a Protocol Action. // ActionHandler handles a notification for a Protocol Action.
type ActionHandler interface { type ActionHandler interface {
Handle(notification ActionNotification) error Handle(notification *ActionNotification) error
} }
// ActionNotification defines a notification for a Protocol Action. // ActionNotification defines a notification for a Protocol Action.
@ -75,7 +75,7 @@ func newActionNotification(
operation, filePath, target, sshCmd, protocol string, operation, filePath, target, sshCmd, protocol string,
fileSize int64, fileSize int64,
err error, err error,
) ActionNotification { ) *ActionNotification {
var bucket, endpoint string var bucket, endpoint string
status := 1 status := 1
@ -99,7 +99,7 @@ func newActionNotification(
status = 0 status = 0
} }
return ActionNotification{ return &ActionNotification{
Action: operation, Action: operation,
Username: user.Username, Username: user.Username,
Path: filePath, Path: filePath,
@ -116,7 +116,7 @@ func newActionNotification(
type defaultActionHandler struct{} type defaultActionHandler struct{}
func (h *defaultActionHandler) Handle(notification ActionNotification) error { func (h *defaultActionHandler) Handle(notification *ActionNotification) error {
if !utils.IsStringInSlice(notification.Action, Config.Actions.ExecuteOn) { if !utils.IsStringInSlice(notification.Action, Config.Actions.ExecuteOn) {
return errUnconfiguredAction return errUnconfiguredAction
} }
@ -134,7 +134,7 @@ func (h *defaultActionHandler) Handle(notification ActionNotification) error {
return h.handleCommand(notification) return h.handleCommand(notification)
} }
func (h *defaultActionHandler) handleHTTP(notification ActionNotification) error { func (h *defaultActionHandler) handleHTTP(notification *ActionNotification) error {
u, err := url.Parse(Config.Actions.Hook) u, err := url.Parse(Config.Actions.Hook)
if err != nil { if err != nil {
logger.Warn(notification.Protocol, "", "Invalid hook %#v for operation %#v: %v", Config.Actions.Hook, notification.Action, err) logger.Warn(notification.Protocol, "", "Invalid hook %#v for operation %#v: %v", Config.Actions.Hook, notification.Action, err)
@ -165,7 +165,7 @@ func (h *defaultActionHandler) handleHTTP(notification ActionNotification) error
return err return err
} }
func (h *defaultActionHandler) handleCommand(notification ActionNotification) error { func (h *defaultActionHandler) handleCommand(notification *ActionNotification) error {
if !filepath.IsAbs(Config.Actions.Hook) { if !filepath.IsAbs(Config.Actions.Hook) {
err := fmt.Errorf("invalid notification command %#v", Config.Actions.Hook) err := fmt.Errorf("invalid notification command %#v", Config.Actions.Hook)
logger.Warn(notification.Protocol, "", "unable to execute notification command: %v", err) logger.Warn(notification.Protocol, "", "unable to execute notification command: %v", err)
@ -188,7 +188,7 @@ func (h *defaultActionHandler) handleCommand(notification ActionNotification) er
return err return err
} }
func notificationAsEnvVars(notification ActionNotification) []string { func notificationAsEnvVars(notification *ActionNotification) []string {
return []string{ return []string{
fmt.Sprintf("SFTPGO_ACTION=%v", notification.Action), fmt.Sprintf("SFTPGO_ACTION=%v", notification.Action),
fmt.Sprintf("SFTPGO_ACTION_USERNAME=%v", notification.Username), fmt.Sprintf("SFTPGO_ACTION_USERNAME=%v", notification.Username),

View file

@ -201,7 +201,7 @@ type actionHandlerStub struct {
called bool called bool
} }
func (h *actionHandlerStub) Handle(notification ActionNotification) error { func (h *actionHandlerStub) Handle(notification *ActionNotification) error {
h.called = true h.called = true
return nil return nil
@ -215,7 +215,7 @@ func TestInitializeActionHandler(t *testing.T) {
InitializeActionHandler(&defaultActionHandler{}) InitializeActionHandler(&defaultActionHandler{})
}) })
err := actionHandler.Handle(ActionNotification{}) err := actionHandler.Handle(&ActionNotification{})
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, handler.called) assert.True(t, handler.called)

View file

@ -630,13 +630,13 @@ func (conns *ActiveConnections) IsNewConnectionAllowed() bool {
} }
// GetStats returns stats for active connections // GetStats returns stats for active connections
func (conns *ActiveConnections) GetStats() []ConnectionStatus { func (conns *ActiveConnections) GetStats() []*ConnectionStatus {
conns.RLock() conns.RLock()
defer conns.RUnlock() defer conns.RUnlock()
stats := make([]ConnectionStatus, 0, len(conns.connections)) stats := make([]*ConnectionStatus, 0, len(conns.connections))
for _, c := range conns.connections { for _, c := range conns.connections {
stat := ConnectionStatus{ stat := &ConnectionStatus{
Username: c.GetUsername(), Username: c.GetUsername(),
ConnectionID: c.GetID(), ConnectionID: c.GetID(),
ClientVersion: c.GetClientVersion(), ClientVersion: c.GetClientVersion(),
@ -675,14 +675,14 @@ type ConnectionStatus struct {
} }
// GetConnectionDuration returns the connection duration as string // GetConnectionDuration returns the connection duration as string
func (c ConnectionStatus) GetConnectionDuration() string { func (c *ConnectionStatus) GetConnectionDuration() string {
elapsed := time.Since(utils.GetTimeFromMsecSinceEpoch(c.ConnectionTime)) elapsed := time.Since(utils.GetTimeFromMsecSinceEpoch(c.ConnectionTime))
return utils.GetDurationAsString(elapsed) return utils.GetDurationAsString(elapsed)
} }
// GetConnectionInfo returns connection info. // GetConnectionInfo returns connection info.
// Protocol,Client Version and RemoteAddress are returned. // Protocol,Client Version and RemoteAddress are returned.
func (c ConnectionStatus) GetConnectionInfo() string { func (c *ConnectionStatus) GetConnectionInfo() string {
var result strings.Builder var result strings.Builder
result.WriteString(fmt.Sprintf("%v. Client: %#v From: %#v", c.Protocol, c.ClientVersion, c.RemoteAddress)) result.WriteString(fmt.Sprintf("%v. Client: %#v From: %#v", c.Protocol, c.ClientVersion, c.RemoteAddress))
@ -702,7 +702,7 @@ func (c ConnectionStatus) GetConnectionInfo() string {
} }
// GetTransfersAsString returns the active transfers as string // GetTransfersAsString returns the active transfers as string
func (c ConnectionStatus) GetTransfersAsString() string { func (c *ConnectionStatus) GetTransfersAsString() string {
result := "" result := ""
for _, t := range c.Transfers { for _, t := range c.Transfers {
if result != "" { if result != "" {

View file

@ -37,10 +37,10 @@ type BaseConnection struct {
} }
// NewBaseConnection returns a new BaseConnection // NewBaseConnection returns a new BaseConnection
func NewBaseConnection(ID, protocol string, user dataprovider.User, fs vfs.Fs) *BaseConnection { func NewBaseConnection(id, protocol string, user dataprovider.User, fs vfs.Fs) *BaseConnection {
connID := ID connID := id
if utils.IsStringInSlice(protocol, supportedProtocols) { if utils.IsStringInSlice(protocol, supportedProtocols) {
connID = fmt.Sprintf("%v_%v", protocol, ID) connID = fmt.Sprintf("%v_%v", protocol, id)
} }
return &BaseConnection{ return &BaseConnection{
ID: connID, ID: connID,
@ -272,12 +272,12 @@ func (c *BaseConnection) RemoveFile(fsPath, virtualPath string, info os.FileInfo
if info.Mode()&os.ModeSymlink == 0 { if info.Mode()&os.ModeSymlink == 0 {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath)) vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
if err == nil { if err == nil {
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, -1, -size, false) //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, -1, -size, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() { if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(c.User, -1, -size, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck
} }
} else { } else {
dataprovider.UpdateUserQuota(c.User, -1, -size, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck
} }
} }
if actionErr != nil { if actionErr != nil {
@ -577,12 +577,12 @@ func (c *BaseConnection) truncateFile(fsPath, virtualPath string, size int64) er
sizeDiff := initialSize - size sizeDiff := initialSize - size
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath)) vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
if err == nil { if err == nil {
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -sizeDiff, false) //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -sizeDiff, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() { if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(c.User, 0, -sizeDiff, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck
} }
} else { } else {
dataprovider.UpdateUserQuota(c.User, 0, -sizeDiff, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck
} }
} }
return err return err
@ -835,64 +835,64 @@ func (c *BaseConnection) isCrossFoldersRequest(virtualSourcePath, virtualTargetP
return true return true
} }
func (c *BaseConnection) updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder vfs.VirtualFolder, initialSize, func (c *BaseConnection) updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder *vfs.VirtualFolder, initialSize,
filesSize int64, numFiles int) { filesSize int64, numFiles int) {
if sourceFolder.MappedPath == dstFolder.MappedPath { if sourceFolder.MappedPath == dstFolder.MappedPath {
// both files are inside the same virtual folder // both files are inside the same virtual folder
if initialSize != -1 { if initialSize != -1 {
dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, -numFiles, -initialSize, false) //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, -numFiles, -initialSize, false) //nolint:errcheck
if dstFolder.IsIncludedInUserQuota() { if dstFolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(c.User, -numFiles, -initialSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, -numFiles, -initialSize, false) //nolint:errcheck
} }
} }
return return
} }
// files are inside different virtual folders // files are inside different virtual folders
dataprovider.UpdateVirtualFolderQuota(sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck
if sourceFolder.IsIncludedInUserQuota() { if sourceFolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(c.User, -numFiles, -filesSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
} }
if initialSize == -1 { if initialSize == -1 {
dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck
if dstFolder.IsIncludedInUserQuota() { if dstFolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(c.User, numFiles, filesSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
} }
} else { } else {
// we cannot have a directory here, initialSize != -1 only for files // we cannot have a directory here, initialSize != -1 only for files
dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck
if dstFolder.IsIncludedInUserQuota() { if dstFolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(c.User, 0, filesSize-initialSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
} }
} }
} }
func (c *BaseConnection) updateQuotaMoveFromVFolder(sourceFolder vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) { func (c *BaseConnection) updateQuotaMoveFromVFolder(sourceFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) {
// move between a virtual folder and the user home dir // move between a virtual folder and the user home dir
dataprovider.UpdateVirtualFolderQuota(sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck
if sourceFolder.IsIncludedInUserQuota() { if sourceFolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(c.User, -numFiles, -filesSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
} }
if initialSize == -1 { if initialSize == -1 {
dataprovider.UpdateUserQuota(c.User, numFiles, filesSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
} else { } else {
// we cannot have a directory here, initialSize != -1 only for files // we cannot have a directory here, initialSize != -1 only for files
dataprovider.UpdateUserQuota(c.User, 0, filesSize-initialSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
} }
} }
func (c *BaseConnection) updateQuotaMoveToVFolder(dstFolder vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) { func (c *BaseConnection) updateQuotaMoveToVFolder(dstFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) {
// move between the user home dir and a virtual folder // move between the user home dir and a virtual folder
dataprovider.UpdateUserQuota(c.User, -numFiles, -filesSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
if initialSize == -1 { if initialSize == -1 {
dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck
if dstFolder.IsIncludedInUserQuota() { if dstFolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(c.User, numFiles, filesSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
} }
} else { } else {
// we cannot have a directory here, initialSize != -1 only for files // we cannot have a directory here, initialSize != -1 only for files
dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck
if dstFolder.IsIncludedInUserQuota() { if dstFolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(c.User, 0, filesSize-initialSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
} }
} }
} }
@ -909,7 +909,7 @@ func (c *BaseConnection) updateQuotaAfterRename(virtualSourcePath, virtualTarget
// both files are contained inside the user home dir // both files are contained inside the user home dir
if initialSize != -1 { if initialSize != -1 {
// we cannot have a directory here // we cannot have a directory here
dataprovider.UpdateUserQuota(c.User, -1, -initialSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, -1, -initialSize, false) //nolint:errcheck
} }
return nil return nil
} }
@ -932,13 +932,13 @@ func (c *BaseConnection) updateQuotaAfterRename(virtualSourcePath, virtualTarget
return err return err
} }
if errSrc == nil && errDst == nil { if errSrc == nil && errDst == nil {
c.updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder, initialSize, filesSize, numFiles) c.updateQuotaMoveBetweenVFolders(&sourceFolder, &dstFolder, initialSize, filesSize, numFiles)
} }
if errSrc == nil && errDst != nil { if errSrc == nil && errDst != nil {
c.updateQuotaMoveFromVFolder(sourceFolder, initialSize, filesSize, numFiles) c.updateQuotaMoveFromVFolder(&sourceFolder, initialSize, filesSize, numFiles)
} }
if errSrc != nil && errDst == nil { if errSrc != nil && errDst == nil {
c.updateQuotaMoveToVFolder(dstFolder, initialSize, filesSize, numFiles) c.updateQuotaMoveToVFolder(&dstFolder, initialSize, filesSize, numFiles)
} }
return nil return nil
} }

View file

@ -1054,7 +1054,7 @@ func TestHasSpace(t *testing.T) {
folder, err := dataprovider.GetFolderByName(folderName) folder, err := dataprovider.GetFolderByName(folderName)
assert.NoError(t, err) assert.NoError(t, err)
err = dataprovider.UpdateVirtualFolderQuota(folder, 10, 1048576, true) err = dataprovider.UpdateVirtualFolderQuota(&folder, 10, 1048576, true)
assert.NoError(t, err) assert.NoError(t, err)
quotaResult = c.HasSpace(true, false, "/vdir/file1") quotaResult = c.HasSpace(true, false, "/vdir/file1")
assert.False(t, quotaResult.HasSpace) assert.False(t, quotaResult.HasSpace)
@ -1105,14 +1105,14 @@ func TestUpdateQuotaMoveVFolders(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
folder2, err := dataprovider.GetFolderByName(folderName2) folder2, err := dataprovider.GetFolderByName(folderName2)
assert.NoError(t, err) assert.NoError(t, err)
err = dataprovider.UpdateVirtualFolderQuota(folder1, 1, 100, true) err = dataprovider.UpdateVirtualFolderQuota(&folder1, 1, 100, true)
assert.NoError(t, err) assert.NoError(t, err)
err = dataprovider.UpdateVirtualFolderQuota(folder2, 2, 150, true) err = dataprovider.UpdateVirtualFolderQuota(&folder2, 2, 150, true)
assert.NoError(t, err) assert.NoError(t, err)
fs, err := user.GetFilesystem("id") fs, err := user.GetFilesystem("id")
assert.NoError(t, err) assert.NoError(t, err)
c := NewBaseConnection("", ProtocolSFTP, user, fs) c := NewBaseConnection("", ProtocolSFTP, user, fs)
c.updateQuotaMoveBetweenVFolders(user.VirtualFolders[0], user.VirtualFolders[1], -1, 100, 1) c.updateQuotaMoveBetweenVFolders(&user.VirtualFolders[0], &user.VirtualFolders[1], -1, 100, 1)
folder1, err = dataprovider.GetFolderByName(folderName1) folder1, err = dataprovider.GetFolderByName(folderName1)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 0, folder1.UsedQuotaFiles) assert.Equal(t, 0, folder1.UsedQuotaFiles)
@ -1122,7 +1122,7 @@ func TestUpdateQuotaMoveVFolders(t *testing.T) {
assert.Equal(t, 3, folder2.UsedQuotaFiles) assert.Equal(t, 3, folder2.UsedQuotaFiles)
assert.Equal(t, int64(250), folder2.UsedQuotaSize) assert.Equal(t, int64(250), folder2.UsedQuotaSize)
c.updateQuotaMoveBetweenVFolders(user.VirtualFolders[1], user.VirtualFolders[0], 10, 100, 1) c.updateQuotaMoveBetweenVFolders(&user.VirtualFolders[1], &user.VirtualFolders[0], 10, 100, 1)
folder1, err = dataprovider.GetFolderByName(folderName1) folder1, err = dataprovider.GetFolderByName(folderName1)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 0, folder1.UsedQuotaFiles) assert.Equal(t, 0, folder1.UsedQuotaFiles)
@ -1132,9 +1132,9 @@ func TestUpdateQuotaMoveVFolders(t *testing.T) {
assert.Equal(t, 2, folder2.UsedQuotaFiles) assert.Equal(t, 2, folder2.UsedQuotaFiles)
assert.Equal(t, int64(150), folder2.UsedQuotaSize) assert.Equal(t, int64(150), folder2.UsedQuotaSize)
err = dataprovider.UpdateUserQuota(user, 1, 100, true) err = dataprovider.UpdateUserQuota(&user, 1, 100, true)
assert.NoError(t, err) assert.NoError(t, err)
c.updateQuotaMoveFromVFolder(user.VirtualFolders[1], -1, 50, 1) c.updateQuotaMoveFromVFolder(&user.VirtualFolders[1], -1, 50, 1)
folder2, err = dataprovider.GetFolderByName(folderName2) folder2, err = dataprovider.GetFolderByName(folderName2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, folder2.UsedQuotaFiles) assert.Equal(t, 1, folder2.UsedQuotaFiles)
@ -1144,7 +1144,7 @@ func TestUpdateQuotaMoveVFolders(t *testing.T) {
assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, 1, user.UsedQuotaFiles)
assert.Equal(t, int64(100), user.UsedQuotaSize) assert.Equal(t, int64(100), user.UsedQuotaSize)
c.updateQuotaMoveToVFolder(user.VirtualFolders[1], -1, 100, 1) c.updateQuotaMoveToVFolder(&user.VirtualFolders[1], -1, 100, 1)
folder2, err = dataprovider.GetFolderByName(folderName2) folder2, err = dataprovider.GetFolderByName(folderName2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 2, folder2.UsedQuotaFiles) assert.Equal(t, 2, folder2.UsedQuotaFiles)

View file

@ -266,13 +266,13 @@ func (t *BaseTransfer) updateQuota(numFiles int, fileSize int64) bool {
if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff > 0) { if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff > 0) {
vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath)) vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath))
if err == nil { if err == nil {
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck
sizeDiff, false) sizeDiff, false)
if vfolder.IsIncludedInUserQuota() { if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck
} }
} else { } else {
dataprovider.UpdateUserQuota(t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck
} }
return true return true
} }

View file

@ -22,8 +22,7 @@ const (
) )
var ( var (
usersBucket = []byte("users") usersBucket = []byte("users")
//usersIDIdxBucket = []byte("users_id_idx")
foldersBucket = []byte("folders") foldersBucket = []byte("folders")
adminsBucket = []byte("admins") adminsBucket = []byte("admins")
dbVersionBucket = []byte("db_version") dbVersionBucket = []byte("db_version")
@ -113,7 +112,7 @@ func (p *BoltProvider) validateUserAndPass(username, password, ip, protocol stri
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err) providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
return user, err return user, err
} }
return checkUserAndPass(user, password, ip, protocol) return checkUserAndPass(&user, password, ip, protocol)
} }
func (p *BoltProvider) validateAdminAndPass(username, password, ip string) (Admin, error) { func (p *BoltProvider) validateAdminAndPass(username, password, ip string) (Admin, error) {
@ -136,7 +135,7 @@ func (p *BoltProvider) validateUserAndPubKey(username string, pubKey []byte) (Us
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err) providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
return user, "", err return user, "", err
} }
return checkUserAndPubKey(user, pubKey) return checkUserAndPubKey(&user, pubKey)
} }
func (p *BoltProvider) updateLastLogin(username string) error { func (p *BoltProvider) updateLastLogin(username string) error {

View file

@ -625,14 +625,14 @@ func CheckUserAndPass(username, password, ip, protocol string) (User, error) {
if err != nil { if err != nil {
return user, err return user, err
} }
return checkUserAndPass(user, password, ip, protocol) return checkUserAndPass(&user, password, ip, protocol)
} }
if config.PreLoginHook != "" { if config.PreLoginHook != "" {
user, err := executePreLoginHook(username, LoginMethodPassword, ip, protocol) user, err := executePreLoginHook(username, LoginMethodPassword, ip, protocol)
if err != nil { if err != nil {
return user, err return user, err
} }
return checkUserAndPass(user, password, ip, protocol) return checkUserAndPass(&user, password, ip, protocol)
} }
return provider.validateUserAndPass(username, password, ip, protocol) return provider.validateUserAndPass(username, password, ip, protocol)
} }
@ -644,14 +644,14 @@ func CheckUserAndPubKey(username string, pubKey []byte, ip, protocol string) (Us
if err != nil { if err != nil {
return user, "", err return user, "", err
} }
return checkUserAndPubKey(user, pubKey) return checkUserAndPubKey(&user, pubKey)
} }
if config.PreLoginHook != "" { if config.PreLoginHook != "" {
user, err := executePreLoginHook(username, SSHLoginMethodPublicKey, ip, protocol) user, err := executePreLoginHook(username, SSHLoginMethodPublicKey, ip, protocol)
if err != nil { if err != nil {
return user, "", err return user, "", err
} }
return checkUserAndPubKey(user, pubKey) return checkUserAndPubKey(&user, pubKey)
} }
return provider.validateUserAndPubKey(username, pubKey) return provider.validateUserAndPubKey(username, pubKey)
} }
@ -671,11 +671,11 @@ func CheckKeyboardInteractiveAuth(username, authHook string, client ssh.Keyboard
if err != nil { if err != nil {
return user, err return user, err
} }
return doKeyboardInteractiveAuth(user, authHook, client, ip, protocol) return doKeyboardInteractiveAuth(&user, authHook, client, ip, protocol)
} }
// UpdateLastLogin updates the last login fields for the given SFTP user // UpdateLastLogin updates the last login fields for the given SFTP user
func UpdateLastLogin(user User) error { func UpdateLastLogin(user *User) error {
lastLogin := utils.GetTimeFromMsecSinceEpoch(user.LastLogin) lastLogin := utils.GetTimeFromMsecSinceEpoch(user.LastLogin)
diff := -time.Until(lastLogin) diff := -time.Until(lastLogin)
if diff < 0 || diff > lastLoginMinDelay { if diff < 0 || diff > lastLoginMinDelay {
@ -690,7 +690,7 @@ func UpdateLastLogin(user User) error {
// UpdateUserQuota updates the quota for the given SFTP user adding filesAdd and sizeAdd. // UpdateUserQuota updates the quota for the given SFTP user adding filesAdd and sizeAdd.
// If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference. // If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference.
func UpdateUserQuota(user User, filesAdd int, sizeAdd int64, reset bool) error { func UpdateUserQuota(user *User, filesAdd int, sizeAdd int64, reset bool) error {
if config.TrackQuota == 0 { if config.TrackQuota == 0 {
return &MethodDisabledError{err: trackQuotaDisabledError} return &MethodDisabledError{err: trackQuotaDisabledError}
} else if config.TrackQuota == 2 && !reset && !user.HasQuotaRestrictions() { } else if config.TrackQuota == 2 && !reset && !user.HasQuotaRestrictions() {
@ -704,7 +704,7 @@ func UpdateUserQuota(user User, filesAdd int, sizeAdd int64, reset bool) error {
// UpdateVirtualFolderQuota updates the quota for the given virtual folder adding filesAdd and sizeAdd. // UpdateVirtualFolderQuota updates the quota for the given virtual folder adding filesAdd and sizeAdd.
// If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference. // If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference.
func UpdateVirtualFolderQuota(vfolder vfs.BaseVirtualFolder, filesAdd int, sizeAdd int64, reset bool) error { func UpdateVirtualFolderQuota(vfolder *vfs.BaseVirtualFolder, filesAdd int, sizeAdd int64, reset bool) error {
if config.TrackQuota == 0 { if config.TrackQuota == 0 {
return &MethodDisabledError{err: trackQuotaDisabledError} return &MethodDisabledError{err: trackQuotaDisabledError}
} }
@ -1482,53 +1482,53 @@ func isPasswordOK(user *User, password string) (bool, error) {
return match, err return match, err
} }
func checkUserAndPass(user User, password, ip, protocol string) (User, error) { func checkUserAndPass(user *User, password, ip, protocol string) (User, error) {
err := checkLoginConditions(&user) err := checkLoginConditions(user)
if err != nil { if err != nil {
return user, err return *user, err
} }
if user.Password == "" { if user.Password == "" {
return user, errors.New("Credentials cannot be null or empty") return *user, errors.New("Credentials cannot be null or empty")
} }
hookResponse, err := executeCheckPasswordHook(user.Username, password, ip, protocol) hookResponse, err := executeCheckPasswordHook(user.Username, password, ip, protocol)
if err != nil { if err != nil {
providerLog(logger.LevelDebug, "error executing check password hook: %v", err) providerLog(logger.LevelDebug, "error executing check password hook: %v", err)
return user, errors.New("Unable to check credentials") return *user, errors.New("Unable to check credentials")
} }
switch hookResponse.Status { switch hookResponse.Status {
case -1: case -1:
// no hook configured // no hook configured
case 1: case 1:
providerLog(logger.LevelDebug, "password accepted by check password hook") providerLog(logger.LevelDebug, "password accepted by check password hook")
return user, nil return *user, nil
case 2: case 2:
providerLog(logger.LevelDebug, "partial success from check password hook") providerLog(logger.LevelDebug, "partial success from check password hook")
password = hookResponse.ToVerify password = hookResponse.ToVerify
default: default:
providerLog(logger.LevelDebug, "password rejected by check password hook, status: %v", hookResponse.Status) providerLog(logger.LevelDebug, "password rejected by check password hook, status: %v", hookResponse.Status)
return user, ErrInvalidCredentials return *user, ErrInvalidCredentials
} }
match, err := isPasswordOK(&user, password) match, err := isPasswordOK(user, password)
if !match { if !match {
err = ErrInvalidCredentials err = ErrInvalidCredentials
} }
return user, err return *user, err
} }
func checkUserAndPubKey(user User, pubKey []byte) (User, string, error) { func checkUserAndPubKey(user *User, pubKey []byte) (User, string, error) {
err := checkLoginConditions(&user) err := checkLoginConditions(user)
if err != nil { if err != nil {
return user, "", err return *user, "", err
} }
if len(user.PublicKeys) == 0 { if len(user.PublicKeys) == 0 {
return user, "", ErrInvalidCredentials return *user, "", ErrInvalidCredentials
} }
for i, k := range user.PublicKeys { for i, k := range user.PublicKeys {
storedPubKey, comment, _, _, err := ssh.ParseAuthorizedKey([]byte(k)) storedPubKey, comment, _, _, err := ssh.ParseAuthorizedKey([]byte(k))
if err != nil { if err != nil {
providerLog(logger.LevelWarn, "error parsing stored public key %d for user %v: %v", i, user.Username, err) providerLog(logger.LevelWarn, "error parsing stored public key %d for user %v: %v", i, user.Username, err)
return user, "", err return *user, "", err
} }
if bytes.Equal(storedPubKey.Marshal(), pubKey) { if bytes.Equal(storedPubKey.Marshal(), pubKey) {
certInfo := "" certInfo := ""
@ -1537,10 +1537,10 @@ func checkUserAndPubKey(user User, pubKey []byte) (User, string, error) {
certInfo = fmt.Sprintf(" %v ID: %v Serial: %v CA: %v", cert.Type(), cert.KeyId, cert.Serial, certInfo = fmt.Sprintf(" %v ID: %v Serial: %v CA: %v", cert.Type(), cert.KeyId, cert.Serial,
ssh.FingerprintSHA256(cert.SignatureKey)) ssh.FingerprintSHA256(cert.SignatureKey))
} }
return user, fmt.Sprintf("%v:%v%v", ssh.FingerprintSHA256(storedPubKey), comment, certInfo), nil return *user, fmt.Sprintf("%v:%v%v", ssh.FingerprintSHA256(storedPubKey), comment, certInfo), nil
} }
} }
return user, "", ErrInvalidCredentials return *user, "", ErrInvalidCredentials
} }
func compareUnixPasswordAndHash(user *User, password string) (bool, error) { func compareUnixPasswordAndHash(user *User, password string) (bool, error) {
@ -1712,7 +1712,7 @@ func sendKeyboardAuthHTTPReq(url *url.URL, request keyboardAuthHookRequest) (key
return response, err return response, err
} }
func executeKeyboardInteractiveHTTPHook(user User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) { func executeKeyboardInteractiveHTTPHook(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) {
authResult := 0 authResult := 0
var url *url.URL var url *url.URL
url, err := url.Parse(authHook) url, err := url.Parse(authHook)
@ -1754,7 +1754,7 @@ func executeKeyboardInteractiveHTTPHook(user User, authHook string, client ssh.K
} }
func getKeyboardInteractiveAnswers(client ssh.KeyboardInteractiveChallenge, response keyboardAuthHookResponse, func getKeyboardInteractiveAnswers(client ssh.KeyboardInteractiveChallenge, response keyboardAuthHookResponse,
user User, ip, protocol string) ([]string, error) { user *User, ip, protocol string) ([]string, error) {
questions := response.Questions questions := response.Questions
answers, err := client(user.Username, response.Instruction, questions, response.Echos) answers, err := client(user.Username, response.Instruction, questions, response.Echos)
if err != nil { if err != nil {
@ -1779,7 +1779,7 @@ func getKeyboardInteractiveAnswers(client ssh.KeyboardInteractiveChallenge, resp
} }
func handleProgramInteractiveQuestions(client ssh.KeyboardInteractiveChallenge, response keyboardAuthHookResponse, func handleProgramInteractiveQuestions(client ssh.KeyboardInteractiveChallenge, response keyboardAuthHookResponse,
user User, stdin io.WriteCloser, ip, protocol string) error { user *User, stdin io.WriteCloser, ip, protocol string) error {
answers, err := getKeyboardInteractiveAnswers(client, response, user, ip, protocol) answers, err := getKeyboardInteractiveAnswers(client, response, user, ip, protocol)
if err != nil { if err != nil {
return err return err
@ -1798,7 +1798,7 @@ func handleProgramInteractiveQuestions(client ssh.KeyboardInteractiveChallenge,
return nil return nil
} }
func executeKeyboardInteractiveProgram(user User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) { func executeKeyboardInteractiveProgram(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) {
authResult := 0 authResult := 0
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel() defer cancel()
@ -1856,7 +1856,7 @@ func executeKeyboardInteractiveProgram(user User, authHook string, client ssh.Ke
return authResult, err return authResult, err
} }
func doKeyboardInteractiveAuth(user User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (User, error) { func doKeyboardInteractiveAuth(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (User, error) {
var authResult int var authResult int
var err error var err error
if strings.HasPrefix(authHook, "http") { if strings.HasPrefix(authHook, "http") {
@ -1865,16 +1865,16 @@ func doKeyboardInteractiveAuth(user User, authHook string, client ssh.KeyboardIn
authResult, err = executeKeyboardInteractiveProgram(user, authHook, client, ip, protocol) authResult, err = executeKeyboardInteractiveProgram(user, authHook, client, ip, protocol)
} }
if err != nil { if err != nil {
return user, err return *user, err
} }
if authResult != 1 { if authResult != 1 {
return user, fmt.Errorf("keyboard interactive auth failed, result: %v", authResult) return *user, fmt.Errorf("keyboard interactive auth failed, result: %v", authResult)
} }
err = checkLoginConditions(&user) err = checkLoginConditions(user)
if err != nil { if err != nil {
return user, err return *user, err
} }
return user, nil return *user, nil
} }
func isCheckPasswordHookDefined(protocol string) bool { func isCheckPasswordHookDefined(protocol string) bool {

View file

@ -99,7 +99,7 @@ func (p *MemoryProvider) validateUserAndPass(username, password, ip, protocol st
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err) providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
return user, err return user, err
} }
return checkUserAndPass(user, password, ip, protocol) return checkUserAndPass(&user, password, ip, protocol)
} }
func (p *MemoryProvider) validateUserAndPubKey(username string, pubKey []byte) (User, string, error) { func (p *MemoryProvider) validateUserAndPubKey(username string, pubKey []byte) (User, string, error) {
@ -112,7 +112,7 @@ func (p *MemoryProvider) validateUserAndPubKey(username string, pubKey []byte) (
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err) providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
return user, "", err return user, "", err
} }
return checkUserAndPubKey(user, pubKey) return checkUserAndPubKey(&user, pubKey)
} }
func (p *MemoryProvider) validateAdminAndPass(username, password, ip string) (Admin, error) { func (p *MemoryProvider) validateAdminAndPass(username, password, ip string) (Admin, error) {

View file

@ -225,23 +225,23 @@ func (p *MySQLProvider) initializeDatabase() error {
return ErrNoInitRequired return ErrNoInitRequired
} }
sqlUsers := strings.Replace(mysqlUsersTableSQL, "{{users}}", sqlTableUsers, 1) sqlUsers := strings.Replace(mysqlUsersTableSQL, "{{users}}", sqlTableUsers, 1)
tx, err := p.dbHandle.Begin() ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
defer cancel()
tx, err := p.dbHandle.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
_, err = tx.Exec(sqlUsers) _, err = tx.Exec(sqlUsers)
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
_, err = tx.Exec(strings.Replace(mysqlSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1)) _, err = tx.Exec(strings.Replace(mysqlSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
_, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1)) _, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
return tx.Commit() return tx.Commit()

View file

@ -229,23 +229,23 @@ func (p *PGSQLProvider) initializeDatabase() error {
return ErrNoInitRequired return ErrNoInitRequired
} }
sqlUsers := strings.Replace(pgsqlUsersTableSQL, "{{users}}", sqlTableUsers, 1) sqlUsers := strings.Replace(pgsqlUsersTableSQL, "{{users}}", sqlTableUsers, 1)
tx, err := p.dbHandle.Begin() ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
defer cancel()
tx, err := p.dbHandle.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
_, err = tx.Exec(sqlUsers) _, err = tx.Exec(sqlUsers)
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
_, err = tx.Exec(strings.Replace(pgsqlSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1)) _, err = tx.Exec(strings.Replace(pgsqlSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
_, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1)) _, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
return tx.Commit() return tx.Commit()

View file

@ -224,7 +224,7 @@ func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHan
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err) providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
return user, err return user, err
} }
return checkUserAndPass(user, password, ip, protocol) return checkUserAndPass(&user, password, ip, protocol)
} }
func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sql.DB) (User, string, error) { func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sql.DB) (User, string, error) {
@ -237,7 +237,7 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sq
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err) providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
return user, "", err return user, "", err
} }
return checkUserAndPubKey(user, pubKey) return checkUserAndPubKey(&user, pubKey)
} }
func sqlCommonCheckAvailability(dbHandle *sql.DB) error { func sqlCommonCheckAvailability(dbHandle *sql.DB) error {
@ -313,6 +313,7 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel() defer cancel()
tx, err := dbHandle.BeginTx(ctx, nil) tx, err := dbHandle.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
@ -321,40 +322,33 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
stmt, err := tx.PrepareContext(ctx, q) stmt, err := tx.PrepareContext(ctx, q)
if err != nil { if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err) providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
sqlCommonRollbackTransaction(tx)
return err return err
} }
defer stmt.Close() defer stmt.Close()
permissions, err := user.GetPermissionsAsJSON() permissions, err := user.GetPermissionsAsJSON()
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
publicKeys, err := user.GetPublicKeysAsJSON() publicKeys, err := user.GetPublicKeysAsJSON()
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
filters, err := user.GetFiltersAsJSON() filters, err := user.GetFiltersAsJSON()
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
fsConfig, err := user.GetFsConfigAsJSON() fsConfig, err := user.GetFsConfigAsJSON()
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
_, err = stmt.ExecContext(ctx, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize, _, err = stmt.ExecContext(ctx, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters), user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters),
string(fsConfig), user.AdditionalInfo) string(fsConfig), user.AdditionalInfo)
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
err = generateVirtualFoldersMapping(ctx, user, tx) err = generateVirtualFoldersMapping(ctx, user, tx)
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
return tx.Commit() return tx.Commit()
@ -367,6 +361,7 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel() defer cancel()
tx, err := dbHandle.BeginTx(ctx, nil) tx, err := dbHandle.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
@ -375,40 +370,33 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
stmt, err := tx.PrepareContext(ctx, q) stmt, err := tx.PrepareContext(ctx, q)
if err != nil { if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err) providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
sqlCommonRollbackTransaction(tx)
return err return err
} }
defer stmt.Close() defer stmt.Close()
permissions, err := user.GetPermissionsAsJSON() permissions, err := user.GetPermissionsAsJSON()
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
publicKeys, err := user.GetPublicKeysAsJSON() publicKeys, err := user.GetPublicKeysAsJSON()
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
filters, err := user.GetFiltersAsJSON() filters, err := user.GetFiltersAsJSON()
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
fsConfig, err := user.GetFsConfigAsJSON() fsConfig, err := user.GetFsConfigAsJSON()
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
_, err = stmt.ExecContext(ctx, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize, _, err = stmt.ExecContext(ctx, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate,
string(filters), string(fsConfig), user.AdditionalInfo, user.ID) string(filters), string(fsConfig), user.AdditionalInfo, user.ID)
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
err = generateVirtualFoldersMapping(ctx, user, tx) err = generateVirtualFoldersMapping(ctx, user, tx)
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
return tx.Commit() return tx.Commit()
@ -979,13 +967,6 @@ func sqlCommonGetFolderUsedQuota(mappedPath string, dbHandle *sql.DB) (int, int6
return usedFiles, usedSize, err return usedFiles, usedSize, err
} }
func sqlCommonRollbackTransaction(tx *sql.Tx) {
err := tx.Rollback()
if err != nil {
providerLog(logger.LevelWarn, "error rolling back transaction: %v", err)
}
}
func sqlCommonGetDatabaseVersion(dbHandle *sql.DB, showInitWarn bool) (schemaVersion, error) { func sqlCommonGetDatabaseVersion(dbHandle *sql.DB, showInitWarn bool) (schemaVersion, error) {
var result schemaVersion var result schemaVersion
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
@ -1030,13 +1011,11 @@ func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sql []string, newVersi
} }
_, err = tx.ExecContext(ctx, q) _, err = tx.ExecContext(ctx, q)
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
} }
err = sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion) err = sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion)
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
return tx.Commit() return tx.Commit()
@ -1130,6 +1109,7 @@ func sqlCommonUpdateDatabaseFrom3To4(sqlV4 string, dbHandle *sql.DB) error {
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping) sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
defer cancel() defer cancel()
tx, err := dbHandle.BeginTx(ctx, nil) tx, err := dbHandle.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
@ -1140,25 +1120,14 @@ func sqlCommonUpdateDatabaseFrom3To4(sqlV4 string, dbHandle *sql.DB) error {
} }
_, err = tx.ExecContext(ctx, q) _, err = tx.ExecContext(ctx, q)
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
} }
/*_, err = sqlCommonRestoreCompatVirtualFolders(ctx, users, tx)
if err != nil {
sqlCommonRollbackTransaction(tx)
return err
}*/
err = sqlCommonUpdateDatabaseVersion(ctx, tx, 4) err = sqlCommonUpdateDatabaseVersion(ctx, tx, 4)
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
return tx.Commit() return tx.Commit()
/*if err == nil {
go updateVFoldersQuotaAfterRestore(foldersToScan)
}
return err*/
} }
//nolint:dupl //nolint:dupl

View file

@ -262,23 +262,23 @@ func (p *SQLiteProvider) initializeDatabase() error {
return ErrNoInitRequired return ErrNoInitRequired
} }
sqlUsers := strings.Replace(sqliteUsersTableSQL, "{{users}}", sqlTableUsers, 1) sqlUsers := strings.Replace(sqliteUsersTableSQL, "{{users}}", sqlTableUsers, 1)
tx, err := p.dbHandle.Begin() ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
defer cancel()
tx, err := p.dbHandle.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
_, err = tx.Exec(sqlUsers) _, err = tx.Exec(sqlUsers)
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
_, err = tx.Exec(strings.Replace(sqliteSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1)) _, err = tx.Exec(strings.Replace(sqliteSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
_, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1)) _, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
if err != nil { if err != nil {
sqlCommonRollbackTransaction(tx)
return err return err
} }
return tx.Commit() return tx.Commit()

View file

@ -470,12 +470,12 @@ func (c *Connection) handleFTPUploadToExistingFile(flags int, resolvedPath, file
if vfs.IsLocalOrSFTPFs(c.Fs) { if vfs.IsLocalOrSFTPFs(c.Fs) {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil { if err == nil {
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() { if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
} }
} else { } else {
dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
} }
} else { } else {
initialSize = fileSize initialSize = fileSize

View file

@ -152,7 +152,7 @@ func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string)
connection.Fs.CheckRootPath(connection.GetUsername(), user.GetUID(), user.GetGID()) connection.Fs.CheckRootPath(connection.GetUsername(), user.GetUID(), user.GetGID())
connection.Log(logger.LevelInfo, "User id: %d, logged in with FTP, username: %#v, home_dir: %#v remote addr: %#v", connection.Log(logger.LevelInfo, "User id: %d, logged in with FTP, username: %#v, home_dir: %#v remote addr: %#v",
user.ID, user.Username, user.HomeDir, ipAddr) user.ID, user.Username, user.HomeDir, ipAddr)
dataprovider.UpdateLastLogin(user) //nolint:errcheck dataprovider.UpdateLastLogin(&user) //nolint:errcheck
return connection, nil return connection, nil
} }

View file

@ -58,7 +58,7 @@ func updateUserQuotaUsage(w http.ResponseWriter, r *http.Request) {
return return
} }
defer common.QuotaScans.RemoveUserQuotaScan(user.Username) defer common.QuotaScans.RemoveUserQuotaScan(user.Username)
err = dataprovider.UpdateUserQuota(user, u.UsedQuotaFiles, u.UsedQuotaSize, mode == quotaUpdateModeReset) err = dataprovider.UpdateUserQuota(&user, u.UsedQuotaFiles, u.UsedQuotaSize, mode == quotaUpdateModeReset)
if err != nil { if err != nil {
sendAPIResponse(w, r, err, "", getRespStatus(err)) sendAPIResponse(w, r, err, "", getRespStatus(err))
} else { } else {
@ -94,7 +94,7 @@ func updateVFolderQuotaUsage(w http.ResponseWriter, r *http.Request) {
return return
} }
defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name) defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name)
err = dataprovider.UpdateVirtualFolderQuota(folder, f.UsedQuotaFiles, f.UsedQuotaSize, mode == quotaUpdateModeReset) err = dataprovider.UpdateVirtualFolderQuota(&folder, f.UsedQuotaFiles, f.UsedQuotaSize, mode == quotaUpdateModeReset)
if err != nil { if err != nil {
sendAPIResponse(w, r, err, "", getRespStatus(err)) sendAPIResponse(w, r, err, "", getRespStatus(err))
} else { } else {
@ -165,7 +165,7 @@ func doQuotaScan(user dataprovider.User) error {
logger.Warn(logSender, "", "error scanning user home dir %#v: %v", user.Username, err) logger.Warn(logSender, "", "error scanning user home dir %#v: %v", user.Username, err)
return err return err
} }
err = dataprovider.UpdateUserQuota(user, numFiles, size, true) err = dataprovider.UpdateUserQuota(&user, numFiles, size, true)
logger.Debug(logSender, "", "user home dir scanned, user: %#v, error: %v", user.Username, err) logger.Debug(logSender, "", "user home dir scanned, user: %#v, error: %v", user.Username, err)
return err return err
} }
@ -178,7 +178,7 @@ func doFolderQuotaScan(folder vfs.BaseVirtualFolder) error {
logger.Warn(logSender, "", "error scanning folder %#v: %v", folder.MappedPath, err) logger.Warn(logSender, "", "error scanning folder %#v: %v", folder.MappedPath, err)
return err return err
} }
err = dataprovider.UpdateVirtualFolderQuota(folder, numFiles, size, true) err = dataprovider.UpdateVirtualFolderQuota(&folder, numFiles, size, true)
logger.Debug(logSender, "", "virtual folder %#v scanned, error: %v", folder.Name, err) logger.Debug(logSender, "", "virtual folder %#v scanned, error: %v", folder.Name, err)
return err return err
} }

View file

@ -123,7 +123,7 @@ type foldersPage struct {
type connectionsPage struct { type connectionsPage struct {
basePage basePage
Connections []common.ConnectionStatus Connections []*common.ConnectionStatus
} }
type statusPage struct { type statusPage struct {

View file

@ -69,7 +69,7 @@ func (l *LeveledLogger) addKeysAndValues(ev *zerolog.Event, keysAndValues ...int
extra := keysAndValues[kvLen-1] extra := keysAndValues[kvLen-1]
keysAndValues = append(keysAndValues[:kvLen-1], "EXTRA_VALUE_AT_END", extra) keysAndValues = append(keysAndValues[:kvLen-1], "EXTRA_VALUE_AT_END", extra)
} }
for i := 0; i < len(keysAndValues); i = i + 2 { for i := 0; i < len(keysAndValues); i += 2 {
key, val := keysAndValues[i], keysAndValues[i+1] key, val := keysAndValues[i], keysAndValues[i+1]
if keyStr, ok := key.(string); ok { if keyStr, ok := key.(string); ok {
ev.Str(keyStr, fmt.Sprintf("%v", val)) ev.Str(keyStr, fmt.Sprintf("%v", val))

View file

@ -412,12 +412,12 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r
if vfs.IsLocalOrSFTPFs(c.Fs) && isTruncate { if vfs.IsLocalOrSFTPFs(c.Fs) && isTruncate {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil { if err == nil {
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() { if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
} }
} else { } else {
dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
} }
} else { } else {
initialSize = fileSize initialSize = fileSize
@ -460,7 +460,7 @@ func (c *Connection) getStatVFSFromQuotaResult(name string, quotaResult vfs.Quot
bsize := uint64(4096) bsize := uint64(4096)
for bsize > uint64(quotaResult.QuotaSize) { for bsize > uint64(quotaResult.QuotaSize) {
bsize = bsize / 4 bsize /= 4
} }
blocks := uint64(quotaResult.QuotaSize) / bsize blocks := uint64(quotaResult.QuotaSize) / bsize
bfree := uint64(quotaResult.QuotaSize-quotaResult.UsedSize) / bsize bfree := uint64(quotaResult.QuotaSize-quotaResult.UsedSize) / bsize

View file

@ -365,7 +365,7 @@ func TestUploadFiles(t *testing.T) {
func TestWithInvalidHome(t *testing.T) { func TestWithInvalidHome(t *testing.T) {
u := dataprovider.User{} u := dataprovider.User{}
u.HomeDir = "home_rel_path" //nolint:goconst u.HomeDir = "home_rel_path" //nolint:goconst
_, err := loginUser(u, dataprovider.LoginMethodPassword, "", nil) _, err := loginUser(&u, dataprovider.LoginMethodPassword, "", nil)
assert.Error(t, err, "login a user with an invalid home_dir must fail") assert.Error(t, err, "login a user with an invalid home_dir must fail")
u.HomeDir = os.TempDir() u.HomeDir = os.TempDir()
@ -1890,7 +1890,7 @@ func TestRecursiveCopyErrors(t *testing.T) {
func TestSFTPSubSystem(t *testing.T) { func TestSFTPSubSystem(t *testing.T) {
permissions := make(map[string][]string) permissions := make(map[string][]string)
permissions["/"] = []string{dataprovider.PermAny} permissions["/"] = []string{dataprovider.PermAny}
user := dataprovider.User{ user := &dataprovider.User{
Permissions: permissions, Permissions: permissions,
HomeDir: os.TempDir(), HomeDir: os.TempDir(),
} }

View file

@ -213,12 +213,12 @@ func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead
if vfs.IsLocalOrSFTPFs(c.connection.Fs) { if vfs.IsLocalOrSFTPFs(c.connection.Fs) {
vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath)) vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil { if err == nil {
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() { if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(c.connection.User, 0, -fileSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
} }
} else { } else {
dataprovider.UpdateUserQuota(c.connection.User, 0, -fileSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
} }
} else { } else {
initialSize = fileSize initialSize = fileSize

View file

@ -408,7 +408,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID, logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
"User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v", "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
user.ID, loginType, user.Username, user.HomeDir, ipAddr) user.ID, loginType, user.Username, user.HomeDir, ipAddr)
dataprovider.UpdateLastLogin(user) //nolint:errcheck dataprovider.UpdateLastLogin(&user) //nolint:errcheck
sshConnection := common.NewSSHConnection(connectionID, conn) sshConnection := common.NewSSHConnection(connectionID, conn)
common.Connections.AddSSHConnection(sshConnection) common.Connections.AddSSHConnection(sshConnection)
@ -557,7 +557,7 @@ func checkRootPath(user *dataprovider.User, connectionID string) error {
return nil return nil
} }
func loginUser(user dataprovider.User, loginMethod, publicKey string, conn ssh.ConnMetadata) (*ssh.Permissions, error) { func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.ConnMetadata) (*ssh.Permissions, error) {
connectionID := "" connectionID := ""
if conn != nil { if conn != nil {
connectionID = hex.EncodeToString(conn.SessionID()) connectionID = hex.EncodeToString(conn.SessionID())
@ -817,7 +817,7 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK
logger.Debug(logSender, connectionID, "user %#v authenticated with partial success", conn.User()) logger.Debug(logSender, connectionID, "user %#v authenticated with partial success", conn.User())
return certPerm, ssh.ErrPartialSuccess return certPerm, ssh.ErrPartialSuccess
} }
sshPerm, err = loginUser(user, method, keyID, conn) sshPerm, err = loginUser(&user, method, keyID, conn)
if err == nil && certPerm != nil { if err == nil && certPerm != nil {
// if we have a SSH user cert we need to merge certificate permissions with our ones // if we have a SSH user cert we need to merge certificate permissions with our ones
// we only set Extensions, so CriticalOptions are always the ones from the certificate // we only set Extensions, so CriticalOptions are always the ones from the certificate
@ -845,7 +845,7 @@ func (c *Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass
} }
ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()) ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
if user, err = dataprovider.CheckUserAndPass(conn.User(), string(pass), ipAddr, common.ProtocolSSH); err == nil { if user, err = dataprovider.CheckUserAndPass(conn.User(), string(pass), ipAddr, common.ProtocolSSH); err == nil {
sshPerm, err = loginUser(user, method, "", conn) sshPerm, err = loginUser(&user, method, "", conn)
} }
user.Username = conn.User() user.Username = conn.User()
updateLoginMetrics(&user, ipAddr, method, err) updateLoginMetrics(&user, ipAddr, method, err)
@ -864,7 +864,7 @@ func (c *Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMeta
ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()) ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
if user, err = dataprovider.CheckKeyboardInteractiveAuth(conn.User(), c.KeyboardInteractiveHook, client, if user, err = dataprovider.CheckKeyboardInteractiveAuth(conn.User(), c.KeyboardInteractiveHook, client,
ipAddr, common.ProtocolSSH); err == nil { ipAddr, common.ProtocolSSH); err == nil {
sshPerm, err = loginUser(user, method, "", conn) sshPerm, err = loginUser(&user, method, "", conn)
} }
user.Username = conn.User() user.Username = conn.User()
updateLoginMetrics(&user, ipAddr, method, err) updateLoginMetrics(&user, ipAddr, method, err)

View file

@ -6490,7 +6490,7 @@ func TestStatVFSCloudBackend(t *testing.T) {
if assert.NoError(t, err) { if assert.NoError(t, err) {
defer client.Close() defer client.Close()
err = dataprovider.UpdateUserQuota(user, 100, 8192, true) err = dataprovider.UpdateUserQuota(&user, 100, 8192, true)
assert.NoError(t, err) assert.NoError(t, err)
stat, err := client.StatVFS("/") stat, err := client.StatVFS("/")
assert.NoError(t, err) assert.NoError(t, err)

View file

@ -248,12 +248,12 @@ func (c *sshCommand) handeSFTPGoRemove() error {
func (c *sshCommand) updateQuota(sshDestPath string, filesNum int, filesSize int64) { func (c *sshCommand) updateQuota(sshDestPath string, filesNum int, filesSize int64) {
vfolder, err := c.connection.User.GetVirtualFolderForPath(sshDestPath) vfolder, err := c.connection.User.GetVirtualFolderForPath(sshDestPath)
if err == nil { if err == nil {
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, filesNum, filesSize, false) //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, filesNum, filesSize, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() { if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(c.connection.User, filesNum, filesSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.connection.User, filesNum, filesSize, false) //nolint:errcheck
} }
} else { } else {
dataprovider.UpdateUserQuota(c.connection.User, filesNum, filesSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.connection.User, filesNum, filesSize, false) //nolint:errcheck
} }
} }

View file

@ -35,7 +35,7 @@ func newSubsystemChannel(reader io.Reader, writer io.Writer) *subsystemChannel {
} }
// ServeSubSystemConnection handles a connection as SSH subsystem // ServeSubSystemConnection handles a connection as SSH subsystem
func ServeSubSystemConnection(user dataprovider.User, connectionID string, reader io.Reader, writer io.Writer) error { func ServeSubSystemConnection(user *dataprovider.User, connectionID string, reader io.Reader, writer io.Writer) error {
fs, err := user.GetFilesystem(connectionID) fs, err := user.GetFilesystem(connectionID)
if err != nil { if err != nil {
return err return err
@ -44,7 +44,7 @@ func ServeSubSystemConnection(user dataprovider.User, connectionID string, reade
dataprovider.UpdateLastLogin(user) //nolint:errcheck dataprovider.UpdateLastLogin(user) //nolint:errcheck
connection := &Connection{ connection := &Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolSFTP, user, fs), BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolSFTP, *user, fs),
ClientVersion: "", ClientVersion: "",
RemoteAddr: &net.IPAddr{}, RemoteAddr: &net.IPAddr{},
channel: newSubsystemChannel(reader, writer), channel: newSubsystemChannel(reader, writer),

View file

@ -253,7 +253,7 @@ func (fs *AzureBlobFs) Create(name string, flag int) (File, *PipeWriter, func(),
// if we shutdown Azurite while uploading it hangs, so we use our own wrapper for // if we shutdown Azurite while uploading it hangs, so we use our own wrapper for
// the low level functions // the low level functions
_, err := azblob.UploadStreamToBlockBlob(ctx, r, blobBlockURL, uploadOptions)*/ _, err := azblob.UploadStreamToBlockBlob(ctx, r, blobBlockURL, uploadOptions)*/
err := fs.handleMultipartUpload(ctx, r, blobBlockURL, headers) err := fs.handleMultipartUpload(ctx, r, &blobBlockURL, &headers)
r.CloseWithError(err) //nolint:errcheck r.CloseWithError(err) //nolint:errcheck
p.Done(err) p.Done(err)
fsLog(fs, logger.LevelDebug, "upload completed, path: %#v, readed bytes: %v, err: %v", name, r.GetReadedBytes(), err) fsLog(fs, logger.LevelDebug, "upload completed, path: %#v, readed bytes: %v, err: %v", name, r.GetReadedBytes(), err)
@ -438,7 +438,8 @@ func (fs *AzureBlobFs) ReadDir(dirname string) ([]os.FileInfo, error) {
result = append(result, NewFileInfo(name, true, 0, time.Now(), false)) result = append(result, NewFileInfo(name, true, 0, time.Now(), false))
prefixes[strings.TrimSuffix(name, "/")] = true prefixes[strings.TrimSuffix(name, "/")] = true
} }
for _, blobInfo := range listBlob.Segment.BlobItems { for idx := range listBlob.Segment.BlobItems {
blobInfo := &listBlob.Segment.BlobItems[idx]
name := strings.TrimPrefix(blobInfo.Name, prefix) name := strings.TrimPrefix(blobInfo.Name, prefix)
size := int64(0) size := int64(0)
if blobInfo.Properties.ContentLength != nil { if blobInfo.Properties.ContentLength != nil {
@ -556,7 +557,8 @@ func (fs *AzureBlobFs) ScanRootDirContents() (int, int64, error) {
return numFiles, size, err return numFiles, size, err
} }
marker = listBlob.NextMarker marker = listBlob.NextMarker
for _, blobInfo := range listBlob.Segment.BlobItems { for idx := range listBlob.Segment.BlobItems {
blobInfo := &listBlob.Segment.BlobItems[idx]
isDir := false isDir := false
if blobInfo.Properties.ContentType != nil { if blobInfo.Properties.ContentType != nil {
isDir = (*blobInfo.Properties.ContentType == dirMimeType) isDir = (*blobInfo.Properties.ContentType == dirMimeType)
@ -637,7 +639,8 @@ func (fs *AzureBlobFs) Walk(root string, walkFn filepath.WalkFunc) error {
return err return err
} }
marker = listBlob.NextMarker marker = listBlob.NextMarker
for _, blobInfo := range listBlob.Segment.BlobItems { for idx := range listBlob.Segment.BlobItems {
blobInfo := &listBlob.Segment.BlobItems[idx]
isDir := false isDir := false
if blobInfo.Properties.ContentType != nil { if blobInfo.Properties.ContentType != nil {
isDir = (*blobInfo.Properties.ContentType == dirMimeType) isDir = (*blobInfo.Properties.ContentType == dirMimeType)
@ -776,8 +779,8 @@ func (fs *AzureBlobFs) hasContents(name string) (bool, error) {
return result, err return result, err
} }
func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Reader, blockBlobURL azblob.BlockBlobURL, func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Reader, blockBlobURL *azblob.BlockBlobURL,
httpHeaders azblob.BlobHTTPHeaders) error { httpHeaders *azblob.BlobHTTPHeaders) error {
partSize := fs.config.UploadPartSize partSize := fs.config.UploadPartSize
guard := make(chan struct{}, fs.config.UploadConcurrency) guard := make(chan struct{}, fs.config.UploadConcurrency)
blockCtxTimeout := time.Duration(fs.config.UploadPartSize/(1024*1024)) * time.Minute blockCtxTimeout := time.Duration(fs.config.UploadPartSize/(1024*1024)) * time.Minute
@ -852,7 +855,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
return poolError return poolError
} }
_, err := blockBlobURL.CommitBlockList(ctx, blocks, httpHeaders, azblob.Metadata{}, azblob.BlobAccessConditions{}, _, err := blockBlobURL.CommitBlockList(ctx, blocks, *httpHeaders, azblob.Metadata{}, azblob.BlobAccessConditions{},
azblob.AccessTierType(fs.config.AccessTier), nil, azblob.ClientProvidedKeyOptions{}) azblob.AccessTierType(fs.config.AccessTier), nil, azblob.ClientProvidedKeyOptions{})
return err return err
} }

View file

@ -264,12 +264,12 @@ func (c *Connection) handleUploadToExistingFile(resolvedPath, filePath string, f
if vfs.IsLocalOrSFTPFs(c.Fs) { if vfs.IsLocalOrSFTPFs(c.Fs) {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil { if err == nil {
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
if vfolder.IsIncludedInUserQuota() { if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
} }
} else { } else {
dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
} }
} else { } else {
initialSize = fileSize initialSize = fileSize

View file

@ -181,7 +181,7 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
common.Connections.Add(connection) common.Connections.Add(connection)
defer common.Connections.Remove(connection.GetID()) defer common.Connections.Remove(connection.GetID())
dataprovider.UpdateLastLogin(user) //nolint:errcheck dataprovider.UpdateLastLogin(&user) //nolint:errcheck
if s.checkRequestMethod(ctx, r, connection) { if s.checkRequestMethod(ctx, r, connection) {
w.Header().Set("Content-Type", "text/xml; charset=utf-8") w.Header().Set("Content-Type", "text/xml; charset=utf-8")