diff --git a/sftpd/handler.go b/sftpd/handler.go index 52eaefa7..2af238ac 100644 --- a/sftpd/handler.go +++ b/sftpd/handler.go @@ -37,7 +37,6 @@ type Connection struct { // last activity for this connection lastActivity time.Time protocol string - lock *sync.Mutex netConn net.Conn channel ssh.Channel command string @@ -61,10 +60,8 @@ func (c Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) { return nil, getSFTPErrorFromOSError(err) } - c.lock.Lock() - defer c.lock.Unlock() - - if _, err := os.Stat(p); err != nil { + fi, err := os.Stat(p) + if err != nil { return nil, getSFTPErrorFromOSError(err) } @@ -91,6 +88,8 @@ func (c Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) { transferError: nil, isFinished: false, minWriteOffset: 0, + expectedSize: fi.Size(), + lock: new(sync.Mutex), } addTransfer(&transfer) return &transfer, nil @@ -109,9 +108,6 @@ func (c Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) { filePath = getUploadTempFilePath(p) } - c.lock.Lock() - defer c.lock.Unlock() - stat, statErr := os.Stat(p) if os.IsNotExist(statErr) { if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(request.Filepath)) { @@ -448,6 +444,7 @@ func (c Connection) handleSFTPUploadToNewFile(requestPath, filePath string) (io. transferError: nil, isFinished: false, minWriteOffset: 0, + lock: new(sync.Mutex), } addTransfer(&transfer) return &transfer, nil @@ -503,6 +500,7 @@ func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, re transferError: nil, isFinished: false, minWriteOffset: minWriteOffset, + lock: new(sync.Mutex), } addTransfer(&transfer) return &transfer, nil diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index a93e84a6..2e052df6 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -9,6 +9,7 @@ import ( "os" "runtime" "strings" + "sync" "testing" "time" @@ -136,6 +137,7 @@ func TestUploadResumeInvalidOffset(t *testing.T) { transferError: nil, isFinished: false, minWriteOffset: 10, + lock: new(sync.Mutex), } _, err := transfer.WriteAt([]byte("test"), 0) if err == nil { @@ -144,6 +146,76 @@ func TestUploadResumeInvalidOffset(t *testing.T) { os.Remove(testfile) } +func TestIncompleteDownload(t *testing.T) { + testfile := "testfile" + file, _ := os.Create(testfile) + transfer := Transfer{ + file: file, + path: file.Name(), + start: time.Now(), + bytesSent: 0, + bytesReceived: 0, + user: dataprovider.User{ + Username: "testuser", + }, + connectionID: "", + transferType: transferDownload, + lastActivity: time.Now(), + isNewFile: false, + protocol: protocolSFTP, + transferError: nil, + isFinished: false, + minWriteOffset: 0, + expectedSize: 10, + lock: new(sync.Mutex), + } + err := transfer.Close() + if err == nil { + t.Error("upoload must fail the expected size does not match") + } + os.Remove(testfile) +} + +func TestReadWriteErrors(t *testing.T) { + testfile := "testfile" + file, _ := os.Create(testfile) + transfer := Transfer{ + file: file, + path: file.Name(), + start: time.Now(), + bytesSent: 0, + bytesReceived: 0, + user: dataprovider.User{ + Username: "testuser", + }, + connectionID: "", + transferType: transferDownload, + lastActivity: time.Now(), + isNewFile: false, + protocol: protocolSFTP, + transferError: nil, + isFinished: false, + minWriteOffset: 0, + expectedSize: 10, + lock: new(sync.Mutex), + } + file.Close() + _, err := transfer.WriteAt([]byte("test"), 0) + if err == nil { + t.Error("writing to closed file must fail") + } + buf := make([]byte, 32768) + _, err = transfer.ReadAt(buf, 0) + if err == nil { + t.Error("reading from a closed file must fail") + } + err = transfer.Close() + if err == nil { + t.Error("upoload must fail the expected size does not match") + } + os.Remove(testfile) +} + func TestUploadFiles(t *testing.T) { oldUploadMode := uploadMode uploadMode = uploadModeAtomic @@ -550,7 +622,9 @@ func TestSystemCommandErrors(t *testing.T) { WriteError: nil, } sshCmd.connection.channel = &mockSSHChannel - transfer := Transfer{transferType: transferDownload} + transfer := Transfer{ + transferType: transferDownload, + lock: new(sync.Mutex)} destBuff := make([]byte, 65535) dst := bytes.NewBuffer(destBuff) _, err = transfer.copyFromReaderToWriter(dst, sshCmd.connection.channel, 0) @@ -1071,6 +1145,7 @@ func TestSCPUploadFiledata(t *testing.T) { transferError: nil, isFinished: false, minWriteOffset: 0, + lock: new(sync.Mutex), } addTransfer(&transfer) err := scpCommand.getUploadFileData(2, &transfer) @@ -1151,12 +1226,13 @@ func TestUploadError(t *testing.T) { transferError: nil, isFinished: false, minWriteOffset: 0, + lock: new(sync.Mutex), } addTransfer(&transfer) transfer.TransferError(fmt.Errorf("fake error")) transfer.Close() if transfer.bytesReceived > 0 { - t.Errorf("byte sent should be 0 for a failed transfer: %v", transfer.bytesSent) + t.Errorf("bytes received should be 0 for a failed transfer: %v", transfer.bytesReceived) } _, err := os.Stat(testfile) if !os.IsNotExist(err) { diff --git a/sftpd/scp.go b/sftpd/scp.go index 4b9e3a50..dc8161bc 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -10,6 +10,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "time" "github.com/drakkan/sftpgo/dataprovider" @@ -212,6 +213,7 @@ func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead i transferError: nil, isFinished: false, minWriteOffset: 0, + lock: new(sync.Mutex), } addTransfer(&transfer) @@ -387,8 +389,9 @@ func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, tra } buf := make([]byte, 32768) + var n int for { - n, err := transfer.ReadAt(buf, readed) + n, err = transfer.ReadAt(buf, readed) if err == nil || err == io.EOF { if n > 0 { _, err = c.connection.channel.Write(buf[:n]) @@ -471,6 +474,8 @@ func (c *scpCommand) handleDownload(filePath string) error { transferError: nil, isFinished: false, minWriteOffset: 0, + expectedSize: stat.Size(), + lock: new(sync.Mutex), } addTransfer(&transfer) diff --git a/sftpd/server.go b/sftpd/server.go index 272b1c63..37185872 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -14,7 +14,6 @@ import ( "os" "path/filepath" "strconv" - "sync" "time" "github.com/drakkan/sftpgo/dataprovider" @@ -274,7 +273,6 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server RemoteAddr: remoteAddr, StartTime: time.Now(), lastActivity: time.Now(), - lock: new(sync.Mutex), netConn: conn, channel: nil, } diff --git a/sftpd/ssh_cmd.go b/sftpd/ssh_cmd.go index 843dcfdd..5edf0770 100644 --- a/sftpd/ssh_cmd.go +++ b/sftpd/ssh_cmd.go @@ -200,6 +200,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { transferError: nil, isFinished: false, minWriteOffset: 0, + lock: new(sync.Mutex), } addTransfer(&transfer) defer removeTransfer(&transfer) @@ -227,6 +228,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { transferError: nil, isFinished: false, minWriteOffset: 0, + lock: new(sync.Mutex), } addTransfer(&transfer) defer removeTransfer(&transfer) @@ -255,6 +257,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { transferError: nil, isFinished: false, minWriteOffset: 0, + lock: new(sync.Mutex), } addTransfer(&transfer) defer removeTransfer(&transfer) diff --git a/sftpd/transfer.go b/sftpd/transfer.go index 6e0ec256..d4b72584 100644 --- a/sftpd/transfer.go +++ b/sftpd/transfer.go @@ -1,9 +1,11 @@ package sftpd import ( + "errors" "fmt" "io" "os" + "sync" "time" "github.com/drakkan/sftpgo/dataprovider" @@ -33,11 +35,18 @@ type Transfer struct { transferError error isFinished bool minWriteOffset int64 + expectedSize int64 + lock *sync.Mutex } // TransferError is called if there is an unexpected error. // For example network or client issues func (t *Transfer) TransferError(err error) { + t.lock.Lock() + defer t.lock.Unlock() + if t.transferError != nil { + return + } t.transferError = err elapsed := time.Since(t.start).Nanoseconds() / 1000000 logger.Warn(logSender, t.connectionID, "Unexpected error for transfer, path: %#v, error: \"%v\" bytes sent: %v, "+ @@ -49,7 +58,13 @@ func (t *Transfer) TransferError(err error) { func (t *Transfer) ReadAt(p []byte, off int64) (n int, err error) { t.lastActivity = time.Now() readed, e := t.file.ReadAt(p, off) + t.lock.Lock() t.bytesSent += int64(readed) + t.lock.Unlock() + if e != nil && e != io.EOF { + t.TransferError(e) + return readed, e + } t.handleThrottle() return readed, e } @@ -59,11 +74,18 @@ func (t *Transfer) ReadAt(p []byte, off int64) (n int, err error) { func (t *Transfer) WriteAt(p []byte, off int64) (n int, err error) { t.lastActivity = time.Now() if off < t.minWriteOffset { - logger.Warn(logSender, t.connectionID, "Invalid write offset %v minimum valid value %v", off, t.minWriteOffset) - return 0, fmt.Errorf("invalid write offset %v", off) + err := fmt.Errorf("Invalid write offset: %v minimum valid value: %v", off, t.minWriteOffset) + t.TransferError(err) + return 0, err } written, e := t.file.WriteAt(p, off) + t.lock.Lock() t.bytesReceived += int64(written) + t.lock.Unlock() + if e != nil { + t.TransferError(e) + return written, e + } t.handleThrottle() return written, e } @@ -74,15 +96,18 @@ func (t *Transfer) WriteAt(p []byte, off int64) (n int, err error) { // If there is an error no action will be executed and, in atomic mode, we try to delete // the temporary file func (t *Transfer) Close() error { - err := t.file.Close() + t.lock.Lock() + defer t.lock.Unlock() if t.isFinished { - return err + return errors.New("transfer already closed") } + err := t.file.Close() t.isFinished = true numFiles := 0 if t.isNewFile { numFiles = 1 } + t.checkDownloadSize() if t.transferType == transferUpload && t.file.Name() != t.path { if t.transferError == nil || uploadMode == uploadModeAtomicWithResume { err = os.Rename(t.file.Name(), t.path) @@ -107,6 +132,11 @@ func (t *Transfer) Close() error { logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol) go executeAction(operationUpload, t.user.Username, t.path, "", "", t.bytesReceived+t.minWriteOffset) } + } else { + logger.Warn(logSender, t.connectionID, "transfer error: %v, path: %#v", t.transferError, t.path) + if err == nil { + err = t.transferError + } } metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError) removeTransfer(t) @@ -116,6 +146,12 @@ func (t *Transfer) Close() error { return err } +func (t *Transfer) checkDownloadSize() { + if t.transferType == transferDownload && t.transferError == nil && t.bytesSent < t.expectedSize { + t.transferError = fmt.Errorf("incomplete download: %v/%v bytes transferred", t.bytesSent, t.expectedSize) + } +} + func (t *Transfer) handleThrottle() { var wantedBandwidth int64 var trasferredBytes int64