diff --git a/go.mod b/go.mod index ef9cbe3f..a0cc7512 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/bmatcuk/doublestar/v4 v4.7.1 github.com/cockroachdb/cockroach-go/v2 v2.3.8 github.com/coreos/go-oidc/v3 v3.11.0 - github.com/drakkan/webdav v0.0.0-20240503091431-218ec83910bb + github.com/drakkan/webdav v0.0.0-20241026165615-b8b8f74ae71b github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001 github.com/fclairamb/ftpserverlib v0.24.1 github.com/fclairamb/go-log v0.5.0 diff --git a/go.sum b/go.sum index 205c6eaa..561696c0 100644 --- a/go.sum +++ b/go.sum @@ -140,8 +140,8 @@ github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f h1:S9JUlrOzjK58UKoLqqb github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f/go.mod h1:4p8lUl4vQ80L598CygL+3IFtm+3nggvvW/palOlViwE= github.com/drakkan/ftpserverlib v0.0.0-20240603150004-6a8f643fbf2e h1:VBpqQeChkGXSV1FXCtvd3BJTyB+DcMgiu7SfkpsGuKw= github.com/drakkan/ftpserverlib v0.0.0-20240603150004-6a8f643fbf2e/go.mod h1:aAwyOAC6IIe+IZeeGD1QjuE3GGDzqW/c5Xtn+Dp0JUM= -github.com/drakkan/webdav v0.0.0-20240503091431-218ec83910bb h1:067/Uo8cfeY7QC0yzWCr/RImuNcM0rLWAsBUyMks59o= -github.com/drakkan/webdav v0.0.0-20240503091431-218ec83910bb/go.mod h1:zOVb1QDhwwqWn2L2qZ0U3swMSO4GTSNyIwXCGO/UGWE= +github.com/drakkan/webdav v0.0.0-20241026165615-b8b8f74ae71b h1:Y1tLiQ8fnxM5f3wiBjAXsHzHNwiY9BR+mXZA75nZwrs= +github.com/drakkan/webdav v0.0.0-20241026165615-b8b8f74ae71b/go.mod h1:zOVb1QDhwwqWn2L2qZ0U3swMSO4GTSNyIwXCGO/UGWE= github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001 h1:/ZshrfQzayqRSBDodmp3rhNCHJCff+utvgBuWRbiqu4= github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001/go.mod h1:kltMsfRMTHSFdMbK66XdS8mfMW77+FZA1fGY1xYMF84= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= diff --git a/internal/common/common.go b/internal/common/common.go index 78c11b75..6435c2af 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -125,6 +125,9 @@ func init() { Connections.clients = clientsMap{ clients: make(map[string]int), } + Connections.transfers = clientsMap{ + clients: make(map[string]int), + } Connections.perUserConns = make(map[string]int) Connections.mapping = make(map[string]int) Connections.sshMapping = make(map[string]int) @@ -908,7 +911,9 @@ func (c *SSHConnection) Close() error { type ActiveConnections struct { // clients contains both authenticated and estabilished connections and the ones waiting // for authentication - clients clientsMap + clients clientsMap + // transfers contains active transfers, total and per-user + transfers clientsMap transfersCheckStatus atomic.Bool sync.RWMutex connections []ActiveConnection @@ -959,6 +964,9 @@ func (conns *ActiveConnections) Add(c ActiveConnection) error { if val := conns.perUserConns[username]; val >= maxSessions { return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions) } + if val := conns.transfers.getTotalFrom(username); val >= maxSessions { + return fmt.Errorf("too many open transfers: %d/%d", val, maxSessions) + } } conns.addUserConnection(username) } @@ -1219,6 +1227,35 @@ func (conns *ActiveConnections) GetClientConnections() int32 { return conns.clients.getTotal() } +// GetTotalTransfers returns the total number of active transfers +func (conns *ActiveConnections) GetTotalTransfers() int32 { + return conns.transfers.getTotal() +} + +// IsNewTransferAllowed returns an error if the maximum number of concurrent allowed +// transfers is exceeded +func (conns *ActiveConnections) IsNewTransferAllowed(username string) error { + if isShuttingDown.Load() { + return ErrShuttingDown + } + if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 { + return nil + } + if Config.MaxPerHostConnections > 0 { + if transfers := conns.transfers.getTotalFrom(username); transfers >= Config.MaxPerHostConnections { + logger.Info(logSender, "", "active transfers from user %q: %d/%d", username, transfers, Config.MaxPerHostConnections) + return ErrConnectionDenied + } + } + if Config.MaxTotalConnections > 0 { + if transfers := conns.transfers.getTotal(); transfers >= int32(Config.MaxTotalConnections) { + logger.Info(logSender, "", "active transfers %d/%d", transfers, Config.MaxTotalConnections) + return ErrConnectionDenied + } + } + return nil +} + // IsNewConnectionAllowed returns an error if the maximum number of concurrent allowed // connections is exceeded or a whitelist is defined and the specified ipAddr is not listed // or the service is shutting down @@ -1259,7 +1296,11 @@ func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr, protocol string) } // on a single SFTP connection we could have multiple SFTP channels or commands - // so we check the estabilished connections too + // so we check the estabilished connections and active uploads too + if transfers := conns.transfers.getTotal(); transfers >= int32(Config.MaxTotalConnections) { + logger.Info(logSender, "", "active transfers %d/%d", transfers, Config.MaxTotalConnections) + return ErrConnectionDenied + } conns.RLock() defer conns.RUnlock() diff --git a/internal/common/common_test.go b/internal/common/common_test.go index 4e856604..c6d801ba 100644 --- a/internal/common/common_test.go +++ b/internal/common/common_test.go @@ -626,11 +626,17 @@ func TestMaxConnections(t *testing.T) { ipAddr := "192.168.7.8" assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolFTP)) + assert.NoError(t, Connections.IsNewTransferAllowed(userTestUsername)) Config.MaxTotalConnections = 1 Config.MaxPerHostConnections = perHost assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolHTTP)) + assert.NoError(t, Connections.IsNewTransferAllowed(userTestUsername)) + isShuttingDown.Store(true) + assert.ErrorIs(t, Connections.IsNewTransferAllowed(userTestUsername), ErrShuttingDown) + isShuttingDown.Store(false) + c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{}) fakeConn := &fakeConnection{ BaseConnection: c, @@ -639,6 +645,10 @@ func TestMaxConnections(t *testing.T) { assert.NoError(t, err) assert.Len(t, Connections.GetStats(""), 1) assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) + Connections.transfers.add(userTestUsername) + assert.Error(t, Connections.IsNewTransferAllowed(userTestUsername)) + Connections.transfers.remove(userTestUsername) + assert.Equal(t, int32(0), Connections.GetTotalTransfers()) res := Connections.Close(fakeConn.GetID(), "") assert.True(t, res) @@ -650,6 +660,9 @@ func TestMaxConnections(t *testing.T) { assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) Connections.RemoveClientConnection(ipAddr) assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolWebDAV)) + Connections.transfers.add(userTestUsername) + assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) + Connections.transfers.remove(userTestUsername) Connections.RemoveClientConnection(ipAddr) Config.MaxTotalConnections = oldValue diff --git a/internal/common/connection.go b/internal/common/connection.go index b63ef93a..8d46ed47 100644 --- a/internal/common/connection.go +++ b/internal/common/connection.go @@ -159,6 +159,8 @@ func (c *BaseConnection) CloseFS() error { // AddTransfer associates a new transfer to this connection func (c *BaseConnection) AddTransfer(t ActiveTransfer) { + Connections.transfers.add(c.User.Username) + c.Lock() defer c.Unlock() @@ -190,6 +192,8 @@ func (c *BaseConnection) AddTransfer(t ActiveTransfer) { // RemoveTransfer removes the specified transfer from the active ones func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) { + Connections.transfers.remove(c.User.Username) + c.Lock() defer c.Unlock() diff --git a/internal/common/protocol_test.go b/internal/common/protocol_test.go index 9ddb3e1a..f82b28a7 100644 --- a/internal/common/protocol_test.go +++ b/internal/common/protocol_test.go @@ -8130,6 +8130,86 @@ func TestRetentionAPI(t *testing.T) { assert.NoError(t, err) } +func TestPerUserTransferLimits(t *testing.T) { + oldMaxPerHostConns := common.Config.MaxPerHostConnections + + common.Config.MaxPerHostConnections = 2 + + u := getTestUser() + u.UploadBandwidth = 32 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + var wg sync.WaitGroup + numErrors := 0 + for i := 0; i <= 2; i++ { + wg.Add(1) + go func(counter int) { + defer wg.Done() + + time.Sleep(20 * time.Millisecond) + err := writeSFTPFile(fmt.Sprintf("%s_%d", testFileName, counter), 64*1024, client) + if err != nil { + numErrors++ + } + }(i) + } + wg.Wait() + + assert.Equal(t, 1, numErrors) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.MaxPerHostConnections = oldMaxPerHostConns +} + +func TestMaxSessionsSameConnection(t *testing.T) { + u := getTestUser() + u.UploadBandwidth = 32 + u.MaxSessions = 2 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + var wg sync.WaitGroup + numErrors := 0 + for i := 0; i <= 2; i++ { + wg.Add(1) + go func(counter int) { + defer wg.Done() + + time.Sleep(20 * time.Millisecond) + var err error + if counter < 2 { + err = writeSFTPFile(fmt.Sprintf("%s_%d", testFileName, counter), 64*1024, client) + } else { + _, _, err = getSftpClient(user) + } + if err != nil { + numErrors++ + } + }(i) + } + + wg.Wait() + assert.Equal(t, 1, numErrors) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + func TestRenameDir(t *testing.T) { u := getTestUser() testDir := "/dir-to-rename" diff --git a/internal/common/transfer_test.go b/internal/common/transfer_test.go index 5324484b..8b41e2ff 100644 --- a/internal/common/transfer_test.go +++ b/internal/common/transfer_test.go @@ -323,6 +323,9 @@ func TestRemovePartialCryptoFile(t *testing.T) { assert.Equal(t, int64(0), size) assert.Equal(t, 1, deletedFiles) assert.NoFileExists(t, testFile) + err = transfer.Close() + assert.Error(t, err) + assert.Len(t, conn.GetTransfers(), 0) } func TestFTPMode(t *testing.T) { @@ -434,6 +437,11 @@ func TestTransferQuota(t *testing.T) { } err = transfer.CheckWrite() assert.True(t, conn.IsQuotaExceededError(err)) + + err = transfer.Close() + assert.NoError(t, err) + assert.Len(t, conn.GetTransfers(), 0) + assert.Equal(t, int32(0), Connections.GetTotalTransfers()) } func TestUploadOutsideHomeRenameError(t *testing.T) { diff --git a/internal/common/transferschecker_test.go b/internal/common/transferschecker_test.go index 443e438e..0528de61 100644 --- a/internal/common/transferschecker_test.go +++ b/internal/common/transferschecker_test.go @@ -250,6 +250,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { Connections.Remove(fakeConn5.GetID()) stats := Connections.GetStats("") assert.Len(t, stats, 0) + assert.Equal(t, int32(0), Connections.GetTotalTransfers()) err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) @@ -368,11 +369,16 @@ func TestTransferCheckerTransferQuota(t *testing.T) { if assert.Error(t, transfer4.errAbort) { assert.Contains(t, transfer4.errAbort.Error(), ErrReadQuotaExceeded.Error()) } + err = transfer3.Close() + assert.NoError(t, err) + err = transfer4.Close() + assert.NoError(t, err) Connections.Remove(fakeConn3.GetID()) Connections.Remove(fakeConn4.GetID()) stats := Connections.GetStats("") assert.Len(t, stats, 0) + assert.Equal(t, int32(0), Connections.GetTotalTransfers()) err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) diff --git a/internal/ftpd/cryptfs_test.go b/internal/ftpd/cryptfs_test.go index c036b59a..23567fff 100644 --- a/internal/ftpd/cryptfs_test.go +++ b/internal/ftpd/cryptfs_test.go @@ -134,6 +134,7 @@ func TestBasicFTPHandlingCryptFs(t *testing.T) { assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestBufferedCryptFs(t *testing.T) { @@ -179,6 +180,7 @@ func TestBufferedCryptFs(t *testing.T) { assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestZeroBytesTransfersCryptFs(t *testing.T) { diff --git a/internal/ftpd/ftpd_test.go b/internal/ftpd/ftpd_test.go index 857351f8..18c650fe 100644 --- a/internal/ftpd/ftpd_test.go +++ b/internal/ftpd/ftpd_test.go @@ -37,6 +37,7 @@ import ( ftpserver "github.com/fclairamb/ftpserverlib" "github.com/jlaffaye/ftp" + "github.com/pkg/sftp" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "github.com/rs/zerolog" @@ -44,6 +45,7 @@ import ( sdkkms "github.com/sftpgo/sdk/kms" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/config" @@ -671,6 +673,7 @@ func TestBasicFTPHandling(t *testing.T) { assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestHTTPFs(t *testing.T) { @@ -715,6 +718,7 @@ func TestHTTPFs(t *testing.T) { assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestListDirWithWildcards(t *testing.T) { @@ -1735,6 +1739,66 @@ func TestMaxPerHostConnections(t *testing.T) { common.Config.MaxPerHostConnections = oldValue } +func TestMaxTransfers(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 2 + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + user := getTestUser() + err := dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + user.Password = "" + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + conn, sftpClient, err := getSftpClient(user) + assert.NoError(t, err) + defer conn.Close() + defer sftpClient.Close() + + f1, err := sftpClient.Create("file1") + assert.NoError(t, err) + f2, err := sftpClient.Create("file2") + assert.NoError(t, err) + _, err = f1.Write([]byte(" ")) + assert.NoError(t, err) + _, err = f2.Write([]byte(" ")) + assert.NoError(t, err) + + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.Error(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.Error(t, err) + err := client.Quit() + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + + err = f1.Close() + assert.NoError(t, err) + err = f2.Close() + assert.NoError(t, err) + + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.MaxPerHostConnections = oldValue +} + func TestRateLimiter(t *testing.T) { oldConfig := config.GetCommonConfig() @@ -3962,6 +4026,7 @@ func TestNestedVirtualFolders(t *testing.T) { assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, 50*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func checkBasicFTP(client *ftp.ServerConn) error { @@ -4213,6 +4278,30 @@ func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []by return content } +func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) { + var sftpClient *sftp.Client + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + if user.Password != "" { + config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} + } + + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if err != nil { + return conn, sftpClient, err + } + sftpClient, err = sftp.NewClient(conn) + if err != nil { + conn.Close() + } + return conn, sftpClient, err +} + func getExitCodeScriptContent(exitCode int) []byte { content := []byte("#!/bin/sh\n\n") content = append(content, []byte(fmt.Sprintf("exit %v", exitCode))...) diff --git a/internal/ftpd/handler.go b/internal/ftpd/handler.go index 8ed36a6b..b11778f5 100644 --- a/internal/ftpd/handler.go +++ b/internal/ftpd/handler.go @@ -331,6 +331,11 @@ func (c *Connection) GetHandle(name string, flags int, offset int64) (ftpserver. return nil, errCOMBNotSupported } + if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { + c.Log(logger.LevelInfo, "denying transfer due to count limits") + return nil, c.GetPermissionDeniedError() + } + if flags&os.O_WRONLY != 0 { return c.uploadFile(fs, p, name, flags) } diff --git a/internal/ftpd/internal_test.go b/internal/ftpd/internal_test.go index ee877e7d..ca478387 100644 --- a/internal/ftpd/internal_test.go +++ b/internal/ftpd/internal_test.go @@ -664,6 +664,7 @@ func TestClientVersion(t *testing.T) { common.Connections.Remove(connection.GetID()) } assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestDriverMethodsNotImplemented(t *testing.T) { @@ -918,6 +919,7 @@ func TestTransferErrors(t *testing.T) { pipeWriter := vfs.NewPipeWriter(w) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile, common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + tr.Connection.RemoveTransfer(tr) tr = newTransfer(baseTransfer, pipeWriter, nil, 0) err = r.Close() @@ -933,6 +935,7 @@ func TestTransferErrors(t *testing.T) { if assert.Error(t, err) { assert.EqualError(t, err, common.ErrOpUnsupported.Error()) } + tr.Connection.RemoveTransfer(tr) err = os.Remove(testfile) assert.NoError(t, err) } diff --git a/internal/httpd/api_http_user.go b/internal/httpd/api_http_user.go index 8a189a6e..fcd02fb6 100644 --- a/internal/httpd/api_http_user.go +++ b/internal/httpd/api_http_user.go @@ -317,6 +317,13 @@ func uploadUserFiles(w http.ResponseWriter, r *http.Request) { } defer common.Connections.Remove(connection.GetID()) + if err := common.Connections.IsNewTransferAllowed(connection.User.Username); err != nil { + connection.Log(logger.LevelInfo, "denying file write due to number of transfer limits") + sendAPIResponse(w, r, err, "Denying file write due to transfer count limits", + http.StatusConflict) + return + } + transferQuota := connection.GetTransferQuota() if !transferQuota.HasUploadSpace() { connection.Log(logger.LevelInfo, "denying file write due to transfer quota limits") diff --git a/internal/httpd/api_shares.go b/internal/httpd/api_shares.go index 612b14b3..a54f8c74 100644 --- a/internal/httpd/api_shares.go +++ b/internal/httpd/api_shares.go @@ -380,6 +380,12 @@ func (s *httpdServer) uploadFilesToShare(w http.ResponseWriter, r *http.Request) if err != nil { return } + if err := common.Connections.IsNewTransferAllowed(connection.User.Username); err != nil { + connection.Log(logger.LevelInfo, "denying file write due to number of transfer limits") + sendAPIResponse(w, r, err, "Denying file write due to transfer count limits", + http.StatusConflict) + return + } transferQuota := connection.GetTransferQuota() if !transferQuota.HasUploadSpace() { diff --git a/internal/httpd/handler.go b/internal/httpd/handler.go index bc936a8e..9821af69 100644 --- a/internal/httpd/handler.go +++ b/internal/httpd/handler.go @@ -97,6 +97,11 @@ func (c *Connection) ReadDir(name string) (vfs.DirLister, error) { func (c *Connection) getFileReader(name string, offset int64, method string) (io.ReadCloser, error) { c.UpdateLastActivity() + if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { + c.Log(logger.LevelInfo, "denying file read due to transfer count limits") + return nil, util.NewI18nError(c.GetPermissionDeniedError(), util.I18nError403Message) + } + transferQuota := c.GetTransferQuota() if !transferQuota.HasDownloadSpace() { c.Log(logger.LevelInfo, "denying file read due to quota limits") @@ -188,6 +193,10 @@ func (c *Connection) getFileWriter(name string) (io.WriteCloser, error) { } func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, requestPath string, isNewFile bool, fileSize int64) (io.WriteCloser, error) { + if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { + c.Log(logger.LevelInfo, "denying file write due to transfer count limits") + return nil, util.NewI18nError(c.GetPermissionDeniedError(), util.I18nError403Message) + } diskQuota, transferQuota := c.HasSpace(isNewFile, false, requestPath) if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { c.Log(logger.LevelInfo, "denying file write due to quota limits") diff --git a/internal/httpd/httpd_test.go b/internal/httpd/httpd_test.go index f2a1085e..4de82c4a 100644 --- a/internal/httpd/httpd_test.go +++ b/internal/httpd/httpd_test.go @@ -49,6 +49,7 @@ import ( "github.com/lithammer/shortuuid/v4" _ "github.com/mattn/go-sqlite3" "github.com/mhale/smtpd" + "github.com/pkg/sftp" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "github.com/rs/xid" @@ -6712,6 +6713,7 @@ func TestCloseActiveConnection(t *testing.T) { _, err = httpdtest.CloseConnection(c.GetID(), http.StatusOK) assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestCloseConnectionAfterUserUpdateDelete(t *testing.T) { @@ -6744,6 +6746,7 @@ func TestCloseConnectionAfterUserUpdateDelete(t *testing.T) { _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestAdminGenerateRecoveryCodesSaveError(t *testing.T) { @@ -8829,6 +8832,7 @@ func TestLoaddataMode(t *testing.T) { assert.NoError(t, err) // mode 2 will update the user and close the previous connection assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) assert.Equal(t, oldUploadBandwidth, user.UploadBandwidth) @@ -13115,6 +13119,7 @@ func TestWebClientMaxConnections(t *testing.T) { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) common.Config.MaxTotalConnections = oldValue } @@ -13409,6 +13414,125 @@ func TestMaxSessions(t *testing.T) { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestMaxTransfers(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 2 + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + share := dataprovider.Share{ + Name: "test share", + Scope: dataprovider.ShareScopeReadWrite, + Paths: []string{"/"}, + Password: defaultPassword, + } + asJSON, err := json.Marshal(share) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + fileName := "testfile.txt" + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path="+fileName, bytes.NewBuffer([]byte(" "))) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + conn, sftpClient, err := getSftpClient(user) + assert.NoError(t, err) + defer conn.Close() + defer sftpClient.Close() + + f1, err := sftpClient.Create("file1") + assert.NoError(t, err) + f2, err := sftpClient.Create("file2") + assert.NoError(t, err) + _, err = f1.Write([]byte(" ")) + assert.NoError(t, err) + _, err = f2.Write([]byte(" ")) + assert.NoError(t, err) + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("filenames", "filepre") + assert.NoError(t, err) + _, err = part.Write([]byte("file content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusConflict, rr) + + req, err = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+fileName, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError403Message) + + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path="+fileName, bytes.NewBuffer([]byte(" "))) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + body = new(bytes.Buffer) + writer = multipart.NewWriter(body) + part1, err := writer.CreateFormFile("filenames", "file11.txt") + assert.NoError(t, err) + _, err = part1.Write([]byte("file11 content")) + assert.NoError(t, err) + part2, err := writer.CreateFormFile("filenames", "file22.txt") + assert.NoError(t, err) + _, err = part2.Write([]byte("file22 content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader = bytes.NewReader(body.Bytes()) + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusConflict, rr) + + err = f1.Close() + assert.NoError(t, err) + err = f2.Close() + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + common.Config.MaxPerHostConnections = oldValue } func TestWebConfigsMock(t *testing.T) { @@ -14954,6 +15078,7 @@ func TestShareMaxSessions(t *testing.T) { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestShareUploadSingle(t *testing.T) { @@ -19088,6 +19213,7 @@ func TestClientUserClose(t *testing.T) { wg.Wait() assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) @@ -27083,6 +27209,30 @@ func checkResponseCode(t *testing.T, expected int, rr *httptest.ResponseRecorder assert.Equal(t, expected, rr.Code, rr.Body.String()) } +func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) { + var sftpClient *sftp.Client + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + if user.Password != "" { + config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} + } + + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if err != nil { + return conn, sftpClient, err + } + sftpClient, err = sftp.NewClient(conn) + if err != nil { + conn.Close() + } + return conn, sftpClient, err +} + func createTestFile(path string, size int64) error { baseDir := filepath.Dir(path) if _, err := os.Stat(baseDir); errors.Is(err, fs.ErrNotExist) { diff --git a/internal/sftpd/handler.go b/internal/sftpd/handler.go index d404ec0f..3ae543c4 100644 --- a/internal/sftpd/handler.go +++ b/internal/sftpd/handler.go @@ -76,6 +76,10 @@ func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) { if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(request.Filepath)) { return nil, sftp.ErrSSHFxPermissionDenied } + if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { + c.Log(logger.LevelInfo, "denying file read due to transfer count limits") + return nil, c.GetPermissionDeniedError() + } transferQuota := c.GetTransferQuota() if !transferQuota.HasDownloadSpace() { c.Log(logger.LevelInfo, "denying file read due to quota limits") @@ -120,9 +124,14 @@ func (c *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) { return c.handleFilewrite(request) } -func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReaderAt, error) { +func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReaderAt, error) { //nolint:gocyclo c.UpdateLastActivity() + if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { + c.Log(logger.LevelInfo, "denying file write due to transfer count limits") + return nil, c.GetPermissionDeniedError() + } + if ok, _ := c.User.IsFileAllowed(request.Filepath); !ok { c.Log(logger.LevelWarn, "writing file %q is not allowed", request.Filepath) return nil, c.GetPermissionDeniedError() diff --git a/internal/sftpd/internal_test.go b/internal/sftpd/internal_test.go index 890f62ba..edb8a10f 100644 --- a/internal/sftpd/internal_test.go +++ b/internal/sftpd/internal_test.go @@ -270,6 +270,7 @@ func TestReadWriteErrors(t *testing.T) { err = os.Remove(testfile) assert.NoError(t, err) assert.Len(t, conn.GetTransfers(), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestUnsupportedListOP(t *testing.T) { @@ -1014,6 +1015,8 @@ func TestSystemCommandErrors(t *testing.T) { transfer.MaxWriteSize = -1 _, err = transfer.copyFromReaderToWriter(sshCmd.connection.channel, dst) assert.True(t, transfer.Connection.IsQuotaExceededError(err)) + err = transfer.Close() + assert.Error(t, err) baseTransfer = common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", "", common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{ @@ -1031,9 +1034,13 @@ func TestSystemCommandErrors(t *testing.T) { if assert.Error(t, err) { assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) } + err = transfer.Close() + assert.Error(t, err) err = os.RemoveAll(homeDir) assert.NoError(t, err) + + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestCommandGetFsError(t *testing.T) { @@ -1717,6 +1724,7 @@ func TestSCPUploadFiledata(t *testing.T) { if assert.Error(t, err) { assert.EqualError(t, err, common.ErrTransferClosed.Error()) } + transfer.Connection.RemoveTransfer(transfer) mockSSHChannel = MockChannel{ Buffer: bytes.NewBuffer(buf), @@ -1728,9 +1736,12 @@ func TestSCPUploadFiledata(t *testing.T) { transfer.Connection.AddTransfer(transfer) err = scpCommand.getUploadFileData(2, transfer) assert.ErrorContains(t, err, os.ErrClosed.Error()) + transfer.Connection.RemoveTransfer(transfer) err = os.Remove(testfile) assert.NoError(t, err) + + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestUploadError(t *testing.T) { @@ -2040,6 +2051,7 @@ func TestRecoverer(t *testing.T) { err = scpCmd.handle() assert.EqualError(t, err, common.ErrGenericFailure.Error()) assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestListernerAcceptErrors(t *testing.T) { @@ -2170,6 +2182,7 @@ func TestMaxUserSessions(t *testing.T) { } common.Connections.Remove(connection.GetID()) assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestCanReadSymlink(t *testing.T) { diff --git a/internal/sftpd/scp.go b/internal/sftpd/scp.go index e62332c7..653587e4 100644 --- a/internal/sftpd/scp.go +++ b/internal/sftpd/scp.go @@ -227,6 +227,12 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) err } func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error { + if err := common.Connections.IsNewTransferAllowed(c.connection.User.Username); err != nil { + err := fmt.Errorf("denying file write due to transfer count limits") + c.connection.Log(logger.LevelInfo, "denying file write due to transfer count limits") + c.sendErrorMessage(nil, err) + return err + } diskQuota, transferQuota := c.connection.HasSpace(isNewFile, false, requestPath) if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { err := fmt.Errorf("denying file write due to quota limits") @@ -501,6 +507,13 @@ func (c *scpCommand) sendDownloadFileData(fs vfs.Fs, filePath string, stat os.Fi func (c *scpCommand) handleDownload(filePath string) error { c.connection.UpdateLastActivity() + + if err := common.Connections.IsNewTransferAllowed(c.connection.User.Username); err != nil { + err := fmt.Errorf("denying file read due to transfer count limits") + c.connection.Log(logger.LevelInfo, "denying file read due to transfer count limits") + c.sendErrorMessage(nil, err) + return err + } transferQuota := c.connection.GetTransferQuota() if !transferQuota.HasDownloadSpace() { c.connection.Log(logger.LevelInfo, "denying file read due to quota limits") diff --git a/internal/sftpd/sftpd_test.go b/internal/sftpd/sftpd_test.go index 0f5666d9..3ddf3fa7 100644 --- a/internal/sftpd/sftpd_test.go +++ b/internal/sftpd/sftpd_test.go @@ -1202,6 +1202,7 @@ func TestConcurrency(t *testing.T) { assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) err = os.Remove(testFilePath) assert.NoError(t, err) @@ -4391,6 +4392,76 @@ func TestMaxPerHostConnections(t *testing.T) { common.Config.MaxPerHostConnections = oldValue } +func TestMaxTransfers(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 2 + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + usePubKey := true + user := getTestUser(usePubKey) + err := dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + user.Password = "" + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + assert.NoError(t, checkBasicSFTP(client)) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + + f1, err := client.Create("file1") + assert.NoError(t, err) + f2, err := client.Create("file2") + assert.NoError(t, err) + _, err = f1.Write([]byte(" ")) + assert.NoError(t, err) + _, err = f2.Write([]byte(" ")) + assert.NoError(t, err) + + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.ErrorContains(t, err, sftp.ErrSSHFxPermissionDenied.Error()) + + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.Error(t, err) + + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.ErrorContains(t, err, sftp.ErrSSHFxPermissionDenied.Error()) + + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + err = scpDownload(localDownloadPath, remoteDownPath, false, false) + assert.Error(t, err) + + err = f1.Close() + assert.NoError(t, err) + err = f2.Close() + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = client.Close() + assert.NoError(t, err) + err = conn.Close() + assert.NoError(t, err) + } + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + common.Config.MaxPerHostConnections = oldValue +} + func TestMaxSessions(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) @@ -4940,6 +5011,7 @@ func TestBandwidthAndConnections(t *testing.T) { assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 10*time.Second, 200*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) err = os.Remove(testFilePath) assert.NoError(t, err) err = os.Remove(localDownloadPath) @@ -9859,6 +9931,62 @@ func TestBasicGitCommands(t *testing.T) { assert.NoError(t, err) } +func TestSSHCommandMaxTransfers(t *testing.T) { + if len(gitPath) == 0 || len(sshPath) == 0 || runtime.GOOS == osWindows { + t.Skip("git and/or ssh command not found or OS is windows, unable to execute this test") + } + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 2 + + usePubKey := true + u := getTestUser(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + repoName := "testrepo" //nolint:goconst + clonePath := filepath.Join(homeBasePath, repoName) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(filepath.Join(homeBasePath, repoName)) + assert.NoError(t, err) + out, err := initGitRepo(filepath.Join(user.HomeDir, repoName)) + assert.NoError(t, err, "unexpected error, out: %v", string(out)) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + f1, err := client.Create("file1") + assert.NoError(t, err) + f2, err := client.Create("file2") + assert.NoError(t, err) + _, err = f1.Write([]byte(" ")) + assert.NoError(t, err) + _, err = f2.Write([]byte(" ")) + assert.NoError(t, err) + + _, err = cloneGitRepo(homeBasePath, "/"+repoName, user.Username) + assert.Error(t, err) + + err = f1.Close() + assert.NoError(t, err) + err = f2.Close() + assert.NoError(t, err) + err = client.Close() + assert.NoError(t, err) + err = conn.Close() + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(clonePath) + assert.NoError(t, err) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + common.Config.MaxPerHostConnections = oldValue +} + func TestGitIncludedVirtualFolders(t *testing.T) { if len(gitPath) == 0 || len(sshPath) == 0 || runtime.GOOS == osWindows { t.Skip("git and/or ssh command not found or OS is windows, unable to execute this test") @@ -11104,6 +11232,7 @@ func TestSCPErrors(t *testing.T) { err = cmd.Process.Kill() assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 2*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) cmd = getScpUploadCommand(testFilePath, remoteUpPath, false, false) go func() { err := cmd.Run() @@ -11116,6 +11245,7 @@ func TestSCPErrors(t *testing.T) { err = cmd.Process.Kill() assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 2*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) err = os.Remove(testFilePath) assert.NoError(t, err) os.Remove(localPath) diff --git a/internal/sftpd/ssh_cmd.go b/internal/sftpd/ssh_cmd.go index fa782f0e..604a0bb0 100644 --- a/internal/sftpd/ssh_cmd.go +++ b/internal/sftpd/ssh_cmd.go @@ -246,11 +246,15 @@ func (c *sshCommand) handleHashCommands() error { return nil } -func (c *sshCommand) executeSystemCommand(command systemCommand) error { +func (c *sshCommand) executeSystemCommand(command systemCommand) error { //nolint:gocyclo sshDestPath := c.getDestPath() if !c.isLocalPath(sshDestPath) { return c.sendErrorResponse(errUnsupportedConfig) } + if err := common.Connections.IsNewTransferAllowed(c.connection.User.Username); err != nil { + err := fmt.Errorf("denying command due to transfer count limits") + return c.sendErrorResponse(err) + } diskQuota, transferQuota := c.connection.HasSpace(true, false, command.quotaCheckPath) if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() || !transferQuota.HasDownloadSpace() { return c.sendErrorResponse(common.ErrQuotaExceeded) diff --git a/internal/webdavd/handler.go b/internal/webdavd/handler.go index 829c1173..54f42a18 100644 --- a/internal/webdavd/handler.go +++ b/internal/webdavd/handler.go @@ -145,6 +145,11 @@ func (c *Connection) RemoveAll(_ context.Context, name string) error { func (c *Connection) OpenFile(_ context.Context, name string, flag int, _ os.FileMode) (webdav.File, error) { c.UpdateLastActivity() + if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { + c.Log(logger.LevelInfo, "denying transfer due to count limits") + return nil, c.GetPermissionDeniedError() + } + name = util.CleanPath(name) fs, p, err := c.GetFsAndResolvedPath(name) if err != nil { diff --git a/internal/webdavd/internal_test.go b/internal/webdavd/internal_test.go index 94e60313..4b676c69 100644 --- a/internal/webdavd/internal_test.go +++ b/internal/webdavd/internal_test.go @@ -760,6 +760,8 @@ func TestContentType(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "application/sftpgo", ctype) } + err = davFile.Close() + assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile+".unknown2", common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) @@ -814,6 +816,8 @@ func TestTransferReadWriteErrors(t *testing.T) { assert.NoError(t, err) err = w.Close() assert.NoError(t, err) + err = davFile.BaseTransfer.Close() + assert.Error(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) @@ -822,6 +826,8 @@ func TestTransferReadWriteErrors(t *testing.T) { assert.True(t, fs.IsNotExist(err)) _, err = davFile.Stat() assert.True(t, fs.IsNotExist(err)) + err = davFile.Close() + assert.Error(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) @@ -844,6 +850,8 @@ func TestTransferReadWriteErrors(t *testing.T) { if assert.NoError(t, err) { assert.Equal(t, int64(0), info.Size()) } + err = davFile.Close() + assert.Error(t, err) r, w, err = pipeat.Pipe() assert.NoError(t, err) @@ -987,8 +995,11 @@ func TestTransferSeek(t *testing.T) { res, err = davFile.Seek(2, io.SeekEnd) assert.True(t, fs.IsNotExist(err)) assert.Equal(t, int64(0), res) + err = davFile.Close() + assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) err = os.Remove(testFilePath) assert.NoError(t, err) diff --git a/internal/webdavd/webdavd_test.go b/internal/webdavd/webdavd_test.go index 6fbdb0a9..fe58b9dd 100644 --- a/internal/webdavd/webdavd_test.go +++ b/internal/webdavd/webdavd_test.go @@ -38,11 +38,13 @@ import ( "time" "github.com/minio/sio" + "github.com/pkg/sftp" "github.com/rs/zerolog" "github.com/sftpgo/sdk" sdkkms "github.com/sftpgo/sdk/kms" "github.com/stretchr/testify/assert" "github.com/studio-b12/gowebdav" + "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/config" @@ -637,6 +639,7 @@ func TestBasicHandling(t *testing.T) { assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) status := webdavd.GetStatus() assert.True(t, status.IsActive) } @@ -721,6 +724,7 @@ func TestBasicHandlingCryptFs(t *testing.T) { assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestBufferedUser(t *testing.T) { @@ -1010,6 +1014,8 @@ func TestRenameWithLock(t *testing.T) { err = resp.Body.Close() assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) @@ -1077,6 +1083,7 @@ func TestPropPatch(t *testing.T) { assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestLoginInvalidPwd(t *testing.T) { @@ -1520,6 +1527,7 @@ func TestPreDownloadHook(t *testing.T) { assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) common.Config.Actions.ExecuteOn = []string{common.OperationPreDownload} common.Config.Actions.Hook = preDownloadPath @@ -1570,6 +1578,7 @@ func TestPreUploadHook(t *testing.T) { assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) common.Config.Actions.ExecuteOn = oldExecuteOn common.Config.Actions.Hook = oldHook @@ -1633,6 +1642,7 @@ func TestMaxConnections(t *testing.T) { assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) common.Config.MaxTotalConnections = oldValue } @@ -1665,6 +1675,61 @@ func TestMaxPerHostConnections(t *testing.T) { assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + common.Config.MaxPerHostConnections = oldValue +} + +func TestMaxTransfers(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 2 + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, true, nil) + assert.NoError(t, checkBasicFunc(client)) + + conn, sftpClient, err := getSftpClient(user) + assert.NoError(t, err) + defer conn.Close() + defer sftpClient.Close() + + f1, err := sftpClient.Create("file1") + assert.NoError(t, err) + f2, err := sftpClient.Create("file2") + assert.NoError(t, err) + _, err = f1.Write([]byte(" ")) + assert.NoError(t, err) + _, err = f2.Write([]byte(" ")) + assert.NoError(t, err) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + assert.Error(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + + err = f1.Close() + assert.NoError(t, err) + err = f2.Close() + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) common.Config.MaxPerHostConnections = oldValue } @@ -1712,6 +1777,7 @@ func TestMaxSessions(t *testing.T) { assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func TestLoginWithIPilters(t *testing.T) { @@ -2171,6 +2237,7 @@ func TestClientClose(t *testing.T) { wg.Wait() assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) err = os.Remove(localDownloadPath) assert.NoError(t, err) @@ -3276,6 +3343,7 @@ func TestNestedVirtualFolders(t *testing.T) { assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) } func checkBasicFunc(client *gowebdav.Client) error { @@ -3472,6 +3540,30 @@ func getTestUserWithCryptFs() dataprovider.User { return user } +func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) { + var sftpClient *sftp.Client + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + if user.Password != "" { + config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} + } + + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if err != nil { + return conn, sftpClient, err + } + sftpClient, err = sftp.NewClient(conn) + if err != nil { + conn.Close() + } + return conn, sftpClient, err +} + func getEncryptedFileSize(size int64) (int64, error) { encSize, err := sio.EncryptedSize(uint64(size)) return int64(encSize) + 33, err