From bc5779e26f238e147536ed3b8642731d29f02c1f Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sat, 7 Sep 2019 23:10:20 +0200 Subject: [PATCH] actions: don't execute actions on errors detect upload/download errors and don't execute actions if a transfer error happen. To detect SFTP errors this patch is needed: https://github.com/pkg/sftp/pull/307 --- docker/sftpgo/alpine/README.md | 6 +- sftpd/handler.go | 6 ++ sftpd/internal_test.go | 2 + sftpd/scp.go | 8 +++ sftpd/sftpd_test.go | 114 ++++++++++++++++++++++++++------- sftpd/transfer.go | 28 ++++++-- 6 files changed, 132 insertions(+), 32 deletions(-) diff --git a/docker/sftpgo/alpine/README.md b/docker/sftpgo/alpine/README.md index cd41fb6b..5546c502 100644 --- a/docker/sftpgo/alpine/README.md +++ b/docker/sftpgo/alpine/README.md @@ -36,8 +36,6 @@ The script `entrypoint.sh` makes sure to correct the permissions of directories Several images can be run with another parameters. ### Custom systemd script -An example of systemd script is present [here](sftpgo-docker.service), with `Environment` parameter to set `PUID` and `GUID` +An example of systemd script is present [here](sftpgo.service), with `Environment` parameter to set `PUID` and `GUID` -`WorkingDirectory` parameter must be exist with one file in this directory like `sftpgo-${PUID}.env` corresponding to the variable file for SFTPgo instance. - -Enjoy \ No newline at end of file +`WorkingDirectory` parameter must be exist with one file in this directory like `sftpgo-${PUID}.env` corresponding to the variable file for SFTPGo instance. \ No newline at end of file diff --git a/sftpd/handler.go b/sftpd/handler.go index bb4cab22..e5ea92d3 100644 --- a/sftpd/handler.go +++ b/sftpd/handler.go @@ -85,6 +85,8 @@ func (c Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) { lastActivity: time.Now(), isNewFile: false, protocol: c.protocol, + transferError: nil, + isFinished: false, } addTransfer(&transfer) return &transfer, nil @@ -380,6 +382,8 @@ func (c Connection) handleSFTPUploadToNewFile(requestPath, filePath string) (io. lastActivity: time.Now(), isNewFile: true, protocol: c.protocol, + transferError: nil, + isFinished: false, } addTransfer(&transfer) return &transfer, nil @@ -434,6 +438,8 @@ func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, re lastActivity: time.Now(), isNewFile: false, protocol: c.protocol, + transferError: nil, + isFinished: false, } addTransfer(&transfer) return &transfer, nil diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 61ebb3dc..875f5e09 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -620,6 +620,8 @@ func TestSCPUploadFiledata(t *testing.T) { lastActivity: time.Now(), isNewFile: true, protocol: connection.protocol, + transferError: nil, + isFinished: false, } addTransfer(&transfer) err := scpCommand.getUploadFileData(2, &transfer) diff --git a/sftpd/scp.go b/sftpd/scp.go index 73814738..62c4e3a3 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -151,6 +151,7 @@ func (c *scpCommand) handleCreateDir(dirPath string) error { func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) error { err := c.sendConfirmationMessage() if err != nil { + transfer.TransferError(err) transfer.Close() return err } @@ -162,6 +163,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err n, err := c.channel.Read(buf) if err != nil { c.sendErrorMessage(err.Error()) + transfer.TransferError(err) transfer.Close() return err } @@ -177,6 +179,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err } err = c.readConfirmationMessage() if err != nil { + transfer.TransferError(err) transfer.Close() return err } @@ -226,6 +229,8 @@ func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead i lastActivity: time.Now(), isNewFile: isNewFile, protocol: c.connection.protocol, + transferError: nil, + isFinished: false, } addTransfer(&transfer) @@ -468,6 +473,8 @@ func (c *scpCommand) handleDownload(filePath string) error { lastActivity: time.Now(), isNewFile: false, protocol: c.connection.protocol, + transferError: nil, + isFinished: false, } addTransfer(&transfer) @@ -477,6 +484,7 @@ func (c *scpCommand) handleDownload(filePath string) error { if err == nil { err = transfer.Close() } else { + transfer.TransferError(err) transfer.Close() } return err diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index e0cac1b4..7d0dfc2b 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -1998,6 +1998,68 @@ func TestSCPRemoteToRemote(t *testing.T) { } } +func TestSCPErrors(t *testing.T) { + if len(scpPath) == 0 { + t.Skip("scp command not found, unable to execute this test") + } + u := getTestUser(true) + u.UploadBandwidth = 4096 + u.DownloadBandwidth = 4096 + user, _, err := api.AddUser(u, http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + testFileSize := int64(524288) + testFileName := "test_file.dat" + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + if err != nil { + t.Errorf("unable to create test file: %v", err) + } + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + localPath := filepath.Join(homeBasePath, "scp_download.dat") + err = scpUpload(testFilePath, remoteUpPath, false, false) + if err != nil { + t.Errorf("error uploading file via scp: %v", err) + } + cmd := getScpDownloadCommand(localPath, remoteDownPath, false, false) + go func() { + if cmd.Run() == nil { + t.Errorf("SCP download must fail") + } + }() + waitForActiveTransfer() + // wait some additional arbitrary time to wait for transfer activity to happen + // it is need to reach all the code in CheckIdleConnections + time.Sleep(100 * time.Millisecond) + cmd.Process.Kill() + + cmd = getScpUploadCommand(testFilePath, remoteUpPath, false, false) + go func() { + if cmd.Run() == nil { + t.Errorf("SCP upload must fail") + } + }() + waitForActiveTransfer() + // wait some additional arbitrary time to wait for transfer activity to happen + // it is need to reach all the code in CheckIdleConnections + time.Sleep(100 * time.Millisecond) + cmd.Process.Kill() + err = os.Remove(testFilePath) + if err != nil { + t.Errorf("error removing test file") + } + err = os.RemoveAll(user.GetHomeDir()) + if err != nil { + t.Errorf("error removing uploaded files") + } + _, err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + // End SCP tests func waitTCPListening(address string) { @@ -2178,6 +2240,35 @@ func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expe } func scpUpload(localPath, remotePath string, preserveTime, remoteToRemote bool) error { + cmd := getScpUploadCommand(localPath, remotePath, preserveTime, remoteToRemote) + return cmd.Run() +} + +func scpDownload(localPath, remotePath string, preserveTime, recursive bool) error { + cmd := getScpDownloadCommand(localPath, remotePath, preserveTime, recursive) + return cmd.Run() +} + +func getScpDownloadCommand(localPath, remotePath string, preserveTime, recursive bool) *exec.Cmd { + var args []string + if preserveTime { + args = append(args, "-p") + } + if recursive { + args = append(args, "-r") + } + args = append(args, "-P") + args = append(args, "2022") + args = append(args, "-o") + args = append(args, "StrictHostKeyChecking=no") + args = append(args, "-i") + args = append(args, privateKeyPath) + args = append(args, remotePath) + args = append(args, localPath) + return exec.Command(scpPath, args...) +} + +func getScpUploadCommand(localPath, remotePath string, preserveTime, remoteToRemote bool) *exec.Cmd { var args []string if remoteToRemote { args = append(args, "-3") @@ -2199,28 +2290,7 @@ func scpUpload(localPath, remotePath string, preserveTime, remoteToRemote bool) args = append(args, privateKeyPath) args = append(args, localPath) args = append(args, remotePath) - cmd := exec.Command(scpPath, args...) - return cmd.Run() -} - -func scpDownload(localPath, remotePath string, preserveTime, recursive bool) error { - var args []string - if preserveTime { - args = append(args, "-p") - } - if recursive { - args = append(args, "-r") - } - args = append(args, "-P") - args = append(args, "2022") - args = append(args, "-o") - args = append(args, "StrictHostKeyChecking=no") - args = append(args, "-i") - args = append(args, privateKeyPath) - args = append(args, remotePath) - args = append(args, localPath) - cmd := exec.Command(scpPath, args...) - return cmd.Run() + return exec.Command(scpPath, args...) } func waitForActiveTransfer() { diff --git a/sftpd/transfer.go b/sftpd/transfer.go index de162bc1..414907b8 100644 --- a/sftpd/transfer.go +++ b/sftpd/transfer.go @@ -32,6 +32,16 @@ type Transfer struct { lastActivity time.Time isNewFile bool protocol string + transferError error + isFinished bool +} + +// TransferError is called if there is an unexpected error. +// For example network or client issues +func (t *Transfer) TransferError(err error) { + t.transferError = err + logger.Warn(logSender, t.connectionID, "Unexpected error for transfer, path: %#v, error: %v bytes sent: %v,"+ + "bytes received: %v", t.path, t.transferError, t.bytesSent, t.bytesReceived) } // ReadAt reads len(p) bytes from the File to download starting at byte offset off and updates the bytes sent. @@ -58,18 +68,23 @@ func (t *Transfer) WriteAt(p []byte, off int64) (n int, err error) { // It closes the underlying file, log the transfer info, update the user quota, for uploads, and execute any defined actions. func (t *Transfer) Close() error { err := t.file.Close() + if t.isFinished { + return err + } if t.transferType == transferUpload && t.file.Name() != t.path { err = os.Rename(t.file.Name(), t.path) logger.Debug(logSender, t.connectionID, "atomic upload completed, rename: %#v -> %#v, error: %v", t.file.Name(), t.path, err) } elapsed := time.Since(t.start).Nanoseconds() / 1000000 - if t.transferType == transferDownload { - logger.TransferLog(downloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID, t.protocol) - executeAction(operationDownload, t.user.Username, t.path, "") - } else { - logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol) - executeAction(operationUpload, t.user.Username, t.path, "") + if t.transferError == nil { + if t.transferType == transferDownload { + logger.TransferLog(downloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID, t.protocol) + executeAction(operationDownload, t.user.Username, t.path, "") + } else { + logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol) + executeAction(operationUpload, t.user.Username, t.path, "") + } } removeTransfer(t) if t.transferType == transferUpload { @@ -79,6 +94,7 @@ func (t *Transfer) Close() error { } dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived, false) } + t.isFinished = true return err }