mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-22 07:30:25 +00:00
micro optimizations spotted using the go-critic linter
This commit is contained in:
parent
b1ce6eb85b
commit
be9230e85b
29 changed files with 160 additions and 189 deletions
|
@ -122,7 +122,7 @@ Command-line flags should be specified in the Subsystem declaration.
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = sftpd.ServeSubSystemConnection(user, connectionID, os.Stdin, os.Stdout)
|
err = sftpd.ServeSubSystemConnection(&user, connectionID, os.Stdin, os.Stdout)
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
logger.Warn(logSender, connectionID, "serving subsystem finished with error: %v", err)
|
logger.Warn(logSender, connectionID, "serving subsystem finished with error: %v", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
|
|
@ -52,7 +52,7 @@ func SSHCommandActionNotification(user *dataprovider.User, filePath, target, ssh
|
||||||
|
|
||||||
// ActionHandler handles a notification for a Protocol Action.
|
// ActionHandler handles a notification for a Protocol Action.
|
||||||
type ActionHandler interface {
|
type ActionHandler interface {
|
||||||
Handle(notification ActionNotification) error
|
Handle(notification *ActionNotification) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ActionNotification defines a notification for a Protocol Action.
|
// ActionNotification defines a notification for a Protocol Action.
|
||||||
|
@ -75,7 +75,7 @@ func newActionNotification(
|
||||||
operation, filePath, target, sshCmd, protocol string,
|
operation, filePath, target, sshCmd, protocol string,
|
||||||
fileSize int64,
|
fileSize int64,
|
||||||
err error,
|
err error,
|
||||||
) ActionNotification {
|
) *ActionNotification {
|
||||||
var bucket, endpoint string
|
var bucket, endpoint string
|
||||||
status := 1
|
status := 1
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ func newActionNotification(
|
||||||
status = 0
|
status = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
return ActionNotification{
|
return &ActionNotification{
|
||||||
Action: operation,
|
Action: operation,
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Path: filePath,
|
Path: filePath,
|
||||||
|
@ -116,7 +116,7 @@ func newActionNotification(
|
||||||
|
|
||||||
type defaultActionHandler struct{}
|
type defaultActionHandler struct{}
|
||||||
|
|
||||||
func (h *defaultActionHandler) Handle(notification ActionNotification) error {
|
func (h *defaultActionHandler) Handle(notification *ActionNotification) error {
|
||||||
if !utils.IsStringInSlice(notification.Action, Config.Actions.ExecuteOn) {
|
if !utils.IsStringInSlice(notification.Action, Config.Actions.ExecuteOn) {
|
||||||
return errUnconfiguredAction
|
return errUnconfiguredAction
|
||||||
}
|
}
|
||||||
|
@ -134,7 +134,7 @@ func (h *defaultActionHandler) Handle(notification ActionNotification) error {
|
||||||
return h.handleCommand(notification)
|
return h.handleCommand(notification)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *defaultActionHandler) handleHTTP(notification ActionNotification) error {
|
func (h *defaultActionHandler) handleHTTP(notification *ActionNotification) error {
|
||||||
u, err := url.Parse(Config.Actions.Hook)
|
u, err := url.Parse(Config.Actions.Hook)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn(notification.Protocol, "", "Invalid hook %#v for operation %#v: %v", Config.Actions.Hook, notification.Action, err)
|
logger.Warn(notification.Protocol, "", "Invalid hook %#v for operation %#v: %v", Config.Actions.Hook, notification.Action, err)
|
||||||
|
@ -165,7 +165,7 @@ func (h *defaultActionHandler) handleHTTP(notification ActionNotification) error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *defaultActionHandler) handleCommand(notification ActionNotification) error {
|
func (h *defaultActionHandler) handleCommand(notification *ActionNotification) error {
|
||||||
if !filepath.IsAbs(Config.Actions.Hook) {
|
if !filepath.IsAbs(Config.Actions.Hook) {
|
||||||
err := fmt.Errorf("invalid notification command %#v", Config.Actions.Hook)
|
err := fmt.Errorf("invalid notification command %#v", Config.Actions.Hook)
|
||||||
logger.Warn(notification.Protocol, "", "unable to execute notification command: %v", err)
|
logger.Warn(notification.Protocol, "", "unable to execute notification command: %v", err)
|
||||||
|
@ -188,7 +188,7 @@ func (h *defaultActionHandler) handleCommand(notification ActionNotification) er
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func notificationAsEnvVars(notification ActionNotification) []string {
|
func notificationAsEnvVars(notification *ActionNotification) []string {
|
||||||
return []string{
|
return []string{
|
||||||
fmt.Sprintf("SFTPGO_ACTION=%v", notification.Action),
|
fmt.Sprintf("SFTPGO_ACTION=%v", notification.Action),
|
||||||
fmt.Sprintf("SFTPGO_ACTION_USERNAME=%v", notification.Username),
|
fmt.Sprintf("SFTPGO_ACTION_USERNAME=%v", notification.Username),
|
||||||
|
|
|
@ -201,7 +201,7 @@ type actionHandlerStub struct {
|
||||||
called bool
|
called bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *actionHandlerStub) Handle(notification ActionNotification) error {
|
func (h *actionHandlerStub) Handle(notification *ActionNotification) error {
|
||||||
h.called = true
|
h.called = true
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -215,7 +215,7 @@ func TestInitializeActionHandler(t *testing.T) {
|
||||||
InitializeActionHandler(&defaultActionHandler{})
|
InitializeActionHandler(&defaultActionHandler{})
|
||||||
})
|
})
|
||||||
|
|
||||||
err := actionHandler.Handle(ActionNotification{})
|
err := actionHandler.Handle(&ActionNotification{})
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, handler.called)
|
assert.True(t, handler.called)
|
||||||
|
|
|
@ -630,13 +630,13 @@ func (conns *ActiveConnections) IsNewConnectionAllowed() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStats returns stats for active connections
|
// GetStats returns stats for active connections
|
||||||
func (conns *ActiveConnections) GetStats() []ConnectionStatus {
|
func (conns *ActiveConnections) GetStats() []*ConnectionStatus {
|
||||||
conns.RLock()
|
conns.RLock()
|
||||||
defer conns.RUnlock()
|
defer conns.RUnlock()
|
||||||
|
|
||||||
stats := make([]ConnectionStatus, 0, len(conns.connections))
|
stats := make([]*ConnectionStatus, 0, len(conns.connections))
|
||||||
for _, c := range conns.connections {
|
for _, c := range conns.connections {
|
||||||
stat := ConnectionStatus{
|
stat := &ConnectionStatus{
|
||||||
Username: c.GetUsername(),
|
Username: c.GetUsername(),
|
||||||
ConnectionID: c.GetID(),
|
ConnectionID: c.GetID(),
|
||||||
ClientVersion: c.GetClientVersion(),
|
ClientVersion: c.GetClientVersion(),
|
||||||
|
@ -675,14 +675,14 @@ type ConnectionStatus struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConnectionDuration returns the connection duration as string
|
// GetConnectionDuration returns the connection duration as string
|
||||||
func (c ConnectionStatus) GetConnectionDuration() string {
|
func (c *ConnectionStatus) GetConnectionDuration() string {
|
||||||
elapsed := time.Since(utils.GetTimeFromMsecSinceEpoch(c.ConnectionTime))
|
elapsed := time.Since(utils.GetTimeFromMsecSinceEpoch(c.ConnectionTime))
|
||||||
return utils.GetDurationAsString(elapsed)
|
return utils.GetDurationAsString(elapsed)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConnectionInfo returns connection info.
|
// GetConnectionInfo returns connection info.
|
||||||
// Protocol,Client Version and RemoteAddress are returned.
|
// Protocol,Client Version and RemoteAddress are returned.
|
||||||
func (c ConnectionStatus) GetConnectionInfo() string {
|
func (c *ConnectionStatus) GetConnectionInfo() string {
|
||||||
var result strings.Builder
|
var result strings.Builder
|
||||||
|
|
||||||
result.WriteString(fmt.Sprintf("%v. Client: %#v From: %#v", c.Protocol, c.ClientVersion, c.RemoteAddress))
|
result.WriteString(fmt.Sprintf("%v. Client: %#v From: %#v", c.Protocol, c.ClientVersion, c.RemoteAddress))
|
||||||
|
@ -702,7 +702,7 @@ func (c ConnectionStatus) GetConnectionInfo() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTransfersAsString returns the active transfers as string
|
// GetTransfersAsString returns the active transfers as string
|
||||||
func (c ConnectionStatus) GetTransfersAsString() string {
|
func (c *ConnectionStatus) GetTransfersAsString() string {
|
||||||
result := ""
|
result := ""
|
||||||
for _, t := range c.Transfers {
|
for _, t := range c.Transfers {
|
||||||
if result != "" {
|
if result != "" {
|
||||||
|
|
|
@ -37,10 +37,10 @@ type BaseConnection struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBaseConnection returns a new BaseConnection
|
// NewBaseConnection returns a new BaseConnection
|
||||||
func NewBaseConnection(ID, protocol string, user dataprovider.User, fs vfs.Fs) *BaseConnection {
|
func NewBaseConnection(id, protocol string, user dataprovider.User, fs vfs.Fs) *BaseConnection {
|
||||||
connID := ID
|
connID := id
|
||||||
if utils.IsStringInSlice(protocol, supportedProtocols) {
|
if utils.IsStringInSlice(protocol, supportedProtocols) {
|
||||||
connID = fmt.Sprintf("%v_%v", protocol, ID)
|
connID = fmt.Sprintf("%v_%v", protocol, id)
|
||||||
}
|
}
|
||||||
return &BaseConnection{
|
return &BaseConnection{
|
||||||
ID: connID,
|
ID: connID,
|
||||||
|
@ -272,12 +272,12 @@ func (c *BaseConnection) RemoveFile(fsPath, virtualPath string, info os.FileInfo
|
||||||
if info.Mode()&os.ModeSymlink == 0 {
|
if info.Mode()&os.ModeSymlink == 0 {
|
||||||
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
|
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, -1, -size, false) //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, -1, -size, false) //nolint:errcheck
|
||||||
if vfolder.IsIncludedInUserQuota() {
|
if vfolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(c.User, -1, -size, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
dataprovider.UpdateUserQuota(c.User, -1, -size, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if actionErr != nil {
|
if actionErr != nil {
|
||||||
|
@ -577,12 +577,12 @@ func (c *BaseConnection) truncateFile(fsPath, virtualPath string, size int64) er
|
||||||
sizeDiff := initialSize - size
|
sizeDiff := initialSize - size
|
||||||
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
|
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -sizeDiff, false) //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -sizeDiff, false) //nolint:errcheck
|
||||||
if vfolder.IsIncludedInUserQuota() {
|
if vfolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(c.User, 0, -sizeDiff, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
dataprovider.UpdateUserQuota(c.User, 0, -sizeDiff, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
@ -835,64 +835,64 @@ func (c *BaseConnection) isCrossFoldersRequest(virtualSourcePath, virtualTargetP
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BaseConnection) updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder vfs.VirtualFolder, initialSize,
|
func (c *BaseConnection) updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder *vfs.VirtualFolder, initialSize,
|
||||||
filesSize int64, numFiles int) {
|
filesSize int64, numFiles int) {
|
||||||
if sourceFolder.MappedPath == dstFolder.MappedPath {
|
if sourceFolder.MappedPath == dstFolder.MappedPath {
|
||||||
// both files are inside the same virtual folder
|
// both files are inside the same virtual folder
|
||||||
if initialSize != -1 {
|
if initialSize != -1 {
|
||||||
dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, -numFiles, -initialSize, false) //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, -numFiles, -initialSize, false) //nolint:errcheck
|
||||||
if dstFolder.IsIncludedInUserQuota() {
|
if dstFolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(c.User, -numFiles, -initialSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, -numFiles, -initialSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// files are inside different virtual folders
|
// files are inside different virtual folders
|
||||||
dataprovider.UpdateVirtualFolderQuota(sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck
|
||||||
if sourceFolder.IsIncludedInUserQuota() {
|
if sourceFolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(c.User, -numFiles, -filesSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
if initialSize == -1 {
|
if initialSize == -1 {
|
||||||
dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck
|
||||||
if dstFolder.IsIncludedInUserQuota() {
|
if dstFolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(c.User, numFiles, filesSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// we cannot have a directory here, initialSize != -1 only for files
|
// we cannot have a directory here, initialSize != -1 only for files
|
||||||
dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck
|
||||||
if dstFolder.IsIncludedInUserQuota() {
|
if dstFolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(c.User, 0, filesSize-initialSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BaseConnection) updateQuotaMoveFromVFolder(sourceFolder vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) {
|
func (c *BaseConnection) updateQuotaMoveFromVFolder(sourceFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) {
|
||||||
// move between a virtual folder and the user home dir
|
// move between a virtual folder and the user home dir
|
||||||
dataprovider.UpdateVirtualFolderQuota(sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&sourceFolder.BaseVirtualFolder, -numFiles, -filesSize, false) //nolint:errcheck
|
||||||
if sourceFolder.IsIncludedInUserQuota() {
|
if sourceFolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(c.User, -numFiles, -filesSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
if initialSize == -1 {
|
if initialSize == -1 {
|
||||||
dataprovider.UpdateUserQuota(c.User, numFiles, filesSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
|
||||||
} else {
|
} else {
|
||||||
// we cannot have a directory here, initialSize != -1 only for files
|
// we cannot have a directory here, initialSize != -1 only for files
|
||||||
dataprovider.UpdateUserQuota(c.User, 0, filesSize-initialSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BaseConnection) updateQuotaMoveToVFolder(dstFolder vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) {
|
func (c *BaseConnection) updateQuotaMoveToVFolder(dstFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) {
|
||||||
// move between the user home dir and a virtual folder
|
// move between the user home dir and a virtual folder
|
||||||
dataprovider.UpdateUserQuota(c.User, -numFiles, -filesSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck
|
||||||
if initialSize == -1 {
|
if initialSize == -1 {
|
||||||
dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, numFiles, filesSize, false) //nolint:errcheck
|
||||||
if dstFolder.IsIncludedInUserQuota() {
|
if dstFolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(c.User, numFiles, filesSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// we cannot have a directory here, initialSize != -1 only for files
|
// we cannot have a directory here, initialSize != -1 only for files
|
||||||
dataprovider.UpdateVirtualFolderQuota(dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, 0, filesSize-initialSize, false) //nolint:errcheck
|
||||||
if dstFolder.IsIncludedInUserQuota() {
|
if dstFolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(c.User, 0, filesSize-initialSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -909,7 +909,7 @@ func (c *BaseConnection) updateQuotaAfterRename(virtualSourcePath, virtualTarget
|
||||||
// both files are contained inside the user home dir
|
// both files are contained inside the user home dir
|
||||||
if initialSize != -1 {
|
if initialSize != -1 {
|
||||||
// we cannot have a directory here
|
// we cannot have a directory here
|
||||||
dataprovider.UpdateUserQuota(c.User, -1, -initialSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, -1, -initialSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -932,13 +932,13 @@ func (c *BaseConnection) updateQuotaAfterRename(virtualSourcePath, virtualTarget
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if errSrc == nil && errDst == nil {
|
if errSrc == nil && errDst == nil {
|
||||||
c.updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder, initialSize, filesSize, numFiles)
|
c.updateQuotaMoveBetweenVFolders(&sourceFolder, &dstFolder, initialSize, filesSize, numFiles)
|
||||||
}
|
}
|
||||||
if errSrc == nil && errDst != nil {
|
if errSrc == nil && errDst != nil {
|
||||||
c.updateQuotaMoveFromVFolder(sourceFolder, initialSize, filesSize, numFiles)
|
c.updateQuotaMoveFromVFolder(&sourceFolder, initialSize, filesSize, numFiles)
|
||||||
}
|
}
|
||||||
if errSrc != nil && errDst == nil {
|
if errSrc != nil && errDst == nil {
|
||||||
c.updateQuotaMoveToVFolder(dstFolder, initialSize, filesSize, numFiles)
|
c.updateQuotaMoveToVFolder(&dstFolder, initialSize, filesSize, numFiles)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1054,7 +1054,7 @@ func TestHasSpace(t *testing.T) {
|
||||||
|
|
||||||
folder, err := dataprovider.GetFolderByName(folderName)
|
folder, err := dataprovider.GetFolderByName(folderName)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = dataprovider.UpdateVirtualFolderQuota(folder, 10, 1048576, true)
|
err = dataprovider.UpdateVirtualFolderQuota(&folder, 10, 1048576, true)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
quotaResult = c.HasSpace(true, false, "/vdir/file1")
|
quotaResult = c.HasSpace(true, false, "/vdir/file1")
|
||||||
assert.False(t, quotaResult.HasSpace)
|
assert.False(t, quotaResult.HasSpace)
|
||||||
|
@ -1105,14 +1105,14 @@ func TestUpdateQuotaMoveVFolders(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
folder2, err := dataprovider.GetFolderByName(folderName2)
|
folder2, err := dataprovider.GetFolderByName(folderName2)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = dataprovider.UpdateVirtualFolderQuota(folder1, 1, 100, true)
|
err = dataprovider.UpdateVirtualFolderQuota(&folder1, 1, 100, true)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = dataprovider.UpdateVirtualFolderQuota(folder2, 2, 150, true)
|
err = dataprovider.UpdateVirtualFolderQuota(&folder2, 2, 150, true)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
fs, err := user.GetFilesystem("id")
|
fs, err := user.GetFilesystem("id")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
c := NewBaseConnection("", ProtocolSFTP, user, fs)
|
c := NewBaseConnection("", ProtocolSFTP, user, fs)
|
||||||
c.updateQuotaMoveBetweenVFolders(user.VirtualFolders[0], user.VirtualFolders[1], -1, 100, 1)
|
c.updateQuotaMoveBetweenVFolders(&user.VirtualFolders[0], &user.VirtualFolders[1], -1, 100, 1)
|
||||||
folder1, err = dataprovider.GetFolderByName(folderName1)
|
folder1, err = dataprovider.GetFolderByName(folderName1)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 0, folder1.UsedQuotaFiles)
|
assert.Equal(t, 0, folder1.UsedQuotaFiles)
|
||||||
|
@ -1122,7 +1122,7 @@ func TestUpdateQuotaMoveVFolders(t *testing.T) {
|
||||||
assert.Equal(t, 3, folder2.UsedQuotaFiles)
|
assert.Equal(t, 3, folder2.UsedQuotaFiles)
|
||||||
assert.Equal(t, int64(250), folder2.UsedQuotaSize)
|
assert.Equal(t, int64(250), folder2.UsedQuotaSize)
|
||||||
|
|
||||||
c.updateQuotaMoveBetweenVFolders(user.VirtualFolders[1], user.VirtualFolders[0], 10, 100, 1)
|
c.updateQuotaMoveBetweenVFolders(&user.VirtualFolders[1], &user.VirtualFolders[0], 10, 100, 1)
|
||||||
folder1, err = dataprovider.GetFolderByName(folderName1)
|
folder1, err = dataprovider.GetFolderByName(folderName1)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 0, folder1.UsedQuotaFiles)
|
assert.Equal(t, 0, folder1.UsedQuotaFiles)
|
||||||
|
@ -1132,9 +1132,9 @@ func TestUpdateQuotaMoveVFolders(t *testing.T) {
|
||||||
assert.Equal(t, 2, folder2.UsedQuotaFiles)
|
assert.Equal(t, 2, folder2.UsedQuotaFiles)
|
||||||
assert.Equal(t, int64(150), folder2.UsedQuotaSize)
|
assert.Equal(t, int64(150), folder2.UsedQuotaSize)
|
||||||
|
|
||||||
err = dataprovider.UpdateUserQuota(user, 1, 100, true)
|
err = dataprovider.UpdateUserQuota(&user, 1, 100, true)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
c.updateQuotaMoveFromVFolder(user.VirtualFolders[1], -1, 50, 1)
|
c.updateQuotaMoveFromVFolder(&user.VirtualFolders[1], -1, 50, 1)
|
||||||
folder2, err = dataprovider.GetFolderByName(folderName2)
|
folder2, err = dataprovider.GetFolderByName(folderName2)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 1, folder2.UsedQuotaFiles)
|
assert.Equal(t, 1, folder2.UsedQuotaFiles)
|
||||||
|
@ -1144,7 +1144,7 @@ func TestUpdateQuotaMoveVFolders(t *testing.T) {
|
||||||
assert.Equal(t, 1, user.UsedQuotaFiles)
|
assert.Equal(t, 1, user.UsedQuotaFiles)
|
||||||
assert.Equal(t, int64(100), user.UsedQuotaSize)
|
assert.Equal(t, int64(100), user.UsedQuotaSize)
|
||||||
|
|
||||||
c.updateQuotaMoveToVFolder(user.VirtualFolders[1], -1, 100, 1)
|
c.updateQuotaMoveToVFolder(&user.VirtualFolders[1], -1, 100, 1)
|
||||||
folder2, err = dataprovider.GetFolderByName(folderName2)
|
folder2, err = dataprovider.GetFolderByName(folderName2)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 2, folder2.UsedQuotaFiles)
|
assert.Equal(t, 2, folder2.UsedQuotaFiles)
|
||||||
|
|
|
@ -266,13 +266,13 @@ func (t *BaseTransfer) updateQuota(numFiles int, fileSize int64) bool {
|
||||||
if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff > 0) {
|
if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff > 0) {
|
||||||
vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath))
|
vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck
|
||||||
sizeDiff, false)
|
sizeDiff, false)
|
||||||
if vfolder.IsIncludedInUserQuota() {
|
if vfolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
dataprovider.UpdateUserQuota(t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,6 @@ const (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
usersBucket = []byte("users")
|
usersBucket = []byte("users")
|
||||||
//usersIDIdxBucket = []byte("users_id_idx")
|
|
||||||
foldersBucket = []byte("folders")
|
foldersBucket = []byte("folders")
|
||||||
adminsBucket = []byte("admins")
|
adminsBucket = []byte("admins")
|
||||||
dbVersionBucket = []byte("db_version")
|
dbVersionBucket = []byte("db_version")
|
||||||
|
@ -113,7 +112,7 @@ func (p *BoltProvider) validateUserAndPass(username, password, ip, protocol stri
|
||||||
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
return checkUserAndPass(user, password, ip, protocol)
|
return checkUserAndPass(&user, password, ip, protocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BoltProvider) validateAdminAndPass(username, password, ip string) (Admin, error) {
|
func (p *BoltProvider) validateAdminAndPass(username, password, ip string) (Admin, error) {
|
||||||
|
@ -136,7 +135,7 @@ func (p *BoltProvider) validateUserAndPubKey(username string, pubKey []byte) (Us
|
||||||
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
||||||
return user, "", err
|
return user, "", err
|
||||||
}
|
}
|
||||||
return checkUserAndPubKey(user, pubKey)
|
return checkUserAndPubKey(&user, pubKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BoltProvider) updateLastLogin(username string) error {
|
func (p *BoltProvider) updateLastLogin(username string) error {
|
||||||
|
|
|
@ -625,14 +625,14 @@ func CheckUserAndPass(username, password, ip, protocol string) (User, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
return checkUserAndPass(user, password, ip, protocol)
|
return checkUserAndPass(&user, password, ip, protocol)
|
||||||
}
|
}
|
||||||
if config.PreLoginHook != "" {
|
if config.PreLoginHook != "" {
|
||||||
user, err := executePreLoginHook(username, LoginMethodPassword, ip, protocol)
|
user, err := executePreLoginHook(username, LoginMethodPassword, ip, protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
return checkUserAndPass(user, password, ip, protocol)
|
return checkUserAndPass(&user, password, ip, protocol)
|
||||||
}
|
}
|
||||||
return provider.validateUserAndPass(username, password, ip, protocol)
|
return provider.validateUserAndPass(username, password, ip, protocol)
|
||||||
}
|
}
|
||||||
|
@ -644,14 +644,14 @@ func CheckUserAndPubKey(username string, pubKey []byte, ip, protocol string) (Us
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, "", err
|
return user, "", err
|
||||||
}
|
}
|
||||||
return checkUserAndPubKey(user, pubKey)
|
return checkUserAndPubKey(&user, pubKey)
|
||||||
}
|
}
|
||||||
if config.PreLoginHook != "" {
|
if config.PreLoginHook != "" {
|
||||||
user, err := executePreLoginHook(username, SSHLoginMethodPublicKey, ip, protocol)
|
user, err := executePreLoginHook(username, SSHLoginMethodPublicKey, ip, protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, "", err
|
return user, "", err
|
||||||
}
|
}
|
||||||
return checkUserAndPubKey(user, pubKey)
|
return checkUserAndPubKey(&user, pubKey)
|
||||||
}
|
}
|
||||||
return provider.validateUserAndPubKey(username, pubKey)
|
return provider.validateUserAndPubKey(username, pubKey)
|
||||||
}
|
}
|
||||||
|
@ -671,11 +671,11 @@ func CheckKeyboardInteractiveAuth(username, authHook string, client ssh.Keyboard
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
return doKeyboardInteractiveAuth(user, authHook, client, ip, protocol)
|
return doKeyboardInteractiveAuth(&user, authHook, client, ip, protocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLastLogin updates the last login fields for the given SFTP user
|
// UpdateLastLogin updates the last login fields for the given SFTP user
|
||||||
func UpdateLastLogin(user User) error {
|
func UpdateLastLogin(user *User) error {
|
||||||
lastLogin := utils.GetTimeFromMsecSinceEpoch(user.LastLogin)
|
lastLogin := utils.GetTimeFromMsecSinceEpoch(user.LastLogin)
|
||||||
diff := -time.Until(lastLogin)
|
diff := -time.Until(lastLogin)
|
||||||
if diff < 0 || diff > lastLoginMinDelay {
|
if diff < 0 || diff > lastLoginMinDelay {
|
||||||
|
@ -690,7 +690,7 @@ func UpdateLastLogin(user User) error {
|
||||||
|
|
||||||
// UpdateUserQuota updates the quota for the given SFTP user adding filesAdd and sizeAdd.
|
// UpdateUserQuota updates the quota for the given SFTP user adding filesAdd and sizeAdd.
|
||||||
// If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference.
|
// If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference.
|
||||||
func UpdateUserQuota(user User, filesAdd int, sizeAdd int64, reset bool) error {
|
func UpdateUserQuota(user *User, filesAdd int, sizeAdd int64, reset bool) error {
|
||||||
if config.TrackQuota == 0 {
|
if config.TrackQuota == 0 {
|
||||||
return &MethodDisabledError{err: trackQuotaDisabledError}
|
return &MethodDisabledError{err: trackQuotaDisabledError}
|
||||||
} else if config.TrackQuota == 2 && !reset && !user.HasQuotaRestrictions() {
|
} else if config.TrackQuota == 2 && !reset && !user.HasQuotaRestrictions() {
|
||||||
|
@ -704,7 +704,7 @@ func UpdateUserQuota(user User, filesAdd int, sizeAdd int64, reset bool) error {
|
||||||
|
|
||||||
// UpdateVirtualFolderQuota updates the quota for the given virtual folder adding filesAdd and sizeAdd.
|
// UpdateVirtualFolderQuota updates the quota for the given virtual folder adding filesAdd and sizeAdd.
|
||||||
// If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference.
|
// If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference.
|
||||||
func UpdateVirtualFolderQuota(vfolder vfs.BaseVirtualFolder, filesAdd int, sizeAdd int64, reset bool) error {
|
func UpdateVirtualFolderQuota(vfolder *vfs.BaseVirtualFolder, filesAdd int, sizeAdd int64, reset bool) error {
|
||||||
if config.TrackQuota == 0 {
|
if config.TrackQuota == 0 {
|
||||||
return &MethodDisabledError{err: trackQuotaDisabledError}
|
return &MethodDisabledError{err: trackQuotaDisabledError}
|
||||||
}
|
}
|
||||||
|
@ -1482,53 +1482,53 @@ func isPasswordOK(user *User, password string) (bool, error) {
|
||||||
return match, err
|
return match, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkUserAndPass(user User, password, ip, protocol string) (User, error) {
|
func checkUserAndPass(user *User, password, ip, protocol string) (User, error) {
|
||||||
err := checkLoginConditions(&user)
|
err := checkLoginConditions(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, err
|
return *user, err
|
||||||
}
|
}
|
||||||
if user.Password == "" {
|
if user.Password == "" {
|
||||||
return user, errors.New("Credentials cannot be null or empty")
|
return *user, errors.New("Credentials cannot be null or empty")
|
||||||
}
|
}
|
||||||
hookResponse, err := executeCheckPasswordHook(user.Username, password, ip, protocol)
|
hookResponse, err := executeCheckPasswordHook(user.Username, password, ip, protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
providerLog(logger.LevelDebug, "error executing check password hook: %v", err)
|
providerLog(logger.LevelDebug, "error executing check password hook: %v", err)
|
||||||
return user, errors.New("Unable to check credentials")
|
return *user, errors.New("Unable to check credentials")
|
||||||
}
|
}
|
||||||
switch hookResponse.Status {
|
switch hookResponse.Status {
|
||||||
case -1:
|
case -1:
|
||||||
// no hook configured
|
// no hook configured
|
||||||
case 1:
|
case 1:
|
||||||
providerLog(logger.LevelDebug, "password accepted by check password hook")
|
providerLog(logger.LevelDebug, "password accepted by check password hook")
|
||||||
return user, nil
|
return *user, nil
|
||||||
case 2:
|
case 2:
|
||||||
providerLog(logger.LevelDebug, "partial success from check password hook")
|
providerLog(logger.LevelDebug, "partial success from check password hook")
|
||||||
password = hookResponse.ToVerify
|
password = hookResponse.ToVerify
|
||||||
default:
|
default:
|
||||||
providerLog(logger.LevelDebug, "password rejected by check password hook, status: %v", hookResponse.Status)
|
providerLog(logger.LevelDebug, "password rejected by check password hook, status: %v", hookResponse.Status)
|
||||||
return user, ErrInvalidCredentials
|
return *user, ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
match, err := isPasswordOK(&user, password)
|
match, err := isPasswordOK(user, password)
|
||||||
if !match {
|
if !match {
|
||||||
err = ErrInvalidCredentials
|
err = ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
return user, err
|
return *user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkUserAndPubKey(user User, pubKey []byte) (User, string, error) {
|
func checkUserAndPubKey(user *User, pubKey []byte) (User, string, error) {
|
||||||
err := checkLoginConditions(&user)
|
err := checkLoginConditions(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, "", err
|
return *user, "", err
|
||||||
}
|
}
|
||||||
if len(user.PublicKeys) == 0 {
|
if len(user.PublicKeys) == 0 {
|
||||||
return user, "", ErrInvalidCredentials
|
return *user, "", ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
for i, k := range user.PublicKeys {
|
for i, k := range user.PublicKeys {
|
||||||
storedPubKey, comment, _, _, err := ssh.ParseAuthorizedKey([]byte(k))
|
storedPubKey, comment, _, _, err := ssh.ParseAuthorizedKey([]byte(k))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
providerLog(logger.LevelWarn, "error parsing stored public key %d for user %v: %v", i, user.Username, err)
|
providerLog(logger.LevelWarn, "error parsing stored public key %d for user %v: %v", i, user.Username, err)
|
||||||
return user, "", err
|
return *user, "", err
|
||||||
}
|
}
|
||||||
if bytes.Equal(storedPubKey.Marshal(), pubKey) {
|
if bytes.Equal(storedPubKey.Marshal(), pubKey) {
|
||||||
certInfo := ""
|
certInfo := ""
|
||||||
|
@ -1537,10 +1537,10 @@ func checkUserAndPubKey(user User, pubKey []byte) (User, string, error) {
|
||||||
certInfo = fmt.Sprintf(" %v ID: %v Serial: %v CA: %v", cert.Type(), cert.KeyId, cert.Serial,
|
certInfo = fmt.Sprintf(" %v ID: %v Serial: %v CA: %v", cert.Type(), cert.KeyId, cert.Serial,
|
||||||
ssh.FingerprintSHA256(cert.SignatureKey))
|
ssh.FingerprintSHA256(cert.SignatureKey))
|
||||||
}
|
}
|
||||||
return user, fmt.Sprintf("%v:%v%v", ssh.FingerprintSHA256(storedPubKey), comment, certInfo), nil
|
return *user, fmt.Sprintf("%v:%v%v", ssh.FingerprintSHA256(storedPubKey), comment, certInfo), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return user, "", ErrInvalidCredentials
|
return *user, "", ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
func compareUnixPasswordAndHash(user *User, password string) (bool, error) {
|
func compareUnixPasswordAndHash(user *User, password string) (bool, error) {
|
||||||
|
@ -1712,7 +1712,7 @@ func sendKeyboardAuthHTTPReq(url *url.URL, request keyboardAuthHookRequest) (key
|
||||||
return response, err
|
return response, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func executeKeyboardInteractiveHTTPHook(user User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) {
|
func executeKeyboardInteractiveHTTPHook(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) {
|
||||||
authResult := 0
|
authResult := 0
|
||||||
var url *url.URL
|
var url *url.URL
|
||||||
url, err := url.Parse(authHook)
|
url, err := url.Parse(authHook)
|
||||||
|
@ -1754,7 +1754,7 @@ func executeKeyboardInteractiveHTTPHook(user User, authHook string, client ssh.K
|
||||||
}
|
}
|
||||||
|
|
||||||
func getKeyboardInteractiveAnswers(client ssh.KeyboardInteractiveChallenge, response keyboardAuthHookResponse,
|
func getKeyboardInteractiveAnswers(client ssh.KeyboardInteractiveChallenge, response keyboardAuthHookResponse,
|
||||||
user User, ip, protocol string) ([]string, error) {
|
user *User, ip, protocol string) ([]string, error) {
|
||||||
questions := response.Questions
|
questions := response.Questions
|
||||||
answers, err := client(user.Username, response.Instruction, questions, response.Echos)
|
answers, err := client(user.Username, response.Instruction, questions, response.Echos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1779,7 +1779,7 @@ func getKeyboardInteractiveAnswers(client ssh.KeyboardInteractiveChallenge, resp
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleProgramInteractiveQuestions(client ssh.KeyboardInteractiveChallenge, response keyboardAuthHookResponse,
|
func handleProgramInteractiveQuestions(client ssh.KeyboardInteractiveChallenge, response keyboardAuthHookResponse,
|
||||||
user User, stdin io.WriteCloser, ip, protocol string) error {
|
user *User, stdin io.WriteCloser, ip, protocol string) error {
|
||||||
answers, err := getKeyboardInteractiveAnswers(client, response, user, ip, protocol)
|
answers, err := getKeyboardInteractiveAnswers(client, response, user, ip, protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -1798,7 +1798,7 @@ func handleProgramInteractiveQuestions(client ssh.KeyboardInteractiveChallenge,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func executeKeyboardInteractiveProgram(user User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) {
|
func executeKeyboardInteractiveProgram(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) {
|
||||||
authResult := 0
|
authResult := 0
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -1856,7 +1856,7 @@ func executeKeyboardInteractiveProgram(user User, authHook string, client ssh.Ke
|
||||||
return authResult, err
|
return authResult, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func doKeyboardInteractiveAuth(user User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (User, error) {
|
func doKeyboardInteractiveAuth(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (User, error) {
|
||||||
var authResult int
|
var authResult int
|
||||||
var err error
|
var err error
|
||||||
if strings.HasPrefix(authHook, "http") {
|
if strings.HasPrefix(authHook, "http") {
|
||||||
|
@ -1865,16 +1865,16 @@ func doKeyboardInteractiveAuth(user User, authHook string, client ssh.KeyboardIn
|
||||||
authResult, err = executeKeyboardInteractiveProgram(user, authHook, client, ip, protocol)
|
authResult, err = executeKeyboardInteractiveProgram(user, authHook, client, ip, protocol)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, err
|
return *user, err
|
||||||
}
|
}
|
||||||
if authResult != 1 {
|
if authResult != 1 {
|
||||||
return user, fmt.Errorf("keyboard interactive auth failed, result: %v", authResult)
|
return *user, fmt.Errorf("keyboard interactive auth failed, result: %v", authResult)
|
||||||
}
|
}
|
||||||
err = checkLoginConditions(&user)
|
err = checkLoginConditions(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, err
|
return *user, err
|
||||||
}
|
}
|
||||||
return user, nil
|
return *user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isCheckPasswordHookDefined(protocol string) bool {
|
func isCheckPasswordHookDefined(protocol string) bool {
|
||||||
|
|
|
@ -99,7 +99,7 @@ func (p *MemoryProvider) validateUserAndPass(username, password, ip, protocol st
|
||||||
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
return checkUserAndPass(user, password, ip, protocol)
|
return checkUserAndPass(&user, password, ip, protocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *MemoryProvider) validateUserAndPubKey(username string, pubKey []byte) (User, string, error) {
|
func (p *MemoryProvider) validateUserAndPubKey(username string, pubKey []byte) (User, string, error) {
|
||||||
|
@ -112,7 +112,7 @@ func (p *MemoryProvider) validateUserAndPubKey(username string, pubKey []byte) (
|
||||||
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
||||||
return user, "", err
|
return user, "", err
|
||||||
}
|
}
|
||||||
return checkUserAndPubKey(user, pubKey)
|
return checkUserAndPubKey(&user, pubKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *MemoryProvider) validateAdminAndPass(username, password, ip string) (Admin, error) {
|
func (p *MemoryProvider) validateAdminAndPass(username, password, ip string) (Admin, error) {
|
||||||
|
|
|
@ -225,23 +225,23 @@ func (p *MySQLProvider) initializeDatabase() error {
|
||||||
return ErrNoInitRequired
|
return ErrNoInitRequired
|
||||||
}
|
}
|
||||||
sqlUsers := strings.Replace(mysqlUsersTableSQL, "{{users}}", sqlTableUsers, 1)
|
sqlUsers := strings.Replace(mysqlUsersTableSQL, "{{users}}", sqlTableUsers, 1)
|
||||||
tx, err := p.dbHandle.Begin()
|
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
tx, err := p.dbHandle.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = tx.Exec(sqlUsers)
|
_, err = tx.Exec(sqlUsers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = tx.Exec(strings.Replace(mysqlSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
|
_, err = tx.Exec(strings.Replace(mysqlSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
|
_, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
|
|
|
@ -229,23 +229,23 @@ func (p *PGSQLProvider) initializeDatabase() error {
|
||||||
return ErrNoInitRequired
|
return ErrNoInitRequired
|
||||||
}
|
}
|
||||||
sqlUsers := strings.Replace(pgsqlUsersTableSQL, "{{users}}", sqlTableUsers, 1)
|
sqlUsers := strings.Replace(pgsqlUsersTableSQL, "{{users}}", sqlTableUsers, 1)
|
||||||
tx, err := p.dbHandle.Begin()
|
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
tx, err := p.dbHandle.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = tx.Exec(sqlUsers)
|
_, err = tx.Exec(sqlUsers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = tx.Exec(strings.Replace(pgsqlSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
|
_, err = tx.Exec(strings.Replace(pgsqlSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
|
_, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
|
|
|
@ -224,7 +224,7 @@ func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHan
|
||||||
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
return checkUserAndPass(user, password, ip, protocol)
|
return checkUserAndPass(&user, password, ip, protocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sql.DB) (User, string, error) {
|
func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sql.DB) (User, string, error) {
|
||||||
|
@ -237,7 +237,7 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sq
|
||||||
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
|
||||||
return user, "", err
|
return user, "", err
|
||||||
}
|
}
|
||||||
return checkUserAndPubKey(user, pubKey)
|
return checkUserAndPubKey(&user, pubKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func sqlCommonCheckAvailability(dbHandle *sql.DB) error {
|
func sqlCommonCheckAvailability(dbHandle *sql.DB) error {
|
||||||
|
@ -313,6 +313,7 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
tx, err := dbHandle.BeginTx(ctx, nil)
|
tx, err := dbHandle.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -321,40 +322,33 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
|
||||||
stmt, err := tx.PrepareContext(ctx, q)
|
stmt, err := tx.PrepareContext(ctx, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
permissions, err := user.GetPermissionsAsJSON()
|
permissions, err := user.GetPermissionsAsJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
publicKeys, err := user.GetPublicKeysAsJSON()
|
publicKeys, err := user.GetPublicKeysAsJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
filters, err := user.GetFiltersAsJSON()
|
filters, err := user.GetFiltersAsJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fsConfig, err := user.GetFsConfigAsJSON()
|
fsConfig, err := user.GetFsConfigAsJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = stmt.ExecContext(ctx, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
|
_, err = stmt.ExecContext(ctx, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
|
||||||
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters),
|
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters),
|
||||||
string(fsConfig), user.AdditionalInfo)
|
string(fsConfig), user.AdditionalInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = generateVirtualFoldersMapping(ctx, user, tx)
|
err = generateVirtualFoldersMapping(ctx, user, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
|
@ -367,6 +361,7 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
tx, err := dbHandle.BeginTx(ctx, nil)
|
tx, err := dbHandle.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -375,40 +370,33 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
|
||||||
stmt, err := tx.PrepareContext(ctx, q)
|
stmt, err := tx.PrepareContext(ctx, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
permissions, err := user.GetPermissionsAsJSON()
|
permissions, err := user.GetPermissionsAsJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
publicKeys, err := user.GetPublicKeysAsJSON()
|
publicKeys, err := user.GetPublicKeysAsJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
filters, err := user.GetFiltersAsJSON()
|
filters, err := user.GetFiltersAsJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fsConfig, err := user.GetFsConfigAsJSON()
|
fsConfig, err := user.GetFsConfigAsJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = stmt.ExecContext(ctx, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
|
_, err = stmt.ExecContext(ctx, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
|
||||||
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate,
|
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate,
|
||||||
string(filters), string(fsConfig), user.AdditionalInfo, user.ID)
|
string(filters), string(fsConfig), user.AdditionalInfo, user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = generateVirtualFoldersMapping(ctx, user, tx)
|
err = generateVirtualFoldersMapping(ctx, user, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
|
@ -979,13 +967,6 @@ func sqlCommonGetFolderUsedQuota(mappedPath string, dbHandle *sql.DB) (int, int6
|
||||||
return usedFiles, usedSize, err
|
return usedFiles, usedSize, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func sqlCommonRollbackTransaction(tx *sql.Tx) {
|
|
||||||
err := tx.Rollback()
|
|
||||||
if err != nil {
|
|
||||||
providerLog(logger.LevelWarn, "error rolling back transaction: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func sqlCommonGetDatabaseVersion(dbHandle *sql.DB, showInitWarn bool) (schemaVersion, error) {
|
func sqlCommonGetDatabaseVersion(dbHandle *sql.DB, showInitWarn bool) (schemaVersion, error) {
|
||||||
var result schemaVersion
|
var result schemaVersion
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||||
|
@ -1030,13 +1011,11 @@ func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sql []string, newVersi
|
||||||
}
|
}
|
||||||
_, err = tx.ExecContext(ctx, q)
|
_, err = tx.ExecContext(ctx, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion)
|
err = sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
|
@ -1130,6 +1109,7 @@ func sqlCommonUpdateDatabaseFrom3To4(sqlV4 string, dbHandle *sql.DB) error {
|
||||||
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
|
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
tx, err := dbHandle.BeginTx(ctx, nil)
|
tx, err := dbHandle.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -1140,25 +1120,14 @@ func sqlCommonUpdateDatabaseFrom3To4(sqlV4 string, dbHandle *sql.DB) error {
|
||||||
}
|
}
|
||||||
_, err = tx.ExecContext(ctx, q)
|
_, err = tx.ExecContext(ctx, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*_, err = sqlCommonRestoreCompatVirtualFolders(ctx, users, tx)
|
|
||||||
if err != nil {
|
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
|
||||||
}*/
|
|
||||||
err = sqlCommonUpdateDatabaseVersion(ctx, tx, 4)
|
err = sqlCommonUpdateDatabaseVersion(ctx, tx, 4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
/*if err == nil {
|
|
||||||
go updateVFoldersQuotaAfterRestore(foldersToScan)
|
|
||||||
}
|
|
||||||
return err*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:dupl
|
//nolint:dupl
|
||||||
|
|
|
@ -262,23 +262,23 @@ func (p *SQLiteProvider) initializeDatabase() error {
|
||||||
return ErrNoInitRequired
|
return ErrNoInitRequired
|
||||||
}
|
}
|
||||||
sqlUsers := strings.Replace(sqliteUsersTableSQL, "{{users}}", sqlTableUsers, 1)
|
sqlUsers := strings.Replace(sqliteUsersTableSQL, "{{users}}", sqlTableUsers, 1)
|
||||||
tx, err := p.dbHandle.Begin()
|
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
tx, err := p.dbHandle.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = tx.Exec(sqlUsers)
|
_, err = tx.Exec(sqlUsers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = tx.Exec(strings.Replace(sqliteSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
|
_, err = tx.Exec(strings.Replace(sqliteSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
|
_, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqlCommonRollbackTransaction(tx)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
|
|
|
@ -470,12 +470,12 @@ func (c *Connection) handleFTPUploadToExistingFile(flags int, resolvedPath, file
|
||||||
if vfs.IsLocalOrSFTPFs(c.Fs) {
|
if vfs.IsLocalOrSFTPFs(c.Fs) {
|
||||||
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
|
||||||
if vfolder.IsIncludedInUserQuota() {
|
if vfolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
initialSize = fileSize
|
initialSize = fileSize
|
||||||
|
|
|
@ -152,7 +152,7 @@ func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string)
|
||||||
connection.Fs.CheckRootPath(connection.GetUsername(), user.GetUID(), user.GetGID())
|
connection.Fs.CheckRootPath(connection.GetUsername(), user.GetUID(), user.GetGID())
|
||||||
connection.Log(logger.LevelInfo, "User id: %d, logged in with FTP, username: %#v, home_dir: %#v remote addr: %#v",
|
connection.Log(logger.LevelInfo, "User id: %d, logged in with FTP, username: %#v, home_dir: %#v remote addr: %#v",
|
||||||
user.ID, user.Username, user.HomeDir, ipAddr)
|
user.ID, user.Username, user.HomeDir, ipAddr)
|
||||||
dataprovider.UpdateLastLogin(user) //nolint:errcheck
|
dataprovider.UpdateLastLogin(&user) //nolint:errcheck
|
||||||
return connection, nil
|
return connection, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,7 @@ func updateUserQuotaUsage(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer common.QuotaScans.RemoveUserQuotaScan(user.Username)
|
defer common.QuotaScans.RemoveUserQuotaScan(user.Username)
|
||||||
err = dataprovider.UpdateUserQuota(user, u.UsedQuotaFiles, u.UsedQuotaSize, mode == quotaUpdateModeReset)
|
err = dataprovider.UpdateUserQuota(&user, u.UsedQuotaFiles, u.UsedQuotaSize, mode == quotaUpdateModeReset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendAPIResponse(w, r, err, "", getRespStatus(err))
|
sendAPIResponse(w, r, err, "", getRespStatus(err))
|
||||||
} else {
|
} else {
|
||||||
|
@ -94,7 +94,7 @@ func updateVFolderQuotaUsage(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name)
|
defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name)
|
||||||
err = dataprovider.UpdateVirtualFolderQuota(folder, f.UsedQuotaFiles, f.UsedQuotaSize, mode == quotaUpdateModeReset)
|
err = dataprovider.UpdateVirtualFolderQuota(&folder, f.UsedQuotaFiles, f.UsedQuotaSize, mode == quotaUpdateModeReset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendAPIResponse(w, r, err, "", getRespStatus(err))
|
sendAPIResponse(w, r, err, "", getRespStatus(err))
|
||||||
} else {
|
} else {
|
||||||
|
@ -165,7 +165,7 @@ func doQuotaScan(user dataprovider.User) error {
|
||||||
logger.Warn(logSender, "", "error scanning user home dir %#v: %v", user.Username, err)
|
logger.Warn(logSender, "", "error scanning user home dir %#v: %v", user.Username, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = dataprovider.UpdateUserQuota(user, numFiles, size, true)
|
err = dataprovider.UpdateUserQuota(&user, numFiles, size, true)
|
||||||
logger.Debug(logSender, "", "user home dir scanned, user: %#v, error: %v", user.Username, err)
|
logger.Debug(logSender, "", "user home dir scanned, user: %#v, error: %v", user.Username, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -178,7 +178,7 @@ func doFolderQuotaScan(folder vfs.BaseVirtualFolder) error {
|
||||||
logger.Warn(logSender, "", "error scanning folder %#v: %v", folder.MappedPath, err)
|
logger.Warn(logSender, "", "error scanning folder %#v: %v", folder.MappedPath, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = dataprovider.UpdateVirtualFolderQuota(folder, numFiles, size, true)
|
err = dataprovider.UpdateVirtualFolderQuota(&folder, numFiles, size, true)
|
||||||
logger.Debug(logSender, "", "virtual folder %#v scanned, error: %v", folder.Name, err)
|
logger.Debug(logSender, "", "virtual folder %#v scanned, error: %v", folder.Name, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -123,7 +123,7 @@ type foldersPage struct {
|
||||||
|
|
||||||
type connectionsPage struct {
|
type connectionsPage struct {
|
||||||
basePage
|
basePage
|
||||||
Connections []common.ConnectionStatus
|
Connections []*common.ConnectionStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
type statusPage struct {
|
type statusPage struct {
|
||||||
|
|
|
@ -69,7 +69,7 @@ func (l *LeveledLogger) addKeysAndValues(ev *zerolog.Event, keysAndValues ...int
|
||||||
extra := keysAndValues[kvLen-1]
|
extra := keysAndValues[kvLen-1]
|
||||||
keysAndValues = append(keysAndValues[:kvLen-1], "EXTRA_VALUE_AT_END", extra)
|
keysAndValues = append(keysAndValues[:kvLen-1], "EXTRA_VALUE_AT_END", extra)
|
||||||
}
|
}
|
||||||
for i := 0; i < len(keysAndValues); i = i + 2 {
|
for i := 0; i < len(keysAndValues); i += 2 {
|
||||||
key, val := keysAndValues[i], keysAndValues[i+1]
|
key, val := keysAndValues[i], keysAndValues[i+1]
|
||||||
if keyStr, ok := key.(string); ok {
|
if keyStr, ok := key.(string); ok {
|
||||||
ev.Str(keyStr, fmt.Sprintf("%v", val))
|
ev.Str(keyStr, fmt.Sprintf("%v", val))
|
||||||
|
|
|
@ -412,12 +412,12 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r
|
||||||
if vfs.IsLocalOrSFTPFs(c.Fs) && isTruncate {
|
if vfs.IsLocalOrSFTPFs(c.Fs) && isTruncate {
|
||||||
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
|
||||||
if vfolder.IsIncludedInUserQuota() {
|
if vfolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
initialSize = fileSize
|
initialSize = fileSize
|
||||||
|
@ -460,7 +460,7 @@ func (c *Connection) getStatVFSFromQuotaResult(name string, quotaResult vfs.Quot
|
||||||
|
|
||||||
bsize := uint64(4096)
|
bsize := uint64(4096)
|
||||||
for bsize > uint64(quotaResult.QuotaSize) {
|
for bsize > uint64(quotaResult.QuotaSize) {
|
||||||
bsize = bsize / 4
|
bsize /= 4
|
||||||
}
|
}
|
||||||
blocks := uint64(quotaResult.QuotaSize) / bsize
|
blocks := uint64(quotaResult.QuotaSize) / bsize
|
||||||
bfree := uint64(quotaResult.QuotaSize-quotaResult.UsedSize) / bsize
|
bfree := uint64(quotaResult.QuotaSize-quotaResult.UsedSize) / bsize
|
||||||
|
|
|
@ -365,7 +365,7 @@ func TestUploadFiles(t *testing.T) {
|
||||||
func TestWithInvalidHome(t *testing.T) {
|
func TestWithInvalidHome(t *testing.T) {
|
||||||
u := dataprovider.User{}
|
u := dataprovider.User{}
|
||||||
u.HomeDir = "home_rel_path" //nolint:goconst
|
u.HomeDir = "home_rel_path" //nolint:goconst
|
||||||
_, err := loginUser(u, dataprovider.LoginMethodPassword, "", nil)
|
_, err := loginUser(&u, dataprovider.LoginMethodPassword, "", nil)
|
||||||
assert.Error(t, err, "login a user with an invalid home_dir must fail")
|
assert.Error(t, err, "login a user with an invalid home_dir must fail")
|
||||||
|
|
||||||
u.HomeDir = os.TempDir()
|
u.HomeDir = os.TempDir()
|
||||||
|
@ -1890,7 +1890,7 @@ func TestRecursiveCopyErrors(t *testing.T) {
|
||||||
func TestSFTPSubSystem(t *testing.T) {
|
func TestSFTPSubSystem(t *testing.T) {
|
||||||
permissions := make(map[string][]string)
|
permissions := make(map[string][]string)
|
||||||
permissions["/"] = []string{dataprovider.PermAny}
|
permissions["/"] = []string{dataprovider.PermAny}
|
||||||
user := dataprovider.User{
|
user := &dataprovider.User{
|
||||||
Permissions: permissions,
|
Permissions: permissions,
|
||||||
HomeDir: os.TempDir(),
|
HomeDir: os.TempDir(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -213,12 +213,12 @@ func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead
|
||||||
if vfs.IsLocalOrSFTPFs(c.connection.Fs) {
|
if vfs.IsLocalOrSFTPFs(c.connection.Fs) {
|
||||||
vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
|
||||||
if vfolder.IsIncludedInUserQuota() {
|
if vfolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(c.connection.User, 0, -fileSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
dataprovider.UpdateUserQuota(c.connection.User, 0, -fileSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
initialSize = fileSize
|
initialSize = fileSize
|
||||||
|
|
|
@ -408,7 +408,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
|
||||||
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
|
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
|
||||||
"User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
|
"User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
|
||||||
user.ID, loginType, user.Username, user.HomeDir, ipAddr)
|
user.ID, loginType, user.Username, user.HomeDir, ipAddr)
|
||||||
dataprovider.UpdateLastLogin(user) //nolint:errcheck
|
dataprovider.UpdateLastLogin(&user) //nolint:errcheck
|
||||||
|
|
||||||
sshConnection := common.NewSSHConnection(connectionID, conn)
|
sshConnection := common.NewSSHConnection(connectionID, conn)
|
||||||
common.Connections.AddSSHConnection(sshConnection)
|
common.Connections.AddSSHConnection(sshConnection)
|
||||||
|
@ -557,7 +557,7 @@ func checkRootPath(user *dataprovider.User, connectionID string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func loginUser(user dataprovider.User, loginMethod, publicKey string, conn ssh.ConnMetadata) (*ssh.Permissions, error) {
|
func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.ConnMetadata) (*ssh.Permissions, error) {
|
||||||
connectionID := ""
|
connectionID := ""
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
connectionID = hex.EncodeToString(conn.SessionID())
|
connectionID = hex.EncodeToString(conn.SessionID())
|
||||||
|
@ -817,7 +817,7 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK
|
||||||
logger.Debug(logSender, connectionID, "user %#v authenticated with partial success", conn.User())
|
logger.Debug(logSender, connectionID, "user %#v authenticated with partial success", conn.User())
|
||||||
return certPerm, ssh.ErrPartialSuccess
|
return certPerm, ssh.ErrPartialSuccess
|
||||||
}
|
}
|
||||||
sshPerm, err = loginUser(user, method, keyID, conn)
|
sshPerm, err = loginUser(&user, method, keyID, conn)
|
||||||
if err == nil && certPerm != nil {
|
if err == nil && certPerm != nil {
|
||||||
// if we have a SSH user cert we need to merge certificate permissions with our ones
|
// if we have a SSH user cert we need to merge certificate permissions with our ones
|
||||||
// we only set Extensions, so CriticalOptions are always the ones from the certificate
|
// we only set Extensions, so CriticalOptions are always the ones from the certificate
|
||||||
|
@ -845,7 +845,7 @@ func (c *Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass
|
||||||
}
|
}
|
||||||
ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
|
ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
|
||||||
if user, err = dataprovider.CheckUserAndPass(conn.User(), string(pass), ipAddr, common.ProtocolSSH); err == nil {
|
if user, err = dataprovider.CheckUserAndPass(conn.User(), string(pass), ipAddr, common.ProtocolSSH); err == nil {
|
||||||
sshPerm, err = loginUser(user, method, "", conn)
|
sshPerm, err = loginUser(&user, method, "", conn)
|
||||||
}
|
}
|
||||||
user.Username = conn.User()
|
user.Username = conn.User()
|
||||||
updateLoginMetrics(&user, ipAddr, method, err)
|
updateLoginMetrics(&user, ipAddr, method, err)
|
||||||
|
@ -864,7 +864,7 @@ func (c *Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMeta
|
||||||
ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
|
ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
|
||||||
if user, err = dataprovider.CheckKeyboardInteractiveAuth(conn.User(), c.KeyboardInteractiveHook, client,
|
if user, err = dataprovider.CheckKeyboardInteractiveAuth(conn.User(), c.KeyboardInteractiveHook, client,
|
||||||
ipAddr, common.ProtocolSSH); err == nil {
|
ipAddr, common.ProtocolSSH); err == nil {
|
||||||
sshPerm, err = loginUser(user, method, "", conn)
|
sshPerm, err = loginUser(&user, method, "", conn)
|
||||||
}
|
}
|
||||||
user.Username = conn.User()
|
user.Username = conn.User()
|
||||||
updateLoginMetrics(&user, ipAddr, method, err)
|
updateLoginMetrics(&user, ipAddr, method, err)
|
||||||
|
|
|
@ -6490,7 +6490,7 @@ func TestStatVFSCloudBackend(t *testing.T) {
|
||||||
if assert.NoError(t, err) {
|
if assert.NoError(t, err) {
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
|
||||||
err = dataprovider.UpdateUserQuota(user, 100, 8192, true)
|
err = dataprovider.UpdateUserQuota(&user, 100, 8192, true)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
stat, err := client.StatVFS("/")
|
stat, err := client.StatVFS("/")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
|
@ -248,12 +248,12 @@ func (c *sshCommand) handeSFTPGoRemove() error {
|
||||||
func (c *sshCommand) updateQuota(sshDestPath string, filesNum int, filesSize int64) {
|
func (c *sshCommand) updateQuota(sshDestPath string, filesNum int, filesSize int64) {
|
||||||
vfolder, err := c.connection.User.GetVirtualFolderForPath(sshDestPath)
|
vfolder, err := c.connection.User.GetVirtualFolderForPath(sshDestPath)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, filesNum, filesSize, false) //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, filesNum, filesSize, false) //nolint:errcheck
|
||||||
if vfolder.IsIncludedInUserQuota() {
|
if vfolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(c.connection.User, filesNum, filesSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.connection.User, filesNum, filesSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
dataprovider.UpdateUserQuota(c.connection.User, filesNum, filesSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.connection.User, filesNum, filesSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ func newSubsystemChannel(reader io.Reader, writer io.Writer) *subsystemChannel {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeSubSystemConnection handles a connection as SSH subsystem
|
// ServeSubSystemConnection handles a connection as SSH subsystem
|
||||||
func ServeSubSystemConnection(user dataprovider.User, connectionID string, reader io.Reader, writer io.Writer) error {
|
func ServeSubSystemConnection(user *dataprovider.User, connectionID string, reader io.Reader, writer io.Writer) error {
|
||||||
fs, err := user.GetFilesystem(connectionID)
|
fs, err := user.GetFilesystem(connectionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -44,7 +44,7 @@ func ServeSubSystemConnection(user dataprovider.User, connectionID string, reade
|
||||||
dataprovider.UpdateLastLogin(user) //nolint:errcheck
|
dataprovider.UpdateLastLogin(user) //nolint:errcheck
|
||||||
|
|
||||||
connection := &Connection{
|
connection := &Connection{
|
||||||
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolSFTP, user, fs),
|
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolSFTP, *user, fs),
|
||||||
ClientVersion: "",
|
ClientVersion: "",
|
||||||
RemoteAddr: &net.IPAddr{},
|
RemoteAddr: &net.IPAddr{},
|
||||||
channel: newSubsystemChannel(reader, writer),
|
channel: newSubsystemChannel(reader, writer),
|
||||||
|
|
|
@ -253,7 +253,7 @@ func (fs *AzureBlobFs) Create(name string, flag int) (File, *PipeWriter, func(),
|
||||||
// if we shutdown Azurite while uploading it hangs, so we use our own wrapper for
|
// if we shutdown Azurite while uploading it hangs, so we use our own wrapper for
|
||||||
// the low level functions
|
// the low level functions
|
||||||
_, err := azblob.UploadStreamToBlockBlob(ctx, r, blobBlockURL, uploadOptions)*/
|
_, err := azblob.UploadStreamToBlockBlob(ctx, r, blobBlockURL, uploadOptions)*/
|
||||||
err := fs.handleMultipartUpload(ctx, r, blobBlockURL, headers)
|
err := fs.handleMultipartUpload(ctx, r, &blobBlockURL, &headers)
|
||||||
r.CloseWithError(err) //nolint:errcheck
|
r.CloseWithError(err) //nolint:errcheck
|
||||||
p.Done(err)
|
p.Done(err)
|
||||||
fsLog(fs, logger.LevelDebug, "upload completed, path: %#v, readed bytes: %v, err: %v", name, r.GetReadedBytes(), err)
|
fsLog(fs, logger.LevelDebug, "upload completed, path: %#v, readed bytes: %v, err: %v", name, r.GetReadedBytes(), err)
|
||||||
|
@ -438,7 +438,8 @@ func (fs *AzureBlobFs) ReadDir(dirname string) ([]os.FileInfo, error) {
|
||||||
result = append(result, NewFileInfo(name, true, 0, time.Now(), false))
|
result = append(result, NewFileInfo(name, true, 0, time.Now(), false))
|
||||||
prefixes[strings.TrimSuffix(name, "/")] = true
|
prefixes[strings.TrimSuffix(name, "/")] = true
|
||||||
}
|
}
|
||||||
for _, blobInfo := range listBlob.Segment.BlobItems {
|
for idx := range listBlob.Segment.BlobItems {
|
||||||
|
blobInfo := &listBlob.Segment.BlobItems[idx]
|
||||||
name := strings.TrimPrefix(blobInfo.Name, prefix)
|
name := strings.TrimPrefix(blobInfo.Name, prefix)
|
||||||
size := int64(0)
|
size := int64(0)
|
||||||
if blobInfo.Properties.ContentLength != nil {
|
if blobInfo.Properties.ContentLength != nil {
|
||||||
|
@ -556,7 +557,8 @@ func (fs *AzureBlobFs) ScanRootDirContents() (int, int64, error) {
|
||||||
return numFiles, size, err
|
return numFiles, size, err
|
||||||
}
|
}
|
||||||
marker = listBlob.NextMarker
|
marker = listBlob.NextMarker
|
||||||
for _, blobInfo := range listBlob.Segment.BlobItems {
|
for idx := range listBlob.Segment.BlobItems {
|
||||||
|
blobInfo := &listBlob.Segment.BlobItems[idx]
|
||||||
isDir := false
|
isDir := false
|
||||||
if blobInfo.Properties.ContentType != nil {
|
if blobInfo.Properties.ContentType != nil {
|
||||||
isDir = (*blobInfo.Properties.ContentType == dirMimeType)
|
isDir = (*blobInfo.Properties.ContentType == dirMimeType)
|
||||||
|
@ -637,7 +639,8 @@ func (fs *AzureBlobFs) Walk(root string, walkFn filepath.WalkFunc) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
marker = listBlob.NextMarker
|
marker = listBlob.NextMarker
|
||||||
for _, blobInfo := range listBlob.Segment.BlobItems {
|
for idx := range listBlob.Segment.BlobItems {
|
||||||
|
blobInfo := &listBlob.Segment.BlobItems[idx]
|
||||||
isDir := false
|
isDir := false
|
||||||
if blobInfo.Properties.ContentType != nil {
|
if blobInfo.Properties.ContentType != nil {
|
||||||
isDir = (*blobInfo.Properties.ContentType == dirMimeType)
|
isDir = (*blobInfo.Properties.ContentType == dirMimeType)
|
||||||
|
@ -776,8 +779,8 @@ func (fs *AzureBlobFs) hasContents(name string) (bool, error) {
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Reader, blockBlobURL azblob.BlockBlobURL,
|
func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Reader, blockBlobURL *azblob.BlockBlobURL,
|
||||||
httpHeaders azblob.BlobHTTPHeaders) error {
|
httpHeaders *azblob.BlobHTTPHeaders) error {
|
||||||
partSize := fs.config.UploadPartSize
|
partSize := fs.config.UploadPartSize
|
||||||
guard := make(chan struct{}, fs.config.UploadConcurrency)
|
guard := make(chan struct{}, fs.config.UploadConcurrency)
|
||||||
blockCtxTimeout := time.Duration(fs.config.UploadPartSize/(1024*1024)) * time.Minute
|
blockCtxTimeout := time.Duration(fs.config.UploadPartSize/(1024*1024)) * time.Minute
|
||||||
|
@ -852,7 +855,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
|
||||||
return poolError
|
return poolError
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := blockBlobURL.CommitBlockList(ctx, blocks, httpHeaders, azblob.Metadata{}, azblob.BlobAccessConditions{},
|
_, err := blockBlobURL.CommitBlockList(ctx, blocks, *httpHeaders, azblob.Metadata{}, azblob.BlobAccessConditions{},
|
||||||
azblob.AccessTierType(fs.config.AccessTier), nil, azblob.ClientProvidedKeyOptions{})
|
azblob.AccessTierType(fs.config.AccessTier), nil, azblob.ClientProvidedKeyOptions{})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -264,12 +264,12 @@ func (c *Connection) handleUploadToExistingFile(resolvedPath, filePath string, f
|
||||||
if vfs.IsLocalOrSFTPFs(c.Fs) {
|
if vfs.IsLocalOrSFTPFs(c.Fs) {
|
||||||
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
|
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
|
||||||
if vfolder.IsIncludedInUserQuota() {
|
if vfolder.IsIncludedInUserQuota() {
|
||||||
dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
dataprovider.UpdateUserQuota(c.User, 0, -fileSize, false) //nolint:errcheck
|
dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
initialSize = fileSize
|
initialSize = fileSize
|
||||||
|
|
|
@ -181,7 +181,7 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
common.Connections.Add(connection)
|
common.Connections.Add(connection)
|
||||||
defer common.Connections.Remove(connection.GetID())
|
defer common.Connections.Remove(connection.GetID())
|
||||||
|
|
||||||
dataprovider.UpdateLastLogin(user) //nolint:errcheck
|
dataprovider.UpdateLastLogin(&user) //nolint:errcheck
|
||||||
|
|
||||||
if s.checkRequestMethod(ctx, r, connection) {
|
if s.checkRequestMethod(ctx, r, connection) {
|
||||||
w.Header().Set("Content-Type", "text/xml; charset=utf-8")
|
w.Header().Set("Content-Type", "text/xml; charset=utf-8")
|
||||||
|
|
Loading…
Reference in a new issue