From d481294519d802d35142a6d1b5af9b27f1ea1060 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Thu, 23 Jan 2020 10:19:56 +0100 Subject: [PATCH] S3: fix quota update after an upload error S3 uploads are atomic, if the upload fails we have no partial file so we have to update the user quota only if the upload succeed --- sftpd/handler.go | 8 +++++++- sftpd/internal_test.go | 37 +++++++++++++++++++++++++++++++++++++ sftpd/scp.go | 17 ++++++++++++----- sftpd/sftpd_test.go | 8 ++++---- sftpd/transfer.go | 17 ++++++++++++++--- 5 files changed, 74 insertions(+), 13 deletions(-) diff --git a/sftpd/handler.go b/sftpd/handler.go index ac96f71d..adb29e9d 100644 --- a/sftpd/handler.go +++ b/sftpd/handler.go @@ -479,11 +479,16 @@ func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, re return nil, vfs.GetSFTPError(c.fs, err) } + initialSize := int64(0) if pflags.Append && osFlags&os.O_TRUNC == 0 { c.Log(logger.LevelDebug, logSender, "upload resume requested, file path: %#v initial size: %v", filePath, fileSize) minWriteOffset = fileSize } else { - dataprovider.UpdateUserQuota(dataProvider, c.User, 0, -fileSize, false) + if vfs.IsLocalOsFs(c.fs) { + dataprovider.UpdateUserQuota(dataProvider, c.User, 0, -fileSize, false) + } else { + initialSize = fileSize + } } vfs.SetPathPermissions(c.fs, filePath, c.User.GetUID(), c.User.GetGID()) @@ -506,6 +511,7 @@ func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, re transferError: nil, isFinished: false, minWriteOffset: minWriteOffset, + initialSize: initialSize, lock: new(sync.Mutex), } addTransfer(&transfer) diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 66e44040..86a44913 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -425,6 +425,28 @@ func TestUploadFiles(t *testing.T) { if err == nil { t.Errorf("upload new file in missing path must fail") } + c.fs = newMockOsFs(nil, nil, false, "123", os.TempDir()) + f, _ := ioutil.TempFile("", "temp") + f.Close() + _, err = c.handleSFTPUploadToExistingFile(flags, f.Name(), f.Name(), 123) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if len(activeTransfers) != 1 { + t.Errorf("unexpected number of transfer, expected 1, current: %v", len(activeTransfers)) + } + transfer := activeTransfers[0] + if transfer.initialSize != 123 { + t.Errorf("unexpected initial size: %v", transfer.initialSize) + } + err = transfer.Close() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if len(activeTransfers) != 0 { + t.Errorf("unexpected number of transfer, expected 0, current: %v", len(activeTransfers)) + } + os.Remove(f.Name()) uploadMode = oldUploadMode } @@ -899,6 +921,17 @@ func TestSystemCommandErrors(t *testing.T) { } } +func TestTransferUpdateQuota(t *testing.T) { + transfer := Transfer{ + transferType: transferUpload, + bytesReceived: 123, + lock: new(sync.Mutex)} + transfer.TransferError(errors.New("fake error")) + if transfer.updateQuota(1) { + t.Errorf("update quota must fail, there is a error and this is a remote upload") + } +} + func TestGetConnectionInfo(t *testing.T) { c := ConnectionStatus{ Username: "test_user", @@ -1222,6 +1255,10 @@ func TestSCPErrorsMockFs(t *testing.T) { if err != errFake { t.Errorf("unexpected error: %v", err) } + err = scpCommand.handleUploadFile(testfile, testfile, 0, false, 4) + if err != nil { + t.Errorf("unexpected error: %v", err) + } os.Remove(testfile) } diff --git a/sftpd/scp.go b/sftpd/scp.go index 6871861c..94552d49 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -181,7 +181,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err return c.sendConfirmationMessage() } -func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead int64, isNewFile bool) error { +func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64) error { if !c.connection.hasSpace(true) { err := fmt.Errorf("denying file write due to space limit") c.connection.Log(logger.LevelWarn, logSenderSCP, "error uploading file: %#v, err: %v", filePath, err) @@ -189,6 +189,14 @@ func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead i return err } + initialSize := int64(0) + if !isNewFile { + if vfs.IsLocalOsFs(c.connection.fs) { + dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -fileSize, false) + } else { + initialSize = fileSize + } + } file, w, cancelFn, err := c.connection.fs.Create(filePath, 0) if err != nil { c.connection.Log(logger.LevelError, logSenderSCP, "error creating file %#v: %v", requestPath, err) @@ -216,6 +224,7 @@ func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead i transferError: nil, isFinished: false, minWriteOffset: 0, + initialSize: initialSize, lock: new(sync.Mutex), } addTransfer(&transfer) @@ -246,7 +255,7 @@ func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error c.sendErrorMessage(err.Error()) return err } - return c.handleUploadFile(p, filePath, sizeToRead, true) + return c.handleUploadFile(p, filePath, sizeToRead, true, 0) } if statErr != nil { @@ -279,9 +288,7 @@ func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error } } - dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -stat.Size(), false) - - return c.handleUploadFile(p, filePath, sizeToRead, false) + return c.handleUploadFile(p, filePath, sizeToRead, false, stat.Size()) } func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileInfo) error { diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index 2dbd712b..d0a25d53 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -3123,7 +3123,7 @@ func TestResolvePaths(t *testing.T) { } path = "../test/sub" resolved, err = fs.ResolvePath(filepath.ToSlash(path)) - if fs.Name() == "osfs" { + if vfs.IsLocalOsFs(fs) { if err == nil { t.Errorf("Unexpected resolved path: %v for: %v, fs: %v", resolved, path, fs.Name()) } @@ -3134,7 +3134,7 @@ func TestResolvePaths(t *testing.T) { } path = "../../../test/../sub" resolved, err = fs.ResolvePath(filepath.ToSlash(path)) - if fs.Name() == "osfs" { + if vfs.IsLocalOsFs(fs) { if err == nil { t.Errorf("Unexpected resolved path: %v for: %v, fs: %v", resolved, path, fs.Name()) } @@ -4624,7 +4624,7 @@ func getKeyboardInteractiveScriptContent(questions []string, sleepTime int, nonJ content := []byte("#!/bin/sh\n\n") q, _ := json.Marshal(questions) echos := []bool{} - for index, _ := range questions { + for index := range questions { echos = append(echos, index%2 == 0) } e, _ := json.Marshal(echos) @@ -4633,7 +4633,7 @@ func getKeyboardInteractiveScriptContent(questions []string, sleepTime int, nonJ } else { content = append(content, []byte(fmt.Sprintf("echo '{\"questions\":%v,\"echos\":%v}'\n", string(q), string(e)))...) } - for index, _ := range questions { + for index := range questions { content = append(content, []byte(fmt.Sprintf("read ANSWER%v\n", index))...) } if sleepTime > 0 { diff --git a/sftpd/transfer.go b/sftpd/transfer.go index 7ec2fa55..5e691323 100644 --- a/sftpd/transfer.go +++ b/sftpd/transfer.go @@ -44,6 +44,7 @@ type Transfer struct { isFinished bool minWriteOffset int64 expectedSize int64 + initialSize int64 lock *sync.Mutex } @@ -163,9 +164,7 @@ func (t *Transfer) Close() error { } metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError) removeTransfer(t) - if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) { - dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived, false) - } + t.updateQuota(numFiles) return err } @@ -181,6 +180,18 @@ func (t *Transfer) closeIO() error { return err } +func (t *Transfer) updateQuota(numFiles int) bool { + // S3 uploads are atomic, if there is an error nothing is uploaded + if t.file == nil && t.transferError != nil { + return false + } + if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) { + dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived-t.initialSize, false) + return true + } + return false +} + func (t *Transfer) checkDownloadSize() { if t.transferType == transferDownload && t.transferError == nil && t.bytesSent < t.expectedSize { t.transferError = fmt.Errorf("incomplete download: %v/%v bytes transferred", t.bytesSent, t.expectedSize)