diff --git a/httpd/api_http_user.go b/httpd/api_http_user.go index de8c042a..b1e54b3b 100644 --- a/httpd/api_http_user.go +++ b/httpd/api_http_user.go @@ -181,11 +181,15 @@ func uploadUserFiles(w http.ResponseWriter, r *http.Request) { common.Connections.Add(connection) defer common.Connections.Remove(connection.GetID()) + t := newThrottledReader(r.Body, connection.User.UploadBandwidth, connection) + r.Body = t err = r.ParseMultipartForm(maxMultipartMem) if err != nil { + connection.RemoveTransfer(t) sendAPIResponse(w, r, err, "Unable to parse multipart form", http.StatusBadRequest) return } + connection.RemoveTransfer(t) defer r.MultipartForm.RemoveAll() //nolint:errcheck parentDir := util.CleanPath(r.URL.Query().Get("path")) @@ -201,6 +205,7 @@ func doUploadFiles(w http.ResponseWriter, r *http.Request, connection *Connectio files []*multipart.FileHeader, ) int { uploaded := 0 + connection.User.UploadBandwidth = 0 for _, f := range files { file, err := f.Open() if err != nil { diff --git a/httpd/api_shares.go b/httpd/api_shares.go index fd7287ac..154936c9 100644 --- a/httpd/api_shares.go +++ b/httpd/api_shares.go @@ -158,11 +158,15 @@ func uploadToShare(w http.ResponseWriter, r *http.Request) { common.Connections.Add(connection) defer common.Connections.Remove(connection.GetID()) + t := newThrottledReader(r.Body, connection.User.UploadBandwidth, connection) + r.Body = t err = r.ParseMultipartForm(maxMultipartMem) if err != nil { + connection.RemoveTransfer(t) sendAPIResponse(w, r, err, "Unable to parse multipart form", http.StatusBadRequest) return } + connection.RemoveTransfer(t) defer r.MultipartForm.RemoveAll() //nolint:errcheck files := r.MultipartForm.File["filenames"] diff --git a/httpd/handler.go b/httpd/handler.go index 375dd071..fd8ff559 100644 --- a/httpd/handler.go +++ b/httpd/handler.go @@ -6,6 +6,8 @@ import ( "os" "path" "strings" + "sync/atomic" + "time" "github.com/drakkan/sftpgo/v2/common" "github.com/drakkan/sftpgo/v2/dataprovider" @@ -214,3 +216,87 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs) return newHTTPDFile(baseTransfer, w, nil), nil } + +func newThrottledReader(r io.ReadCloser, limit int64, conn *Connection) *throttledReader { + t := &throttledReader{ + bytesRead: 0, + id: conn.GetTransferID(), + limit: limit, + r: r, + abortTransfer: 0, + start: time.Now(), + conn: conn, + } + conn.AddTransfer(t) + return t +} + +type throttledReader struct { + bytesRead int64 + id uint64 + limit int64 + r io.ReadCloser + abortTransfer int32 + start time.Time + conn *Connection +} + +func (t *throttledReader) GetID() uint64 { + return t.id +} + +func (t *throttledReader) GetType() int { + return common.TransferUpload +} + +func (t *throttledReader) GetSize() int64 { + return atomic.LoadInt64(&t.bytesRead) +} + +func (t *throttledReader) GetVirtualPath() string { + return "**reading request body**" +} + +func (t *throttledReader) GetStartTime() time.Time { + return t.start +} + +func (t *throttledReader) SignalClose() { + atomic.StoreInt32(&(t.abortTransfer), 1) +} + +func (t *throttledReader) Truncate(fsPath string, size int64) (int64, error) { + return 0, vfs.ErrVfsUnsupported +} + +func (t *throttledReader) GetRealFsPath(fsPath string) string { + return "" +} + +func (t *throttledReader) SetTimes(fsPath string, atime time.Time, mtime time.Time) bool { + return false +} + +func (t *throttledReader) Read(p []byte) (n int, err error) { + if atomic.LoadInt32(&t.abortTransfer) == 1 { + return 0, errTransferAborted + } + + t.conn.UpdateLastActivity() + n, err = t.r.Read(p) + if t.limit > 0 { + atomic.AddInt64(&t.bytesRead, int64(n)) + trasferredBytes := atomic.LoadInt64(&t.bytesRead) + elapsed := time.Since(t.start).Nanoseconds() / 1000000 + wantedElapsed := 1000 * (trasferredBytes / 1024) / t.limit + if wantedElapsed > elapsed { + toSleep := time.Duration(wantedElapsed - elapsed) + time.Sleep(toSleep * time.Millisecond) + } + } + return +} + +func (t *throttledReader) Close() error { + return t.r.Close() +} diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index 422d919d..65c2de6b 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -10668,7 +10668,7 @@ func TestClientUserClose(t *testing.T) { req.Header.Add("Content-Type", writer.FormDataContentType()) setBearerForReq(req, webAPIToken) rr := executeRequest(req) - checkResponseCode(t, http.StatusInternalServerError, rr) + checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "transfer aborted") }() // wait for the transfers diff --git a/httpd/internal_test.go b/httpd/internal_test.go index b0769d68..6039d97b 100644 --- a/httpd/internal_test.go +++ b/httpd/internal_test.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "html/template" + "io" "net/http" "net/http/httptest" "net/url" @@ -1810,6 +1811,18 @@ func TestGetFileWriterErrors(t *testing.T) { assert.Error(t, err) } +func TestThrottledHandler(t *testing.T) { + tr := &throttledReader{ + r: io.NopCloser(bytes.NewBuffer(nil)), + } + err := tr.Close() + assert.NoError(t, err) + assert.Empty(t, tr.GetRealFsPath("real path")) + assert.False(t, tr.SetTimes("p", time.Now(), time.Now())) + _, err = tr.Truncate("", 0) + assert.ErrorIs(t, err, vfs.ErrVfsUnsupported) +} + func TestHTTPDFile(t *testing.T) { user := dataprovider.User{ BaseUser: sdk.BaseUser{ @@ -1857,6 +1870,9 @@ func TestHTTPDFile(t *testing.T) { assert.Error(t, err) assert.Error(t, httpdFile.ErrTransfer) assert.Equal(t, err, httpdFile.ErrTransfer) + httpdFile.SignalClose() + _, err = httpdFile.Write(nil) + assert.ErrorIs(t, err, errTransferAborted) } func TestChangeUserPwd(t *testing.T) { diff --git a/templates/webclient/files.html b/templates/webclient/files.html index 3d6831b1..f3121d55 100644 --- a/templates/webclient/files.html +++ b/templates/webclient/files.html @@ -85,7 +85,7 @@ - +