mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-21 23:20:24 +00:00
actions: don't execute actions on errors
detect upload/download errors and don't execute actions if a transfer error happen. To detect SFTP errors this patch is needed: https://github.com/pkg/sftp/pull/307
This commit is contained in:
parent
2a7e56ed29
commit
bc5779e26f
6 changed files with 132 additions and 32 deletions
|
@ -36,8 +36,6 @@ The script `entrypoint.sh` makes sure to correct the permissions of directories
|
|||
Several images can be run with another parameters.
|
||||
|
||||
### Custom systemd script
|
||||
An example of systemd script is present [here](sftpgo-docker.service), with `Environment` parameter to set `PUID` and `GUID`
|
||||
An example of systemd script is present [here](sftpgo.service), with `Environment` parameter to set `PUID` and `GUID`
|
||||
|
||||
`WorkingDirectory` parameter must be exist with one file in this directory like `sftpgo-${PUID}.env` corresponding to the variable file for SFTPgo instance.
|
||||
|
||||
Enjoy
|
||||
`WorkingDirectory` parameter must be exist with one file in this directory like `sftpgo-${PUID}.env` corresponding to the variable file for SFTPGo instance.
|
|
@ -85,6 +85,8 @@ func (c Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
|
|||
lastActivity: time.Now(),
|
||||
isNewFile: false,
|
||||
protocol: c.protocol,
|
||||
transferError: nil,
|
||||
isFinished: false,
|
||||
}
|
||||
addTransfer(&transfer)
|
||||
return &transfer, nil
|
||||
|
@ -380,6 +382,8 @@ func (c Connection) handleSFTPUploadToNewFile(requestPath, filePath string) (io.
|
|||
lastActivity: time.Now(),
|
||||
isNewFile: true,
|
||||
protocol: c.protocol,
|
||||
transferError: nil,
|
||||
isFinished: false,
|
||||
}
|
||||
addTransfer(&transfer)
|
||||
return &transfer, nil
|
||||
|
@ -434,6 +438,8 @@ func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, re
|
|||
lastActivity: time.Now(),
|
||||
isNewFile: false,
|
||||
protocol: c.protocol,
|
||||
transferError: nil,
|
||||
isFinished: false,
|
||||
}
|
||||
addTransfer(&transfer)
|
||||
return &transfer, nil
|
||||
|
|
|
@ -620,6 +620,8 @@ func TestSCPUploadFiledata(t *testing.T) {
|
|||
lastActivity: time.Now(),
|
||||
isNewFile: true,
|
||||
protocol: connection.protocol,
|
||||
transferError: nil,
|
||||
isFinished: false,
|
||||
}
|
||||
addTransfer(&transfer)
|
||||
err := scpCommand.getUploadFileData(2, &transfer)
|
||||
|
|
|
@ -151,6 +151,7 @@ func (c *scpCommand) handleCreateDir(dirPath string) error {
|
|||
func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) error {
|
||||
err := c.sendConfirmationMessage()
|
||||
if err != nil {
|
||||
transfer.TransferError(err)
|
||||
transfer.Close()
|
||||
return err
|
||||
}
|
||||
|
@ -162,6 +163,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err
|
|||
n, err := c.channel.Read(buf)
|
||||
if err != nil {
|
||||
c.sendErrorMessage(err.Error())
|
||||
transfer.TransferError(err)
|
||||
transfer.Close()
|
||||
return err
|
||||
}
|
||||
|
@ -177,6 +179,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err
|
|||
}
|
||||
err = c.readConfirmationMessage()
|
||||
if err != nil {
|
||||
transfer.TransferError(err)
|
||||
transfer.Close()
|
||||
return err
|
||||
}
|
||||
|
@ -226,6 +229,8 @@ func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead i
|
|||
lastActivity: time.Now(),
|
||||
isNewFile: isNewFile,
|
||||
protocol: c.connection.protocol,
|
||||
transferError: nil,
|
||||
isFinished: false,
|
||||
}
|
||||
addTransfer(&transfer)
|
||||
|
||||
|
@ -468,6 +473,8 @@ func (c *scpCommand) handleDownload(filePath string) error {
|
|||
lastActivity: time.Now(),
|
||||
isNewFile: false,
|
||||
protocol: c.connection.protocol,
|
||||
transferError: nil,
|
||||
isFinished: false,
|
||||
}
|
||||
addTransfer(&transfer)
|
||||
|
||||
|
@ -477,6 +484,7 @@ func (c *scpCommand) handleDownload(filePath string) error {
|
|||
if err == nil {
|
||||
err = transfer.Close()
|
||||
} else {
|
||||
transfer.TransferError(err)
|
||||
transfer.Close()
|
||||
}
|
||||
return err
|
||||
|
|
|
@ -1998,6 +1998,68 @@ func TestSCPRemoteToRemote(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSCPErrors(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
u := getTestUser(true)
|
||||
u.UploadBandwidth = 4096
|
||||
u.DownloadBandwidth = 4096
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileSize := int64(524288)
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
|
||||
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
|
||||
localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
||||
err = scpUpload(testFilePath, remoteUpPath, false, false)
|
||||
if err != nil {
|
||||
t.Errorf("error uploading file via scp: %v", err)
|
||||
}
|
||||
cmd := getScpDownloadCommand(localPath, remoteDownPath, false, false)
|
||||
go func() {
|
||||
if cmd.Run() == nil {
|
||||
t.Errorf("SCP download must fail")
|
||||
}
|
||||
}()
|
||||
waitForActiveTransfer()
|
||||
// wait some additional arbitrary time to wait for transfer activity to happen
|
||||
// it is need to reach all the code in CheckIdleConnections
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cmd.Process.Kill()
|
||||
|
||||
cmd = getScpUploadCommand(testFilePath, remoteUpPath, false, false)
|
||||
go func() {
|
||||
if cmd.Run() == nil {
|
||||
t.Errorf("SCP upload must fail")
|
||||
}
|
||||
}()
|
||||
waitForActiveTransfer()
|
||||
// wait some additional arbitrary time to wait for transfer activity to happen
|
||||
// it is need to reach all the code in CheckIdleConnections
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cmd.Process.Kill()
|
||||
err = os.Remove(testFilePath)
|
||||
if err != nil {
|
||||
t.Errorf("error removing test file")
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// End SCP tests
|
||||
|
||||
func waitTCPListening(address string) {
|
||||
|
@ -2178,6 +2240,35 @@ func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expe
|
|||
}
|
||||
|
||||
func scpUpload(localPath, remotePath string, preserveTime, remoteToRemote bool) error {
|
||||
cmd := getScpUploadCommand(localPath, remotePath, preserveTime, remoteToRemote)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func scpDownload(localPath, remotePath string, preserveTime, recursive bool) error {
|
||||
cmd := getScpDownloadCommand(localPath, remotePath, preserveTime, recursive)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func getScpDownloadCommand(localPath, remotePath string, preserveTime, recursive bool) *exec.Cmd {
|
||||
var args []string
|
||||
if preserveTime {
|
||||
args = append(args, "-p")
|
||||
}
|
||||
if recursive {
|
||||
args = append(args, "-r")
|
||||
}
|
||||
args = append(args, "-P")
|
||||
args = append(args, "2022")
|
||||
args = append(args, "-o")
|
||||
args = append(args, "StrictHostKeyChecking=no")
|
||||
args = append(args, "-i")
|
||||
args = append(args, privateKeyPath)
|
||||
args = append(args, remotePath)
|
||||
args = append(args, localPath)
|
||||
return exec.Command(scpPath, args...)
|
||||
}
|
||||
|
||||
func getScpUploadCommand(localPath, remotePath string, preserveTime, remoteToRemote bool) *exec.Cmd {
|
||||
var args []string
|
||||
if remoteToRemote {
|
||||
args = append(args, "-3")
|
||||
|
@ -2199,28 +2290,7 @@ func scpUpload(localPath, remotePath string, preserveTime, remoteToRemote bool)
|
|||
args = append(args, privateKeyPath)
|
||||
args = append(args, localPath)
|
||||
args = append(args, remotePath)
|
||||
cmd := exec.Command(scpPath, args...)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func scpDownload(localPath, remotePath string, preserveTime, recursive bool) error {
|
||||
var args []string
|
||||
if preserveTime {
|
||||
args = append(args, "-p")
|
||||
}
|
||||
if recursive {
|
||||
args = append(args, "-r")
|
||||
}
|
||||
args = append(args, "-P")
|
||||
args = append(args, "2022")
|
||||
args = append(args, "-o")
|
||||
args = append(args, "StrictHostKeyChecking=no")
|
||||
args = append(args, "-i")
|
||||
args = append(args, privateKeyPath)
|
||||
args = append(args, remotePath)
|
||||
args = append(args, localPath)
|
||||
cmd := exec.Command(scpPath, args...)
|
||||
return cmd.Run()
|
||||
return exec.Command(scpPath, args...)
|
||||
}
|
||||
|
||||
func waitForActiveTransfer() {
|
||||
|
|
|
@ -32,6 +32,16 @@ type Transfer struct {
|
|||
lastActivity time.Time
|
||||
isNewFile bool
|
||||
protocol string
|
||||
transferError error
|
||||
isFinished bool
|
||||
}
|
||||
|
||||
// TransferError is called if there is an unexpected error.
|
||||
// For example network or client issues
|
||||
func (t *Transfer) TransferError(err error) {
|
||||
t.transferError = err
|
||||
logger.Warn(logSender, t.connectionID, "Unexpected error for transfer, path: %#v, error: %v bytes sent: %v,"+
|
||||
"bytes received: %v", t.path, t.transferError, t.bytesSent, t.bytesReceived)
|
||||
}
|
||||
|
||||
// ReadAt reads len(p) bytes from the File to download starting at byte offset off and updates the bytes sent.
|
||||
|
@ -58,18 +68,23 @@ func (t *Transfer) WriteAt(p []byte, off int64) (n int, err error) {
|
|||
// It closes the underlying file, log the transfer info, update the user quota, for uploads, and execute any defined actions.
|
||||
func (t *Transfer) Close() error {
|
||||
err := t.file.Close()
|
||||
if t.isFinished {
|
||||
return err
|
||||
}
|
||||
if t.transferType == transferUpload && t.file.Name() != t.path {
|
||||
err = os.Rename(t.file.Name(), t.path)
|
||||
logger.Debug(logSender, t.connectionID, "atomic upload completed, rename: %#v -> %#v, error: %v",
|
||||
t.file.Name(), t.path, err)
|
||||
}
|
||||
elapsed := time.Since(t.start).Nanoseconds() / 1000000
|
||||
if t.transferType == transferDownload {
|
||||
logger.TransferLog(downloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID, t.protocol)
|
||||
executeAction(operationDownload, t.user.Username, t.path, "")
|
||||
} else {
|
||||
logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol)
|
||||
executeAction(operationUpload, t.user.Username, t.path, "")
|
||||
if t.transferError == nil {
|
||||
if t.transferType == transferDownload {
|
||||
logger.TransferLog(downloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID, t.protocol)
|
||||
executeAction(operationDownload, t.user.Username, t.path, "")
|
||||
} else {
|
||||
logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol)
|
||||
executeAction(operationUpload, t.user.Username, t.path, "")
|
||||
}
|
||||
}
|
||||
removeTransfer(t)
|
||||
if t.transferType == transferUpload {
|
||||
|
@ -79,6 +94,7 @@ func (t *Transfer) Close() error {
|
|||
}
|
||||
dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived, false)
|
||||
}
|
||||
t.isFinished = true
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue