sftpgo-mirror/sftpd/transfer.go
Nicola Murino 0ea2ca3141 simplify data provider usage
remove the obsolete SQL scripts too. They are not required since v0.9.6
2020-07-08 19:59:31 +02:00

295 lines
8.4 KiB
Go

package sftpd
import (
"errors"
"fmt"
"io"
"os"
"path"
"sync"
"time"
"github.com/eikenb/pipeat"
"github.com/drakkan/sftpgo/dataprovider"
"github.com/drakkan/sftpgo/logger"
"github.com/drakkan/sftpgo/metrics"
"github.com/drakkan/sftpgo/vfs"
)
const (
transferUpload = iota
transferDownload
)
var (
errTransferClosed = errors.New("transfer already closed")
)
// Transfer contains the transfer details for an upload or a download.
// It implements the io Reader and Writer interface to handle files downloads and uploads
type Transfer struct {
file *os.File
writerAt *vfs.PipeWriter
readerAt *pipeat.PipeReaderAt
cancelFn func()
path string
start time.Time
bytesSent int64
bytesReceived int64
user dataprovider.User
connectionID string
transferType int
lastActivity time.Time
protocol string
transferError error
minWriteOffset int64
initialSize int64
lock *sync.Mutex
isNewFile bool
isFinished bool
requestPath string
maxWriteSize int64
}
// TransferError is called if there is an unexpected error.
// For example network or client issues
func (t *Transfer) TransferError(err error) {
t.lock.Lock()
defer t.lock.Unlock()
if t.transferError != nil {
return
}
t.transferError = err
if t.cancelFn != nil {
t.cancelFn()
}
elapsed := time.Since(t.start).Nanoseconds() / 1000000
logger.Warn(logSender, t.connectionID, "Unexpected error for transfer, path: %#v, error: \"%v\" bytes sent: %v, "+
"bytes received: %v transfer running since %v ms", t.path, t.transferError, t.bytesSent, t.bytesReceived, elapsed)
}
// ReadAt reads len(p) bytes from the File to download starting at byte offset off and updates the bytes sent.
// It handles download bandwidth throttling too
func (t *Transfer) ReadAt(p []byte, off int64) (n int, err error) {
t.lastActivity = time.Now()
var readed int
var e error
if t.readerAt != nil {
readed, e = t.readerAt.ReadAt(p, off)
} else {
readed, e = t.file.ReadAt(p, off)
}
t.lock.Lock()
t.bytesSent += int64(readed)
t.lock.Unlock()
if e != nil && e != io.EOF {
t.TransferError(e)
return readed, e
}
t.handleThrottle()
return readed, e
}
// WriteAt writes len(p) bytes to the uploaded file starting at byte offset off and updates the bytes received.
// It handles upload bandwidth throttling too
func (t *Transfer) WriteAt(p []byte, off int64) (n int, err error) {
t.lastActivity = time.Now()
if off < t.minWriteOffset {
err := fmt.Errorf("Invalid write offset: %v minimum valid value: %v", off, t.minWriteOffset)
t.TransferError(err)
return 0, err
}
var written int
var e error
if t.writerAt != nil {
written, e = t.writerAt.WriteAt(p, off)
} else {
written, e = t.file.WriteAt(p, off)
}
t.lock.Lock()
t.bytesReceived += int64(written)
if e == nil && t.maxWriteSize > 0 && t.bytesReceived > t.maxWriteSize {
e = errQuotaExceeded
}
t.lock.Unlock()
if e != nil {
t.TransferError(e)
return written, e
}
t.handleThrottle()
return written, e
}
// Close it is called when the transfer is completed.
// It closes the underlying file, logs the transfer info, updates the user quota (for uploads)
// and executes any defined action.
// If there is an error no action will be executed and, in atomic mode, we try to delete
// the temporary file
func (t *Transfer) Close() error {
t.lock.Lock()
defer t.lock.Unlock()
if t.isFinished {
return errTransferClosed
}
err := t.closeIO()
defer removeTransfer(t) //nolint:errcheck
t.isFinished = true
numFiles := 0
if t.isNewFile {
numFiles = 1
}
metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
if t.transferError == errQuotaExceeded && t.file != nil {
// if quota is exceeded we try to remove the partial file for uploads to local filesystem
err = os.Remove(t.file.Name())
if err == nil {
numFiles--
t.bytesReceived = 0
t.minWriteOffset = 0
}
logger.Warn(logSender, t.connectionID, "upload denied due to space limit, delete temporary file: %#v, deletion error: %v",
t.file.Name(), err)
} else if t.transferType == transferUpload && t.file != nil && t.file.Name() != t.path {
if t.transferError == nil || uploadMode == uploadModeAtomicWithResume {
err = os.Rename(t.file.Name(), t.path)
logger.Debug(logSender, t.connectionID, "atomic upload completed, rename: %#v -> %#v, error: %v",
t.file.Name(), t.path, err)
} else {
err = os.Remove(t.file.Name())
logger.Warn(logSender, t.connectionID, "atomic upload completed with error: \"%v\", delete temporary file: %#v, "+
"deletion error: %v", t.transferError, t.file.Name(), err)
if err == nil {
numFiles--
t.bytesReceived = 0
t.minWriteOffset = 0
}
}
}
elapsed := time.Since(t.start).Nanoseconds() / 1000000
if t.transferType == transferDownload {
logger.TransferLog(downloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID, t.protocol)
go executeAction(newActionNotification(t.user, operationDownload, t.path, "", "", t.bytesSent, t.transferError)) //nolint:errcheck
} else {
logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol)
go executeAction(newActionNotification(t.user, operationUpload, t.path, "", "", t.bytesReceived+t.minWriteOffset, //nolint:errcheck
t.transferError))
}
if t.transferError != nil {
logger.Warn(logSender, t.connectionID, "transfer error: %v, path: %#v", t.transferError, t.path)
if err == nil {
err = t.transferError
}
}
t.updateQuota(numFiles)
return err
}
func (t *Transfer) closeIO() error {
var err error
if t.writerAt != nil {
err = t.writerAt.Close()
if err != nil {
t.transferError = err
}
} else if t.readerAt != nil {
err = t.readerAt.Close()
} else {
err = t.file.Close()
}
return err
}
func (t *Transfer) updateQuota(numFiles int) bool {
// S3 uploads are atomic, if there is an error nothing is uploaded
if t.file == nil && t.transferError != nil {
return false
}
if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) {
vfolder, err := t.user.GetVirtualFolderForPath(path.Dir(t.requestPath))
if err == nil {
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck
t.bytesReceived-t.initialSize, false)
if vfolder.IsIncludedInUserQuota() {
dataprovider.UpdateUserQuota(t.user, numFiles, t.bytesReceived-t.initialSize, false) //nolint:errcheck
}
} else {
dataprovider.UpdateUserQuota(t.user, numFiles, t.bytesReceived-t.initialSize, false) //nolint:errcheck
}
return true
}
return false
}
func (t *Transfer) handleThrottle() {
var wantedBandwidth int64
var trasferredBytes int64
if t.transferType == transferDownload {
wantedBandwidth = t.user.DownloadBandwidth
trasferredBytes = t.bytesSent
} else {
wantedBandwidth = t.user.UploadBandwidth
trasferredBytes = t.bytesReceived
}
if wantedBandwidth > 0 {
// real and wanted elapsed as milliseconds, bytes as kilobytes
realElapsed := time.Since(t.start).Nanoseconds() / 1000000
// trasferredBytes / 1000 = KB/s, we multiply for 1000 to get milliseconds
wantedElapsed := 1000 * (trasferredBytes / 1000) / wantedBandwidth
if wantedElapsed > realElapsed {
toSleep := time.Duration(wantedElapsed - realElapsed)
time.Sleep(toSleep * time.Millisecond)
}
}
}
// used for ssh commands.
// It reads from src until EOF so it does not treat an EOF from Read as an error to be reported.
// EOF from Write is reported as error
func (t *Transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64, error) {
var written int64
var err error
if t.maxWriteSize < 0 {
return 0, errQuotaExceeded
}
buf := make([]byte, 32768)
for {
t.lastActivity = time.Now()
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
if t.transferType == transferDownload {
t.bytesSent = written
} else {
t.bytesReceived = written
}
if t.maxWriteSize > 0 && written > t.maxWriteSize {
err = errQuotaExceeded
break
}
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
t.handleThrottle()
}
t.transferError = err
if t.bytesSent > 0 || t.bytesReceived > 0 || err != nil {
metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
}
return written, err
}