scp: fix quota update after file overwrite

added a test case too
This commit is contained in:
Nicola Murino 2019-09-02 23:12:41 +02:00
parent 4a1b67454e
commit dc5eeb54fd
4 changed files with 17 additions and 8 deletions

View file

@ -57,7 +57,7 @@ Version info, such as git commit and build date, can be embedded setting the fol
For example you can build using the following command: For example you can build using the following command:
```bash ```bash
go build -i -ldflags "-s -w -X github.com/drakkan/sftpgo/utils.commit=`git describe --tags --always --dirty` -X github.com/drakkan/sftpgo/utils.date=`date -u +%FT%TZ`" -o sftpgo go build -i -ldflags "-s -w -X github.com/drakkan/sftpgo/utils.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/utils.date=`date -u +%FT%TZ`" -o sftpgo
``` ```
and you will get a version that includes git commit and build date like this one: and you will get a version that includes git commit and build date like this one:
@ -71,7 +71,7 @@ For Linux, a systemd sample [service](https://github.com/drakkan/sftpgo/tree/mas
Alternately you can use distro packages: Alternately you can use distro packages:
- Arch Linux PKGBUILD is available on [AUR](https://aur.archlinux.org/packages/sftpgo-git/ "SFTPGo") - Arch Linux PKGBUILD is available on [AUR](https://aur.archlinux.org/packages/sftpgo/ "SFTPGo")
## Configuration ## Configuration

View file

@ -188,8 +188,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err
return c.sendConfirmationMessage() return c.sendConfirmationMessage()
} }
func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead int64) error { func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead int64, isNewFile bool) error {
logger.Debug(logSenderSCP, "upload to new file: %v", filePath)
if !c.connection.hasSpace(true) { if !c.connection.hasSpace(true) {
err := fmt.Errorf("denying file write due to space limit") err := fmt.Errorf("denying file write due to space limit")
logger.Warn(logSenderSCP, "error uploading file: %v, err: %v", filePath, err) logger.Warn(logSenderSCP, "error uploading file: %v, err: %v", filePath, err)
@ -225,7 +224,7 @@ func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead i
connectionID: c.connection.ID, connectionID: c.connection.ID,
transferType: transferUpload, transferType: transferUpload,
lastActivity: time.Now(), lastActivity: time.Now(),
isNewFile: true, isNewFile: isNewFile,
protocol: c.connection.protocol, protocol: c.connection.protocol,
} }
addTransfer(&transfer) addTransfer(&transfer)
@ -256,7 +255,7 @@ func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error
} }
stat, statErr := os.Stat(p) stat, statErr := os.Stat(p)
if os.IsNotExist(statErr) { if os.IsNotExist(statErr) {
return c.handleUploadFile(p, filePath, sizeToRead) return c.handleUploadFile(p, filePath, sizeToRead, true)
} }
if statErr != nil { if statErr != nil {
@ -284,7 +283,7 @@ func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error
dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -stat.Size(), false) dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -stat.Size(), false)
return c.handleUploadFile(p, filePath, sizeToRead) return c.handleUploadFile(p, filePath, sizeToRead, false)
} }
func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileInfo) error { func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileInfo) error {

View file

@ -1503,10 +1503,12 @@ func TestSCPUploadFileOverwrite(t *testing.T) {
} }
usePubKey := true usePubKey := true
u := getTestUser(usePubKey) u := getTestUser(usePubKey)
u.QuotaFiles = 1000
user, _, err := api.AddUser(u, http.StatusOK) user, _, err := api.AddUser(u, http.StatusOK)
if err != nil { if err != nil {
t.Errorf("unable to add user: %v", err) t.Errorf("unable to add user: %v", err)
} }
os.RemoveAll(user.GetHomeDir())
testFileName := "test_file.dat" testFileName := "test_file.dat"
testFilePath := filepath.Join(homeBasePath, testFileName) testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(32760) testFileSize := int64(32760)
@ -1524,6 +1526,14 @@ func TestSCPUploadFileOverwrite(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("error uploading existing file via scp: %v", err) t.Errorf("error uploading existing file via scp: %v", err)
} }
user, _, err = api.GetUserByID(user.ID, http.StatusOK)
if err != nil {
t.Errorf("error getting user: %v", err)
}
if user.UsedQuotaSize != testFileSize || user.UsedQuotaFiles != 1 {
t.Errorf("update quota error on file overwrite, actual size: %v, expected: %v actual files: %v, expected: 1",
user.UsedQuotaSize, testFileSize, user.UsedQuotaFiles)
}
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
localPath := filepath.Join(homeBasePath, "scp_download.dat") localPath := filepath.Join(homeBasePath, "scp_download.dat")
err = scpDownload(localPath, remoteDownPath, false, false) err = scpDownload(localPath, remoteDownPath, false, false)

View file

@ -1,6 +1,6 @@
package utils package utils
const version = "0.9.1" const version = "0.9.1-dev"
var ( var (
commit = "" commit = ""