diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go index 138ddb8d..09479840 100644 --- a/dataprovider/sqlcommon.go +++ b/dataprovider/sqlcommon.go @@ -96,39 +96,17 @@ func sqlCommonGetUserByID(ID int64) (User, error) { } func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, p Provider) error { - var usedFiles int - var usedSize int64 - var err error - if reset { - usedFiles = 0 - usedSize = 0 - } else { - usedFiles, usedSize, err = p.getUsedQuota(username) - if err != nil { - return err - } - } - usedFiles += filesAdd - usedSize += sizeAdd - if usedFiles < 0 { - logger.Warn(logSender, "used files is negative, probably some files were added and not tracked, please rescan quota!") - usedFiles = 0 - } - if usedSize < 0 { - logger.Warn(logSender, "used files is negative, probably some files were added and not tracked, please rescan quota!") - usedSize = 0 - } - - q := getUpdateQuotaQuery() + q := getUpdateQuotaQuery(reset) stmt, err := dbHandle.Prepare(q) if err != nil { logger.Debug(logSender, "error preparing database query %v: %v", q, err) return err } defer stmt.Close() - _, err = stmt.Exec(usedSize, usedFiles, utils.GetTimeAsMsSinceEpoch(time.Now()), username) + _, err = stmt.Exec(sizeAdd, filesAdd, utils.GetTimeAsMsSinceEpoch(time.Now()), username) if err == nil { - logger.Debug(logSender, "quota updated for user %v, new files: %v new size: %v", username, usedFiles, usedSize) + logger.Debug(logSender, "quota updated for user %v, files increment: %v size increment: %v is reset? %v", + username, filesAdd, sizeAdd, reset) } else { logger.Warn(logSender, "error updating quota for username %v: %v", username, err) } diff --git a/dataprovider/sqlqueries.go b/dataprovider/sqlqueries.go index 07d89cdb..393bc552 100644 --- a/dataprovider/sqlqueries.go +++ b/dataprovider/sqlqueries.go @@ -36,8 +36,12 @@ func getUsersQuery(order string, username string) string { order, sqlPlaceholders[0], sqlPlaceholders[1]) } -func getUpdateQuotaQuery() string { - return fmt.Sprintf(`UPDATE %v SET used_quota_size = %v,used_quota_files = %v,last_quota_scan = %v +func getUpdateQuotaQuery(reset bool) string { + if reset { + return fmt.Sprintf(`UPDATE %v SET used_quota_size = %v,used_quota_files = %v,last_quota_scan = %v + WHERE username = %v`, config.UsersTable, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) + } + return fmt.Sprintf(`UPDATE %v SET used_quota_size = used_quota_size + %v,used_quota_files = used_quota_files + %v,last_quota_scan = %v WHERE username = %v`, config.UsersTable, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } diff --git a/sftpd/handler.go b/sftpd/handler.go index 3b17280e..1b857440 100644 --- a/sftpd/handler.go +++ b/sftpd/handler.go @@ -168,8 +168,9 @@ func (c Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) { } if trunc { + // the file is truncated so we need to decrease quota size but not quota files logger.Debug(logSender, "file truncation requested update quota for user %v", c.User.Username) - dataprovider.UpdateUserQuota(dataProvider, c.User.Username, -1, -stat.Size(), false) + dataprovider.UpdateUserQuota(dataProvider, c.User.Username, 0, -stat.Size(), false) } utils.SetPathPermissions(p, c.User.GetUID(), c.User.GetGID()) @@ -185,7 +186,7 @@ func (c Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) { user: c.User, connectionID: c.ID, transferType: transferUpload, - isNewFile: trunc, + isNewFile: !trunc, } addTransfer(&transfer) return &transfer, nil diff --git a/sftpd/sftpd.go b/sftpd/sftpd.go index 020d28f6..c4b144cc 100644 --- a/sftpd/sftpd.go +++ b/sftpd/sftpd.go @@ -298,8 +298,8 @@ func executeAction(operation string, username string, path string, target string if _, err = os.Stat(actions.Command); err == nil { command := exec.Command(actions.Command, operation, username, path, target) err = command.Start() - logger.Debug(logSender, "executed command \"%v\" with arguments: %v, %v, %v, error: %v", - actions.Command, operation, path, target, err) + logger.Debug(logSender, "executed command \"%v\" with arguments: %v, %v, %v, %v, error: %v", + actions.Command, operation, username, path, target, err) } else { logger.Warn(logSender, "Invalid action command \"%v\" : %v", actions.Command, err) } diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index 1c785ef9..7870513f 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -574,6 +574,66 @@ func TestMaxSessions(t *testing.T) { } } +func TestQuotaFileReplace(t *testing.T) { + usePubKey := false + user, err := api.AddUser(getTestUser(usePubKey), http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + testFileSize := int64(65535) + expectedQuotaSize := user.UsedQuotaSize + testFileSize + expectedQuotaFiles := user.UsedQuotaFiles + 1 + testFileName := "test_file.dat" + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + if err != nil { + t.Errorf("unable to create test file: %v", err) + } + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if err != nil { + t.Errorf("file upload error: %v", err) + } + user, err = api.GetUserByID(user.ID, http.StatusOK) + if err != nil { + t.Errorf("error getting user: %v", err) + } + if expectedQuotaFiles != user.UsedQuotaFiles { + t.Errorf("quota files does not match, expected: %v, actual: %v", expectedQuotaFiles, user.UsedQuotaFiles) + } + if expectedQuotaSize != user.UsedQuotaSize { + t.Errorf("quota size does not match, expected: %v, actual: %v", expectedQuotaSize, user.UsedQuotaSize) + } + // now replace the same file, the quota must not change + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if err != nil { + t.Errorf("file upload error: %v", err) + } + user, err = api.GetUserByID(user.ID, http.StatusOK) + if err != nil { + t.Errorf("error getting user: %v", err) + } + if expectedQuotaFiles != user.UsedQuotaFiles { + t.Errorf("quota files does not match, expected: %v, actual: %v", expectedQuotaFiles, user.UsedQuotaFiles) + } + if expectedQuotaSize != user.UsedQuotaSize { + t.Errorf("quota size does not match, expected: %v, actual: %v", expectedQuotaSize, user.UsedQuotaSize) + } + err = client.Remove(testFileName) + if err != nil { + t.Errorf("error removing uploaded file: %v", err) + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + func TestQuotaScan(t *testing.T) { usePubKey := false user, err := api.AddUser(getTestUser(usePubKey), http.StatusOK) @@ -670,14 +730,18 @@ func TestQuotaSize(t *testing.T) { if err != nil { t.Errorf("unable to create test file: %v", err) } - err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + err = sftpUploadFile(testFilePath, testFileName+".quota", testFileSize, client) if err != nil { t.Errorf("file upload error: %v", err) } - err = sftpUploadFile(testFilePath, testFileName+".1", testFileSize, client) + err = sftpUploadFile(testFilePath, testFileName+".quota.1", testFileSize, client) if err == nil { t.Errorf("user is over quota file upload must fail") } + err = client.Remove(testFileName + ".quota") + if err != nil { + t.Errorf("error removing uploaded file: %v", err) + } } err = api.RemoveUser(user, http.StatusOK) if err != nil { diff --git a/sftpd/transfer.go b/sftpd/transfer.go index 235567c5..e5853207 100644 --- a/sftpd/transfer.go +++ b/sftpd/transfer.go @@ -57,9 +57,11 @@ func (t *Transfer) Close() error { executeAction(operationUpload, t.user.Username, t.path, "") } removeTransfer(t) - if t.transferType == transferUpload && t.bytesReceived > 0 && t.isNewFile { + if t.transferType == transferUpload { numFiles := 0 - numFiles++ + if t.isNewFile { + numFiles = 1 + } dataprovider.UpdateUserQuota(dataProvider, t.user.Username, numFiles, t.bytesReceived, false) } return err