diff --git a/dataprovider/user.go b/dataprovider/user.go index 798d5de8..1b9dc848 100644 --- a/dataprovider/user.go +++ b/dataprovider/user.go @@ -293,24 +293,6 @@ func (u *User) HasOverlappedMappedPaths() bool { return false } -// GetRemaingQuotaSize returns the available quota size for the given SFTP path -func (u *User) GetRemaingQuotaSize(sftpPath string) int64 { - vfolder, err := u.GetVirtualFolderForPath(sftpPath) - if err == nil { - if vfolder.IsIncludedInUserQuota() && u.QuotaSize > 0 { - return u.QuotaSize - u.UsedQuotaSize - } - if vfolder.QuotaSize > 0 { - return vfolder.QuotaSize - vfolder.UsedQuotaSize - } - } else { - if u.QuotaSize > 0 { - return u.QuotaSize - u.UsedQuotaSize - } - } - return 0 -} - // HasPerm returns true if the user has the given permission or any permission func (u *User) HasPerm(permission, path string) bool { perms := u.GetPermissionsForPath(path) diff --git a/sftpd/handler.go b/sftpd/handler.go index fcefe7eb..41d9286c 100644 --- a/sftpd/handler.go +++ b/sftpd/handler.go @@ -498,7 +498,8 @@ func (c Connection) handleSFTPRemove(filePath string, request *sftp.Request) err } func (c Connection) handleSFTPUploadToNewFile(resolvedPath, filePath, requestPath string) (io.WriterAt, error) { - if !c.hasSpace(true, requestPath) { + quotaResult := c.hasSpace(true, requestPath) + if !quotaResult.HasSpace { c.Log(logger.LevelInfo, logSender, "denying file write due to quota limits") return nil, sftp.ErrSSHFxFailure } @@ -539,7 +540,8 @@ func (c Connection) handleSFTPUploadToNewFile(resolvedPath, filePath, requestPat func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, resolvedPath, filePath string, fileSize int64, requestPath string) (io.WriterAt, error) { var err error - if !c.hasSpace(false, requestPath) { + quotaResult := c.hasSpace(false, requestPath) + if !quotaResult.HasSpace { c.Log(logger.LevelInfo, logSender, "denying file write due to quota limits") return nil, sftp.ErrSSHFxFailure } @@ -636,10 +638,12 @@ func (c Connection) hasSpaceForRename(request *sftp.Request, initialSize int64, // rename between user root dir and a virtual folder included in user quota return true } - if !c.hasSpace(true, request.Target) { + quotaResult := c.hasSpace(true, request.Target) + if !quotaResult.HasSpace { if initialSize != -1 { // we are overquota but we are overwriting a file so we check the quota size - if c.hasSpace(false, request.Target) { + quotaResult = c.hasSpace(false, request.Target) + if quotaResult.HasSpace { // we have enough quota size return true } @@ -655,41 +659,53 @@ func (c Connection) hasSpaceForRename(request *sftp.Request, initialSize int64, return true } -func (c Connection) hasSpace(checkFiles bool, requestPath string) bool { - if dataprovider.GetQuotaTracking() == 0 { - return true +func (c Connection) hasSpace(checkFiles bool, requestPath string) vfs.QuotaCheckResult { + result := vfs.QuotaCheckResult{ + HasSpace: true, + AllowedSize: 0, + AllowedFiles: 0, + UsedSize: 0, + UsedFiles: 0, + QuotaSize: 0, + QuotaFiles: 0, + } + + if dataprovider.GetQuotaTracking() == 0 { + return result } - var quotaSize, usedSize int64 - var quotaFiles, numFiles int var err error var vfolder vfs.VirtualFolder vfolder, err = c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil && !vfolder.IsIncludedInUserQuota() { if vfolder.HasNoQuotaRestrictions(checkFiles) { - return true + return result } - quotaSize = vfolder.QuotaSize - quotaFiles = vfolder.QuotaFiles - numFiles, usedSize, err = dataprovider.GetUsedVirtualFolderQuota(dataProvider, vfolder.MappedPath) + result.QuotaSize = vfolder.QuotaSize + result.QuotaFiles = vfolder.QuotaFiles + result.UsedFiles, result.UsedSize, err = dataprovider.GetUsedVirtualFolderQuota(dataProvider, vfolder.MappedPath) } else { if c.User.HasNoQuotaRestrictions(checkFiles) { - return true + return result } - quotaSize = c.User.QuotaSize - quotaFiles = c.User.QuotaFiles - numFiles, usedSize, err = dataprovider.GetUsedQuota(dataProvider, c.User.Username) + result.QuotaSize = c.User.QuotaSize + result.QuotaFiles = c.User.QuotaFiles + result.UsedFiles, result.UsedSize, err = dataprovider.GetUsedQuota(dataProvider, c.User.Username) } if err != nil { c.Log(logger.LevelWarn, logSender, "error getting used quota for %#v request path %#v: %v", c.User.Username, requestPath, err) - return false + result.HasSpace = false + return result } - if (checkFiles && quotaFiles > 0 && numFiles >= quotaFiles) || - (quotaSize > 0 && usedSize >= quotaSize) { + result.AllowedFiles = result.QuotaFiles - result.UsedFiles + result.AllowedSize = result.QuotaSize - result.UsedSize + if (checkFiles && result.QuotaFiles > 0 && result.UsedFiles >= result.QuotaFiles) || + (result.QuotaSize > 0 && result.UsedSize >= result.QuotaSize) { c.Log(logger.LevelDebug, logSender, "quota exceed for user %#v, request path %#v, num files: %v/%v, size: %v/%v check files: %v", - c.User.Username, requestPath, numFiles, quotaFiles, usedSize, quotaSize, checkFiles) - return false + c.User.Username, requestPath, result.UsedFiles, result.QuotaFiles, result.UsedSize, result.QuotaSize, checkFiles) + result.HasSpace = false + return result } - return true + return result } func (c Connection) close() error { diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 72ae4e2f..392bee33 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -587,7 +587,8 @@ func TestSFTPGetUsedQuota(t *testing.T) { connection := Connection{ User: u, } - assert.False(t, connection.hasSpace(false, "/")) + quotaResult := connection.hasSpace(false, "/") + assert.False(t, quotaResult.HasSpace) } func TestSupportedSSHCommands(t *testing.T) { diff --git a/sftpd/scp.go b/sftpd/scp.go index f1d7bc2e..0952aead 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -188,7 +188,8 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err } func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error { - if !c.connection.hasSpace(true, requestPath) { + quotaResult := c.connection.hasSpace(true, requestPath) + if !quotaResult.HasSpace { err := fmt.Errorf("denying file write due to quota limits") c.connection.Log(logger.LevelWarn, logSenderSCP, "error uploading file: %#v, err: %v", filePath, err) c.sendErrorMessage(err) diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index 56dfe976..a6c4a05e 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -5455,6 +5455,169 @@ func TestSSHCopy(t *testing.T) { assert.NoError(t, err) } +func TestSSHCopyQuotaLimits(t *testing.T) { + usePubKey := false + testFileSize := int64(131072) + testFileSize1 := int64(65536) + testFileSize2 := int64(32768) + u := getTestUser(usePubKey) + u.QuotaFiles = 3 + u.QuotaSize = testFileSize + testFileSize1 + 1 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + vdirPath2 := "/vdir2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: mappedPath1, + }, + VirtualPath: vdirPath1, + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: mappedPath2, + }, + VirtualPath: vdirPath2, + QuotaFiles: 3, + QuotaSize: testFileSize + testFileSize1 + 1, + }) + u.Filters.FileExtensions = []dataprovider.ExtensionsFilter{ + { + Path: "/", + DeniedExtensions: []string{".denied"}, + }, + } + err := os.MkdirAll(mappedPath1, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath2, os.ModePerm) + assert.NoError(t, err) + user, _, err := httpd.AddUser(u, http.StatusOK) + assert.NoError(t, err) + client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer client.Close() + testDir := "testDir" + testFileName := "test_file.dat" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileName1 := "test_file1.dat" + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + testFileName2 := "test_file2.dat" + testFilePath2 := filepath.Join(homeBasePath, testFileName2) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + err = createTestFile(testFilePath2, testFileSize2) + assert.NoError(t, err) + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, testDir)) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath2, path.Join(testDir, testFileName2), testFileSize2, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath2, path.Join(testDir, testFileName2+".dupl"), testFileSize2, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath2, path.Join(vdirPath2, testDir, testFileName2), testFileSize2, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath2, path.Join(vdirPath2, testDir, testFileName2+".dupl"), testFileSize2, client) + assert.NoError(t, err) + // user quota: 2 files, size: 32768*2, folder2 quota: 2 files, size: 32768*2 + // try to duplicate testDir, this will result in 4 file (over quota) and 32768*4 bytes (not over quota) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", testDir, testDir+"_copy"), user, usePubKey) + assert.Error(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath2, testDir), + path.Join(vdirPath2, testDir+"_copy")), user, usePubKey) + assert.Error(t, err) + + _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testDir), user, usePubKey) + assert.NoError(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir)), user, usePubKey) + assert.NoError(t, err) + user, _, err = httpd.GetUserByID(user.ID, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + folder, _, err := httpd.GetFolders(0, 0, mappedPath1, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, folder, 1) { + f := folder[0] + assert.Equal(t, 0, f.UsedQuotaFiles) + assert.Equal(t, int64(0), f.UsedQuotaSize) + } + folder, _, err = httpd.GetFolders(0, 0, mappedPath2, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, folder, 1) { + f := folder[0] + assert.Equal(t, 0, f.UsedQuotaFiles) + assert.Equal(t, int64(0), f.UsedQuotaSize) + } + err = client.Mkdir(path.Join(vdirPath1, testDir)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, testDir)) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, testDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, testDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + + // vdir1 is included in user quota, file limit will be exceeded + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir), "/"), user, usePubKey) + assert.Error(t, err) + + // vdir2 size limit will be exceeded + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir, testFileName), + vdirPath2+"/"), user, usePubKey) + assert.Error(t, err) + // now decrease the limits + user.QuotaFiles = 1 + user.QuotaSize = testFileSize * 10 + user.VirtualFolders[1].QuotaSize = testFileSize + user.VirtualFolders[1].QuotaFiles = 10 + user, _, err = httpd.UpdateUser(user, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.QuotaFiles) + assert.Equal(t, testFileSize*10, user.QuotaSize) + if assert.Len(t, user.VirtualFolders, 2) { + f := user.VirtualFolders[1] + assert.Equal(t, testFileSize, f.QuotaSize) + assert.Equal(t, 10, f.QuotaFiles) + } + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir), + path.Join(vdirPath2, testDir+".copy")), user, usePubKey) + assert.Error(t, err) + + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath2, testDir), + testDir+".copy"), user, usePubKey) + assert.Error(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + err = os.Remove(testFilePath2) + assert.NoError(t, err) + } + + _, err = httpd.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpd.RemoveFolder(vfs.BaseVirtualFolder{MappedPath: mappedPath1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpd.RemoveFolder(vfs.BaseVirtualFolder{MappedPath: mappedPath2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + func TestSSHRemove(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) diff --git a/sftpd/ssh_cmd.go b/sftpd/ssh_cmd.go index bac5f791..40214a55 100644 --- a/sftpd/ssh_cmd.go +++ b/sftpd/ssh_cmd.go @@ -163,6 +163,9 @@ func (c *sshCommand) handeSFTPGoCopy() error { err := errors.New("unsupported copy source: only files and directories are supported") return c.sendErrorResponse(err) } + if err := c.checkCopyQuota(filesNum, filesSize, sshDestPath); err != nil { + return c.sendErrorResponse(err) + } c.connection.Log(logger.LevelDebug, logSenderSSH, "start copy %#v -> %#v", fsSourcePath, fsDestPath) err = fscopy.Copy(fsSourcePath, fsDestPath) if err != nil { @@ -301,7 +304,8 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { return c.sendErrorResponse(errUnsupportedConfig) } sshDestPath := c.getDestPath() - if !c.connection.hasSpace(true, command.quotaCheckPath) { + quotaResult := c.connection.hasSpace(true, command.quotaCheckPath) + if !quotaResult.HasSpace { return c.sendErrorResponse(errQuotaExceeded) } perms := []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermCreateDirs, dataprovider.PermListItems, @@ -342,9 +346,10 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { var once sync.Once commandResponse := make(chan bool) + remainingQuotaSize := quotaResult.GetRemainingSize() + go func() { defer stdin.Close() - remainingQuotaSize := c.connection.User.GetRemaingQuotaSize(sshDestPath) transfer := Transfer{ file: nil, path: command.fsPath, @@ -629,6 +634,30 @@ func (c *sshCommand) checkCopyDestination(fsDestPath string) error { return nil } +func (c *sshCommand) checkCopyQuota(numFiles int, filesSize int64, requestPath string) error { + quotaResult := c.connection.hasSpace(true, requestPath) + if !quotaResult.HasSpace { + return errQuotaExceeded + } + if quotaResult.QuotaFiles > 0 { + remainingFiles := quotaResult.GetRemainingFiles() + if remainingFiles < numFiles { + c.connection.Log(logger.LevelDebug, logSenderSSH, "copy not allowed, file limit will be exceeded, "+ + "remaining files: %v to copy: %v", remainingFiles, numFiles) + return errQuotaExceeded + } + } + if quotaResult.QuotaSize > 0 { + remainingSize := quotaResult.GetRemainingSize() + if remainingSize < filesSize { + c.connection.Log(logger.LevelDebug, logSenderSSH, "copy not allowed, size limit will be exceeded, "+ + "remaining size: %v to copy: %v", remainingSize, filesSize) + return errQuotaExceeded + } + } + return nil +} + func (c *sshCommand) getSizeForPath(name string) (int, int64, error) { if dataprovider.GetQuotaTracking() > 0 { fi, err := c.connection.fs.Lstat(name) diff --git a/vfs/vfs.go b/vfs/vfs.go index ed465c2d..859c24c7 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -45,6 +45,33 @@ type Fs interface { Join(elem ...string) string } +// QuotaCheckResult defines the result for a quota check +type QuotaCheckResult struct { + HasSpace bool + AllowedSize int64 + AllowedFiles int + UsedSize int64 + UsedFiles int + QuotaSize int64 + QuotaFiles int +} + +// GetRemainingSize returns the remaining allowed size +func (q *QuotaCheckResult) GetRemainingSize() int64 { + if q.QuotaSize > 0 { + return q.QuotaSize - q.UsedSize + } + return 0 +} + +// GetRemainigFiles returns the remaining allowed files +func (q *QuotaCheckResult) GetRemainingFiles() int { + if q.QuotaFiles > 0 { + return q.QuotaFiles - q.UsedFiles + } + return 0 +} + // S3FsConfig defines the configuration for S3 based filesystem type S3FsConfig struct { Bucket string `json:"bucket,omitempty"`