From 002a06629e588dd13fe41e542aaef37482f6dcb6 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Thu, 14 Apr 2022 19:07:41 +0200 Subject: [PATCH] refactoring of user session counters Fixes #792 Signed-off-by: Nicola Murino --- common/common.go | 55 ++++++++-- common/common_test.go | 70 +++++++++--- common/connection.go | 5 + common/transferschecker_test.go | 27 +++-- ftpd/internal_test.go | 3 +- ftpd/server.go | 10 +- go.mod | 15 +-- go.sum | 29 ++--- httpd/api_http_user.go | 70 +++++------- httpd/api_shares.go | 25 ++++- httpd/api_utils.go | 4 +- httpd/httpd_test.go | 185 +++++++++++++++++++++++++++++--- httpd/middleware.go | 2 +- httpd/oidc.go | 2 +- httpd/server.go | 8 +- httpd/webclient.go | 38 +++++-- sftpd/cryptfs_test.go | 4 +- sftpd/handler.go | 6 +- sftpd/internal_test.go | 43 ++++++++ sftpd/scp.go | 5 +- sftpd/server.go | 10 +- sftpd/sftpd_test.go | 52 +++++---- sftpd/ssh_cmd.go | 7 +- sftpd/subsystem.go | 9 +- vfs/cryptfs.go | 12 +-- vfs/osfs.go | 13 +-- webdavd/server.go | 26 +++-- webdavd/webdavd_test.go | 6 +- 28 files changed, 542 insertions(+), 199 deletions(-) diff --git a/common/common.go b/common/common.go index 55fda7d7..40e54655 100644 --- a/common/common.go +++ b/common/common.go @@ -98,6 +98,7 @@ func init() { Connections.clients = clientsMap{ clients: make(map[string]int), } + Connections.perUserConns = make(map[string]int) } // errors definitions @@ -345,6 +346,7 @@ type ActiveTransfer interface { type ActiveConnection interface { GetID() string GetUsername() string + GetMaxSessions() int GetLocalAddress() string GetRemoteAddress() string GetClientVersion() string @@ -733,6 +735,29 @@ type ActiveConnections struct { sync.RWMutex connections []ActiveConnection sshConnections []*SSHConnection + perUserConns map[string]int +} + +// internal method, must be called within a locked block +func (conns *ActiveConnections) addUserConnection(username string) { + if username == "" { + return + } + conns.perUserConns[username]++ +} + +// internal method, must be called within a locked block +func (conns *ActiveConnections) removeUserConnection(username string) { + if username == "" { + return + } + if val, ok := conns.perUserConns[username]; ok { + conns.perUserConns[username]-- + if val > 1 { + return + } + delete(conns.perUserConns, username) + } } // GetActiveSessions returns the number of active sessions for the given username. @@ -741,24 +766,27 @@ func (conns *ActiveConnections) GetActiveSessions(username string) int { conns.RLock() defer conns.RUnlock() - numSessions := 0 - for _, c := range conns.connections { - if c.GetUsername() == username { - numSessions++ - } - } - return numSessions + return conns.perUserConns[username] } // Add adds a new connection to the active ones -func (conns *ActiveConnections) Add(c ActiveConnection) { +func (conns *ActiveConnections) Add(c ActiveConnection) error { conns.Lock() defer conns.Unlock() + if username := c.GetUsername(); username != "" { + if maxSessions := c.GetMaxSessions(); maxSessions > 0 { + if val := conns.perUserConns[username]; val >= maxSessions { + return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions) + } + } + conns.addUserConnection(username) + } conns.connections = append(conns.connections, c) metric.UpdateActiveConnectionsSize(len(conns.connections)) logger.Debug(c.GetProtocol(), c.GetID(), "connection added, local address %#v, remote address %#v, num open connections: %v", c.GetLocalAddress(), c.GetRemoteAddress(), len(conns.connections)) + return nil } // Swap replaces an existing connection with the given one. @@ -771,6 +799,16 @@ func (conns *ActiveConnections) Swap(c ActiveConnection) error { for idx, conn := range conns.connections { if conn.GetID() == c.GetID() { + conns.removeUserConnection(conn.GetUsername()) + if username := c.GetUsername(); username != "" { + if maxSessions := c.GetMaxSessions(); maxSessions > 0 { + if val := conns.perUserConns[username]; val >= maxSessions { + conns.addUserConnection(conn.GetUsername()) + return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions) + } + } + conns.addUserConnection(username) + } err := conn.CloseFS() conns.connections[idx] = c logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err) @@ -793,6 +831,7 @@ func (conns *ActiveConnections) Remove(connectionID string) { conns.connections[idx] = conns.connections[lastIdx] conns.connections[lastIdx] = nil conns.connections = conns.connections[:lastIdx] + conns.removeUserConnection(conn.GetUsername()) metric.UpdateActiveConnectionsSize(lastIdx) logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %#v, remote address %#v close fs error: %v, num open connections: %v", conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx) diff --git a/common/common_test.go b/common/common_test.go index 4f8c12a5..8d975fc9 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -370,6 +370,29 @@ func TestWhitelist(t *testing.T) { Config = configCopy } +func TestUserMaxSessions(t *testing.T) { + c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername, + MaxSessions: 1, + }, + }) + fakeConn := &fakeConnection{ + BaseConnection: c, + } + err := Connections.Add(fakeConn) + assert.NoError(t, err) + err = Connections.Add(fakeConn) + assert.Error(t, err) + err = Connections.Swap(fakeConn) + assert.NoError(t, err) + Connections.Remove(fakeConn.GetID()) + Connections.Lock() + Connections.removeUserConnection(userTestUsername) + Connections.Unlock() + assert.Len(t, Connections.GetStats(), 0) +} + func TestMaxConnections(t *testing.T) { oldValue := Config.MaxTotalConnections perHost := Config.MaxPerHostConnections @@ -387,7 +410,8 @@ func TestMaxConnections(t *testing.T) { fakeConn := &fakeConnection{ BaseConnection: c, } - Connections.Add(fakeConn) + err := Connections.Add(fakeConn) + assert.NoError(t, err) assert.Len(t, Connections.GetStats(), 1) assert.False(t, Connections.IsNewConnectionAllowed(ipAddr)) @@ -466,14 +490,16 @@ func TestIdleConnections(t *testing.T) { sshConn1.lastActivity = c.lastActivity sshConn2.lastActivity = c.lastActivity Connections.AddSSHConnection(sshConn1) - Connections.Add(fakeConn) + err = Connections.Add(fakeConn) + assert.NoError(t, err) assert.Equal(t, Connections.GetActiveSessions(username), 1) c = NewBaseConnection(sshConn2.id+"_1", ProtocolSSH, "", "", user) fakeConn = &fakeConnection{ BaseConnection: c, } Connections.AddSSHConnection(sshConn2) - Connections.Add(fakeConn) + err = Connections.Add(fakeConn) + assert.NoError(t, err) assert.Equal(t, Connections.GetActiveSessions(username), 2) cFTP := NewBaseConnection("id2", ProtocolFTP, "", "", dataprovider.User{}) @@ -481,7 +507,8 @@ func TestIdleConnections(t *testing.T) { fakeConn = &fakeConnection{ BaseConnection: cFTP, } - Connections.Add(fakeConn) + err = Connections.Add(fakeConn) + assert.NoError(t, err) assert.Equal(t, Connections.GetActiveSessions(username), 2) assert.Len(t, Connections.GetStats(), 3) Connections.RLock() @@ -521,7 +548,8 @@ func TestCloseConnection(t *testing.T) { BaseConnection: c, } assert.True(t, Connections.IsNewConnectionAllowed("127.0.0.1")) - Connections.Add(fakeConn) + err := Connections.Add(fakeConn) + assert.NoError(t, err) assert.Len(t, Connections.GetStats(), 1) res := Connections.Close(fakeConn.GetID()) assert.True(t, res) @@ -536,19 +564,34 @@ func TestSwapConnection(t *testing.T) { fakeConn := &fakeConnection{ BaseConnection: c, } - Connections.Add(fakeConn) + err := Connections.Add(fakeConn) + assert.NoError(t, err) if assert.Len(t, Connections.GetStats(), 1) { assert.Equal(t, "", Connections.GetStats()[0].Username) } c = NewBaseConnection("id", ProtocolFTP, "", "", dataprovider.User{ BaseUser: sdk.BaseUser{ - Username: userTestUsername, + Username: userTestUsername, + MaxSessions: 1, }, }) fakeConn = &fakeConnection{ BaseConnection: c, } - err := Connections.Swap(fakeConn) + c1 := NewBaseConnection("id1", ProtocolFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername, + }, + }) + fakeConn1 := &fakeConnection{ + BaseConnection: c1, + } + err = Connections.Add(fakeConn1) + assert.NoError(t, err) + err = Connections.Swap(fakeConn) + assert.Error(t, err) + Connections.Remove(fakeConn1.ID) + err = Connections.Swap(fakeConn) assert.NoError(t, err) if assert.Len(t, Connections.GetStats(), 1) { assert.Equal(t, userTestUsername, Connections.GetStats()[0].Username) @@ -600,9 +643,12 @@ func TestConnectionStatus(t *testing.T) { command: "PROPFIND", } t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) - Connections.Add(fakeConn1) - Connections.Add(fakeConn2) - Connections.Add(fakeConn3) + err := Connections.Add(fakeConn1) + assert.NoError(t, err) + err = Connections.Add(fakeConn2) + assert.NoError(t, err) + err = Connections.Add(fakeConn3) + assert.NoError(t, err) stats := Connections.GetStats() assert.Len(t, stats, 3) @@ -628,7 +674,7 @@ func TestConnectionStatus(t *testing.T) { } } - err := t1.Close() + err = t1.Close() assert.NoError(t, err) err = t2.Close() assert.NoError(t, err) diff --git a/common/connection.go b/common/connection.go index 494087e6..4d9d23a4 100644 --- a/common/connection.go +++ b/common/connection.go @@ -80,6 +80,11 @@ func (c *BaseConnection) GetUsername() string { return c.User.Username } +// GetMaxSessions returns the maximum number of concurrent sessions allowed +func (c *BaseConnection) GetMaxSessions() int { + return c.User.MaxSessions +} + // GetProtocol returns the protocol for the connection func (c *BaseConnection) GetProtocol() string { return c.protocol diff --git a/common/transferschecker_test.go b/common/transferschecker_test.go index 50846ea4..63bb7731 100644 --- a/common/transferschecker_test.go +++ b/common/transferschecker_test.go @@ -62,7 +62,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), "/file1", TransferUpload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{}) transfer1.BytesReceived = 150 - Connections.Add(fakeConn1) + err = Connections.Add(fakeConn1) + assert.NoError(t, err) // the transferschecker will do nothing if there is only one ongoing transfer Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) @@ -76,7 +77,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { "/file2", TransferUpload, 0, 0, 120, 40, true, fsUser, dataprovider.TransferQuota{}) transfer1.BytesReceived = 50 transfer2.BytesReceived = 60 - Connections.Add(fakeConn2) + err = Connections.Add(fakeConn2) + assert.NoError(t, err) connID3 := xid.New().String() conn3 := NewBaseConnection(connID3, ProtocolSFTP, "", "", user) @@ -86,7 +88,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"), "/file3", TransferDownload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{}) transfer3.BytesReceived = 60 // this value will be ignored, this is a download - Connections.Add(fakeConn3) + err = Connections.Add(fakeConn3) + assert.NoError(t, err) // the transfers are not overquota Connections.checkTransfers() @@ -146,7 +149,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(os.TempDir(), folderName, "file1"), filepath.Join(os.TempDir(), folderName, "file1"), path.Join(vdirPath, "/file1"), TransferUpload, 0, 0, 100, 0, true, fsFolder, dataprovider.TransferQuota{}) - Connections.Add(fakeConn4) + err = Connections.Add(fakeConn4) + assert.NoError(t, err) connID5 := xid.New().String() conn5 := NewBaseConnection(connID5, ProtocolSFTP, "", "", user) fakeConn5 := &fakeConnection{ @@ -156,7 +160,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { filepath.Join(os.TempDir(), folderName, "file2"), path.Join(vdirPath, "/file2"), TransferUpload, 0, 0, 100, 0, true, fsFolder, dataprovider.TransferQuota{}) - Connections.Add(fakeConn5) + err = Connections.Add(fakeConn5) + assert.NoError(t, err) transfer4.BytesReceived = 50 transfer5.BytesReceived = 40 Connections.checkTransfers() @@ -245,7 +250,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) { transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), "/file1", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100}) transfer1.BytesReceived = 150 - Connections.Add(fakeConn1) + err = Connections.Add(fakeConn1) + assert.NoError(t, err) // the transferschecker will do nothing if there is only one ongoing transfer Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) @@ -258,7 +264,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) { transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), "/file2", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100}) transfer2.BytesReceived = 150 - Connections.Add(fakeConn2) + err = Connections.Add(fakeConn2) + assert.NoError(t, err) Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer2.errAbort) @@ -294,7 +301,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) { transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), "/file1", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100}) transfer3.BytesSent = 150 - Connections.Add(fakeConn3) + err = Connections.Add(fakeConn3) + assert.NoError(t, err) connID4 := xid.New().String() conn4 := NewBaseConnection(connID4, ProtocolSFTP, "", "", user) @@ -304,7 +312,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) { transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), "/file2", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100}) transfer4.BytesSent = 150 - Connections.Add(fakeConn4) + err = Connections.Add(fakeConn4) + assert.NoError(t, err) Connections.checkTransfers() assert.Nil(t, transfer3.errAbort) assert.Nil(t, transfer4.errAbort) diff --git a/ftpd/internal_test.go b/ftpd/internal_test.go index 6e43565d..b428ac18 100644 --- a/ftpd/internal_test.go +++ b/ftpd/internal_test.go @@ -593,7 +593,8 @@ func TestClientVersion(t *testing.T) { BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), clientContext: mockCC, } - common.Connections.Add(connection) + err := common.Connections.Add(connection) + assert.NoError(t, err) stats := common.Connections.GetStats() if assert.Len(t, stats, 1) { assert.Equal(t, "mock version", stats[0].ClientVersion) diff --git a/ftpd/server.go b/ftpd/server.go index 4420a87a..ca4aa4b7 100644 --- a/ftpd/server.go +++ b/ftpd/server.go @@ -168,8 +168,8 @@ func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) { cc.RemoteAddr().String(), user), clientContext: cc, } - common.Connections.Add(connection) - return s.initialMsg, nil + err = common.Connections.Add(connection) + return s.initialMsg, err } // ClientDisconnected is called when the user disconnects, even if he never authenticated @@ -367,9 +367,9 @@ func (s *Server) validateUser(user dataprovider.User, cc ftpserver.ClientContext } err = common.Connections.Swap(connection) if err != nil { - err = user.CloseFs() - logger.Warn(logSender, connectionID, "unable to swap connection, close fs error: %v", err) - return nil, common.ErrInternalFailure + errClose := user.CloseFs() + logger.Warn(logSender, connectionID, "unable to swap connection: %v, close fs error: %v", err, errClose) + return nil, err } return connection, nil } diff --git a/go.mod b/go.mod index 0b2dd090..1138ff88 100644 --- a/go.mod +++ b/go.mod @@ -12,9 +12,9 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.15.3 github.com/aws/aws-sdk-go-v2/credentials v1.11.2 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3 - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.4 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.5 github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.3 - github.com/aws/aws-sdk-go-v2/service/s3 v1.26.4 + github.com/aws/aws-sdk-go-v2/service/s3 v1.26.5 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.4 github.com/aws/aws-sdk-go-v2/service/sts v1.16.3 github.com/cockroachdb/cockroach-go/v2 v2.2.8 @@ -32,10 +32,10 @@ require ( github.com/grandcat/zeroconf v1.0.0 github.com/hashicorp/go-hclog v1.2.0 github.com/hashicorp/go-plugin v1.4.3 - github.com/hashicorp/go-retryablehttp v0.7.0 + github.com/hashicorp/go-retryablehttp v0.7.1 github.com/jlaffaye/ftp v0.0.0-20201112195030-9aae4d151126 github.com/klauspost/compress v1.15.1 - github.com/lestrrat-go/jwx v1.2.22 + github.com/lestrrat-go/jwx v1.2.23 github.com/lib/pq v1.10.5 github.com/lithammer/shortuuid/v3 v3.0.7 github.com/mattn/go-sqlite3 v1.14.12 @@ -54,7 +54,7 @@ require ( github.com/shirou/gopsutil/v3 v3.22.3 github.com/spf13/afero v1.8.2 github.com/spf13/cobra v1.4.0 - github.com/spf13/viper v1.10.1 + github.com/spf13/viper v1.11.0 github.com/stretchr/testify v1.7.1 github.com/studio-b12/gowebdav v0.0.0-20220128162035-c7b1ff8a5e62 github.com/unrolled/secure v1.10.0 @@ -67,7 +67,7 @@ require ( golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 golang.org/x/net v0.0.0-20220412020605-290c469a71a5 golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5 - golang.org/x/sys v0.0.0-20220412071739-889880a91fd5 + golang.org/x/sys v0.0.0-20220412211240-33da011f77ad golang.org/x/time v0.0.0-20220411224347-583f2d630306 google.golang.org/api v0.74.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 @@ -130,6 +130,7 @@ require ( github.com/mitchellh/mapstructure v1.4.3 // indirect github.com/oklog/run v1.1.0 // indirect github.com/pelletier/go-toml v1.9.4 // indirect + github.com/pelletier/go-toml/v2 v2.0.0-beta.8 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20220216144756-c35f1ee13d7c // indirect @@ -151,7 +152,7 @@ require ( golang.org/x/tools v0.1.10 // indirect golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20220407144326-9054f6ed7bac // indirect + google.golang.org/genproto v0.0.0-20220413183235-5e96e2839df9 // indirect google.golang.org/grpc v1.45.0 // indirect google.golang.org/protobuf v1.28.0 // indirect gopkg.in/ini.v1 v1.66.4 // indirect diff --git a/go.sum b/go.sum index 13b34f6e..df2f7277 100644 --- a/go.sum +++ b/go.sum @@ -142,8 +142,8 @@ github.com/aws/aws-sdk-go-v2/credentials v1.11.2/go.mod h1:j8YsY9TXTm31k4eFhspiQ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3 h1:LWPg5zjHV9oz/myQr4wMs0gi4CjnDN/ILmyZUFYXZsU= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3/go.mod h1:uk1vhHHERfSVCUnqSqz8O48LBYDSC+k6brng09jcMOk= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.3/go.mod h1:0dHuD2HZZSiwfJSy1FO5bX1hQ1TxVV1QXXjpn3XUE44= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.4 h1:iqcMQBj/B3FPxVb5SGNHC8XAh64hmaWUC8piZArBE7U= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.4/go.mod h1:s79ZPBpDzcR1BCuAhGCF1rgmd/QmLueKCvdkmX4SDgg= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.5 h1:lPo/NX1o4vkk2C7mHmB2FCf9Qp7KZNHrlzHxdP/yugw= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.5/go.mod h1:JNo9mEKrjnmDBc19z65TZmj1xG9PQHu2GOlApYk31DU= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9 h1:onz/VaaxZ7Z4V+WIN9Txly9XLTmoOh1oJ8XcAC3pako= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9/go.mod h1:AnVH5pvai0pAF4lXRq0bmhbes1u9R8wTE+g+183bZNM= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.3 h1:9stUQR/u2KXU6HkFJYlqnZEjBnbgrVbG6I5HN09xZh0= @@ -164,8 +164,8 @@ github.com/aws/aws-sdk-go-v2/service/kms v1.16.3/go.mod h1:QuiHPBqlOFCi4LqdSskYY github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.3 h1:xqXHk4UDW7ii4MRciyLpY87yuZds0iymmgHt3h35xTE= github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.3/go.mod h1:HT0cm2+NUCF33MdXjck554HC6VRgQ4q6JIlSqlYZ18Y= github.com/aws/aws-sdk-go-v2/service/s3 v1.26.3/go.mod h1:g1qvDuRsJY+XghsV6zg00Z4KJ7DtFFCx8fJD2a491Ak= -github.com/aws/aws-sdk-go-v2/service/s3 v1.26.4 h1:frOI/v6KWuKGlKUA5gheRw01EDpxcCxTalFQkCOZXAo= -github.com/aws/aws-sdk-go-v2/service/s3 v1.26.4/go.mod h1:qFKU5d+PAv+23bi9ZhtWeA+TmLUz7B/R59ZGXQ1Mmu4= +github.com/aws/aws-sdk-go-v2/service/s3 v1.26.5 h1:A3PuAUlh1u47WHcM68CDaG9ZWjK7ewePjDp+0dY9yv4= +github.com/aws/aws-sdk-go-v2/service/s3 v1.26.5/go.mod h1:qFKU5d+PAv+23bi9ZhtWeA+TmLUz7B/R59ZGXQ1Mmu4= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.4 h1:EmIEXOjAdXtxa2OGM1VAajZV/i06Q8qd4kBpJd9/p1k= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.4/go.mod h1:PJc8s+lxyU8rrre0/4a0pn2wgwiDvOEzoOjcJUBr67o= github.com/aws/aws-sdk-go-v2/service/sns v1.17.4/go.mod h1:kElt+uCcXxcqFyc+bQqZPFD9DME/eC6oHBXvFzQ9Bcw= @@ -427,8 +427,8 @@ github.com/hashicorp/go-hclog v1.2.0 h1:La19f8d7WIlm4ogzNHB0JGqs5AUDAZ2UfCY4sJXc github.com/hashicorp/go-hclog v1.2.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= github.com/hashicorp/go-plugin v1.4.3 h1:DXmvivbWD5qdiBts9TpBC7BYL1Aia5sxbRgQB+v6UZM= github.com/hashicorp/go-plugin v1.4.3/go.mod h1:5fGEH17QVwTTcR0zV7yhDPLLmFX9YSZ38b18Udy6vYQ= -github.com/hashicorp/go-retryablehttp v0.7.0 h1:eu1EI/mbirUgP5C8hVsTNaGZreBDlYiwC1FZWkvQPQ4= -github.com/hashicorp/go-retryablehttp v0.7.0/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY= +github.com/hashicorp/go-retryablehttp v0.7.1 h1:sUiuQAnLlbvmExtFQs72iFW/HXeUn8Z1aJLQ4LJJbTQ= +github.com/hashicorp/go-retryablehttp v0.7.1/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= @@ -549,8 +549,8 @@ github.com/lestrrat-go/iter v1.0.1/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbq github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= github.com/lestrrat-go/jwx v1.2.6/go.mod h1:tJuGuAI3LC71IicTx82Mz1n3w9woAs2bYJZpkjJQ5aU= -github.com/lestrrat-go/jwx v1.2.22 h1:bCxMokwHNuJHVxgANP4OBddXGtQ9Oy+6cqp4O2rW7DU= -github.com/lestrrat-go/jwx v1.2.22/go.mod h1:sAXjRwzSvCN6soO4RLoWWm1bVPpb8iOuv0IYfH8OWd8= +github.com/lestrrat-go/jwx v1.2.23 h1:8oP5fY1yzCRraUNNyfAVdOkLCqY7xMZz11lVcvHqC1Y= +github.com/lestrrat-go/jwx v1.2.23/go.mod h1:sAXjRwzSvCN6soO4RLoWWm1bVPpb8iOuv0IYfH8OWd8= github.com/lestrrat-go/option v1.0.0 h1:WqAWL8kh8VcSoD6xjSH34/1m8yxluXQbDeKNfvFeEO4= github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -627,6 +627,8 @@ github.com/otiai10/mint v1.3.3 h1:7JgpsBaN0uMkyju4tbYHu0mnM55hNKVYLsXmwr15NQI= github.com/otiai10/mint v1.3.3/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc= github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM= github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= +github.com/pelletier/go-toml/v2 v2.0.0-beta.8 h1:dy81yyLYJDwMTifq24Oi/IslOslRrDSb3jwDggjz3Z0= +github.com/pelletier/go-toml/v2 v2.0.0-beta.8/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pires/go-proxyproto v0.6.2 h1:KAZ7UteSOt6urjme6ZldyFm4wDe/z0ZUP0Yv0Dos0d8= github.com/pires/go-proxyproto v0.6.2/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= @@ -708,8 +710,8 @@ github.com/spf13/jwalterweatherman v1.1.0 h1:ue6voC5bR5F8YxI5S67j9i582FU4Qvo2bmq github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/spf13/viper v1.10.1 h1:nuJZuYpG7gTj/XqiUwg8bA0cp1+M2mC3J4g5luUYBKk= -github.com/spf13/viper v1.10.1/go.mod h1:IGlFPqhNAPKRxohIzWpI5QEy4kuI7tcl5WvR+8qy1rU= +github.com/spf13/viper v1.11.0 h1:7OX/1FS6n7jHD1zGrZTM7WtY13ZELRyosK4k93oPr44= +github.com/spf13/viper v1.11.0/go.mod h1:djo0X/bA5+tYVoCn+C7cAYJGcVn/qYLFTG8gdUsX7Zk= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= @@ -934,8 +936,8 @@ golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220412071739-889880a91fd5 h1:NubxfvTRuNb4RVzWrIDAUzUvREH1HkCD4JjyQTSG9As= -golang.org/x/sys v0.0.0-20220412071739-889880a91fd5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -1166,8 +1168,9 @@ google.golang.org/genproto v0.0.0-20220310185008-1973136f34c6/go.mod h1:kGP+zUP2 google.golang.org/genproto v0.0.0-20220324131243-acbaeb5b85eb/go.mod h1:hAL49I2IFola2sVEjAn7MEwsja0xp51I0tlGAf9hz4E= google.golang.org/genproto v0.0.0-20220401170504-314d38edb7de/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo= google.golang.org/genproto v0.0.0-20220405205423-9d709892a2bf/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo= -google.golang.org/genproto v0.0.0-20220407144326-9054f6ed7bac h1:qSNTkEN+L2mvWcLgJOR+8bdHX9rN/IdU3A1Ghpfb1Rg= google.golang.org/genproto v0.0.0-20220407144326-9054f6ed7bac/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo= +google.golang.org/genproto v0.0.0-20220413183235-5e96e2839df9 h1:XGQ6tc+EnM35IAazg4y6AHmUg4oK8NXsXaILte1vRlk= +google.golang.org/genproto v0.0.0-20220413183235-5e96e2839df9/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo= google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= diff --git a/httpd/api_http_user.go b/httpd/api_http_user.go index 8be88184..2d282e84 100644 --- a/httpd/api_http_user.go +++ b/httpd/api_http_user.go @@ -34,7 +34,7 @@ func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, err connID := xid.New().String() protocol := getProtocolFromRequest(r) connectionID := fmt.Sprintf("%v_%v", protocol, connID) - if err := checkHTTPClientUser(&user, r, connectionID); err != nil { + if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil { sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden) return nil, err } @@ -43,6 +43,10 @@ func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, err r.RemoteAddr, user), request: r, } + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) + return connection, err + } return connection, nil } @@ -52,7 +56,6 @@ func readUserFolder(w http.ResponseWriter, r *http.Request) { if err != nil { return } - common.Connections.Add(connection) defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) @@ -70,7 +73,6 @@ func createUserDir(w http.ResponseWriter, r *http.Request) { if err != nil { return } - common.Connections.Add(connection) defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) @@ -90,22 +92,7 @@ func createUserDir(w http.ResponseWriter, r *http.Request) { func renameUserDir(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - connection, err := getUserConnection(w, r) - if err != nil { - return - } - common.Connections.Add(connection) - defer common.Connections.Remove(connection.GetID()) - - oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path")) - newName := connection.User.GetCleanedPath(r.URL.Query().Get("target")) - err = connection.Rename(oldName, newName) - if err != nil { - sendAPIResponse(w, r, err, fmt.Sprintf("Unable to rename directory %#v to %#v", oldName, newName), - getMappedStatusCode(err)) - return - } - sendAPIResponse(w, r, nil, fmt.Sprintf("Directory %#v renamed to %#v", oldName, newName), http.StatusOK) + renameItem(w, r) } func deleteUserDir(w http.ResponseWriter, r *http.Request) { @@ -114,7 +101,6 @@ func deleteUserDir(w http.ResponseWriter, r *http.Request) { if err != nil { return } - common.Connections.Add(connection) defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) @@ -132,7 +118,6 @@ func getUserFile(w http.ResponseWriter, r *http.Request) { if err != nil { return } - common.Connections.Add(connection) defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) @@ -183,7 +168,6 @@ func setFileDirMetadata(w http.ResponseWriter, r *http.Request) { if err != nil { return } - common.Connections.Add(connection) defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) @@ -214,7 +198,6 @@ func uploadUserFile(w http.ResponseWriter, r *http.Request) { if err != nil { return } - common.Connections.Add(connection) defer common.Connections.Remove(connection.GetID()) filePath := connection.User.GetCleanedPath(r.URL.Query().Get("path")) @@ -258,6 +241,8 @@ func uploadUserFiles(w http.ResponseWriter, r *http.Request) { if err != nil { return } + defer common.Connections.Remove(connection.GetID()) + transferQuota := connection.GetTransferQuota() if !transferQuota.HasUploadSpace() { connection.Log(logger.LevelInfo, "denying file write due to transfer quota limits") @@ -265,8 +250,6 @@ func uploadUserFiles(w http.ResponseWriter, r *http.Request) { http.StatusRequestEntityTooLarge) return } - common.Connections.Add(connection) - defer common.Connections.Remove(connection.GetID()) t := newThrottledReader(r.Body, connection.User.UploadBandwidth, connection) r.Body = t @@ -332,22 +315,7 @@ func doUploadFiles(w http.ResponseWriter, r *http.Request, connection *Connectio func renameUserFile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - connection, err := getUserConnection(w, r) - if err != nil { - return - } - common.Connections.Add(connection) - defer common.Connections.Remove(connection.GetID()) - - oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path")) - newName := connection.User.GetCleanedPath(r.URL.Query().Get("target")) - err = connection.Rename(oldName, newName) - if err != nil { - sendAPIResponse(w, r, err, fmt.Sprintf("Unable to rename file %#v to %#v", oldName, newName), - getMappedStatusCode(err)) - return - } - sendAPIResponse(w, r, nil, fmt.Sprintf("File %#v renamed to %#v", oldName, newName), http.StatusOK) + renameItem(w, r) } func deleteUserFile(w http.ResponseWriter, r *http.Request) { @@ -356,7 +324,6 @@ func deleteUserFile(w http.ResponseWriter, r *http.Request) { if err != nil { return } - common.Connections.Add(connection) defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) @@ -393,7 +360,6 @@ func getUserFilesAsZipStream(w http.ResponseWriter, r *http.Request) { if err != nil { return } - common.Connections.Add(connection) defer common.Connections.Remove(connection.GetID()) var filesList []string @@ -581,3 +547,21 @@ func setModificationTimeFromHeader(r *http.Request, c *Connection, filePath stri } } } + +func renameItem(w http.ResponseWriter, r *http.Request) { + connection, err := getUserConnection(w, r) + if err != nil { + return + } + defer common.Connections.Remove(connection.GetID()) + + oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + newName := connection.User.GetCleanedPath(r.URL.Query().Get("target")) + err = connection.Rename(oldName, newName) + if err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Unable to rename %#v -> %#v", oldName, newName), + getMappedStatusCode(err)) + return + } + sendAPIResponse(w, r, nil, fmt.Sprintf("%#v renamed to %#v", oldName, newName), http.StatusOK) +} diff --git a/httpd/api_shares.go b/httpd/api_shares.go index 7b8b26bb..4d034c88 100644 --- a/httpd/api_shares.go +++ b/httpd/api_shares.go @@ -167,7 +167,10 @@ func (s *httpdServer) readBrowsableShareContents(w http.ResponseWriter, r *http. return } - common.Connections.Add(connection) + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) + return + } defer common.Connections.Remove(connection.GetID()) contents, err := connection.ReadDir(name) @@ -194,7 +197,10 @@ func (s *httpdServer) downloadBrowsableSharedFile(w http.ResponseWriter, r *http return } - common.Connections.Add(connection) + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) + return + } defer common.Connections.Remove(connection.GetID()) info, err := connection.Stat(name, 1) @@ -231,7 +237,10 @@ func (s *httpdServer) downloadFromShare(w http.ResponseWriter, r *http.Request) return } - common.Connections.Add(connection) + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) + return + } defer common.Connections.Remove(connection.GetID()) compress := true @@ -289,7 +298,10 @@ func (s *httpdServer) uploadFileToShare(w http.ResponseWriter, r *http.Request) } dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck - common.Connections.Add(connection) + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) + return + } defer common.Connections.Remove(connection.GetID()) if err := doUploadFile(w, r, connection, filePath); err != nil { dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck @@ -313,7 +325,10 @@ func (s *httpdServer) uploadFilesToShare(w http.ResponseWriter, r *http.Request) return } - common.Connections.Add(connection) + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) + return + } defer common.Connections.Remove(connection.GetID()) t := newThrottledReader(r.Body, connection.User.UploadBandwidth, connection) diff --git a/httpd/api_utils.go b/httpd/api_utils.go index 4dabf0a1..15763a11 100644 --- a/httpd/api_utils.go +++ b/httpd/api_utils.go @@ -498,7 +498,7 @@ func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err err dataprovider.ExecutePostLoginHook(user, loginMethod, ip, protocol, err) } -func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string) error { +func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string, checkSessions bool) error { if util.IsStringInSlice(common.ProtocolHTTP, user.Filters.DeniedProtocols) { logger.Info(logSender, connectionID, "cannot login user %#v, protocol HTTP is not allowed", user.Username) return fmt.Errorf("protocol HTTP is not allowed for user %#v", user.Username) @@ -507,7 +507,7 @@ func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID logger.Info(logSender, connectionID, "cannot login user %#v, password login method is not allowed", user.Username) return fmt.Errorf("login method password is not allowed for user %#v", user.Username) } - if user.MaxSessions > 0 { + if checkSessions && user.MaxSessions > 0 { activeSessions := common.Connections.GetActiveSessions(user.Username) if activeSessions >= user.MaxSessions { logger.Info(logSender, connectionID, "authentication refused for user: %#v, too many open sessions: %v/%v", user.Username, diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index c183b4d0..b4730af6 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -3857,7 +3857,8 @@ func TestCloseActiveConnection(t *testing.T) { fakeConn := &fakeConnection{ BaseConnection: c, } - common.Connections.Add(fakeConn) + err = common.Connections.Add(fakeConn) + assert.NoError(t, err) _, err = httpdtest.CloseConnection(c.GetID(), http.StatusOK) assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(), 0) @@ -3870,12 +3871,14 @@ func TestCloseConnectionAfterUserUpdateDelete(t *testing.T) { fakeConn := &fakeConnection{ BaseConnection: c, } - common.Connections.Add(fakeConn) + err = common.Connections.Add(fakeConn) + assert.NoError(t, err) c1 := common.NewBaseConnection("connID1", common.ProtocolSFTP, "", "", user) fakeConn1 := &fakeConnection{ BaseConnection: c1, } - common.Connections.Add(fakeConn1) + err = common.Connections.Add(fakeConn1) + assert.NoError(t, err) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "0") assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(), 2) @@ -3883,8 +3886,10 @@ func TestCloseConnectionAfterUserUpdateDelete(t *testing.T) { assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(), 0) - common.Connections.Add(fakeConn) - common.Connections.Add(fakeConn1) + err = common.Connections.Add(fakeConn) + assert.NoError(t, err) + err = common.Connections.Add(fakeConn1) + assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(), 2) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) @@ -5173,7 +5178,8 @@ func TestLoaddataMode(t *testing.T) { fakeConn := &fakeConnection{ BaseConnection: c, } - common.Connections.Add(fakeConn) + err = common.Connections.Add(fakeConn) + assert.NoError(t, err) assert.Len(t, common.Connections.GetStats(), 1) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) @@ -8714,7 +8720,8 @@ func TestWebClientMaxConnections(t *testing.T) { connection := &httpd.Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user), } - common.Connections.Add(connection) + err = common.Connections.Add(connection) + assert.NoError(t, err) _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) @@ -8895,20 +8902,57 @@ func TestMaxSessions(t *testing.T) { u.Email = "user@session.com" user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) - _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) - _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + apiToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) // now add a fake connection fs := vfs.NewOsFs("id", os.TempDir(), "") connection := &httpd.Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user), } - common.Connections.Add(connection) + err = common.Connections.Add(connection) + assert.NoError(t, err) _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) + // try an user API call + req, err := http.NewRequest(http.MethodGet, userDirsPath+"/?path=%2F", nil) + assert.NoError(t, err) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + // web client requests + req, err = http.NewRequest(http.MethodGet, webClientDownloadZipPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + req, err = http.NewRequest(http.MethodGet, webClientDirsPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + req, err = http.NewRequest(http.MethodGet, webClientFilesPath+"?path=p", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path=file", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + // test reset password smtpCfg := smtp.Config{ Host: "127.0.0.1", @@ -8924,11 +8968,11 @@ func TestMaxSessions(t *testing.T) { form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) lastResetCode = "" - req, err := http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - rr := executeRequest(req) + rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.GreaterOrEqual(t, len(lastResetCode), 20) form = make(url.Values) @@ -9630,6 +9674,123 @@ func TestShareUsage(t *testing.T) { executeRequest(req) } +func TestShareMaxSessions(t *testing.T) { + u := getTestUser() + u.MaxSessions = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + share := dataprovider.Share{ + Name: "test share max sessions read", + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + } + asJSON, err := json.Marshal(share) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/dirs", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // add a fake connection + fs := vfs.NewOsFs("id", os.TempDir(), "") + connection := &httpd.Connection{ + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user), + } + err = common.Connections.Add(connection) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/dirs", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"/dirs", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"/browse", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/files?path=afile", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + req, err = http.NewRequest(http.MethodDelete, userSharesPath+"/"+objectID, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + // now test a write share + share = dataprovider.Share{ + Name: "test share max sessions write", + Scope: dataprovider.ShareScopeWrite, + Paths: []string{"/"}, + } + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID = rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "file.txt"), bytes.NewBuffer([]byte("content"))) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part1, err := writer.CreateFormFile("filenames", "file1.txt") + assert.NoError(t, err) + _, err = part1.Write([]byte("file1 content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + common.Connections.Remove(connection.GetID()) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Len(t, common.Connections.GetStats(), 0) +} + func TestShareUploadSingle(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) diff --git a/httpd/middleware.go b/httpd/middleware.go index 01106f99..baefa48c 100644 --- a/httpd/middleware.go +++ b/httpd/middleware.go @@ -434,7 +434,7 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu return err } connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String()) - if err := checkHTTPClientUser(&user, r, connectionID); err != nil { + if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err) return err } diff --git a/httpd/oidc.go b/httpd/oidc.go index 11950855..4b6dbf20 100644 --- a/httpd/oidc.go +++ b/httpd/oidc.go @@ -331,7 +331,7 @@ func (t *oidcToken) getUser(r *http.Request) error { return err } connectionID := fmt.Sprintf("%v_%v", common.ProtocolOIDC, xid.New().String()) - if err := checkHTTPClientUser(&user, r, connectionID); err != nil { + if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodIDP, ipAddr, err) return err } diff --git a/httpd/server.go b/httpd/server.go index b7946efc..61f95d0b 100644 --- a/httpd/server.go +++ b/httpd/server.go @@ -228,7 +228,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re return } connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String()) - if err := checkHTTPClientUser(&user, r, connectionID); err != nil { + if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err) s.renderClientLoginPage(w, err.Error(), ipAddr) return @@ -268,7 +268,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r return } connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String()) - if err := checkHTTPClientUser(user, r, connectionID); err != nil { + if err := checkHTTPClientUser(user, r, connectionID, true); err != nil { s.renderClientResetPwdPage(w, fmt.Sprintf("Password reset successfully but unable to login: %v", err.Error()), ipAddr) return } @@ -760,7 +760,7 @@ func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) { return } connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String()) - if err := checkHTTPClientUser(&user, r, connectionID); err != nil { + if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err) sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden) return @@ -920,7 +920,7 @@ func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, logger.Debug(logSender, "", "signature mismatch for user %#v, unable to refresh cookie", user.Username) return } - if err := checkHTTPClientUser(&user, r, xid.New().String()); err != nil { + if err := checkHTTPClientUser(&user, r, xid.New().String(), true); err != nil { logger.Debug(logSender, "", "unable to refresh cookie for user %#v: %v", user.Username, err) return } diff --git a/httpd/webclient.go b/httpd/webclient.go index ab2130b6..38598418 100644 --- a/httpd/webclient.go +++ b/httpd/webclient.go @@ -595,7 +595,7 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http. connID := xid.New().String() protocol := getProtocolFromRequest(r) connectionID := fmt.Sprintf("%v_%v", protocol, connID) - if err := checkHTTPClientUser(&user, r, connectionID); err != nil { + if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil { s.renderClientForbiddenPage(w, r, err.Error()) return } @@ -604,7 +604,10 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http. r.RemoteAddr, user), request: r, } - common.Connections.Add(connection) + if err = common.Connections.Add(connection); err != nil { + s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "") + return + } defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) @@ -635,7 +638,10 @@ func (s *httpdServer) handleShareGetDirContents(w http.ResponseWriter, r *http.R s.renderClientMessagePage(w, r, "Invalid share path", "", getRespStatus(err), err, "") return } - common.Connections.Add(connection) + if err = common.Connections.Add(connection); err != nil { + s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "") + return + } defer common.Connections.Remove(connection.GetID()) contents, err := connection.ReadDir(name) @@ -691,7 +697,10 @@ func (s *httpdServer) handleShareGetFiles(w http.ResponseWriter, r *http.Request return } - common.Connections.Add(connection) + if err = common.Connections.Add(connection); err != nil { + s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "") + return + } defer common.Connections.Remove(connection.GetID()) var info os.FileInfo @@ -735,7 +744,7 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http. connID := xid.New().String() protocol := getProtocolFromRequest(r) connectionID := fmt.Sprintf("%v_%v", protocol, connID) - if err := checkHTTPClientUser(&user, r, connectionID); err != nil { + if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil { sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } @@ -744,7 +753,10 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http. r.RemoteAddr, user), request: r, } - common.Connections.Add(connection) + if err = common.Connections.Add(connection); err != nil { + s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "") + return + } defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) @@ -809,7 +821,7 @@ func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Reques connID := xid.New().String() protocol := getProtocolFromRequest(r) connectionID := fmt.Sprintf("%v_%v", protocol, connID) - if err := checkHTTPClientUser(&user, r, connectionID); err != nil { + if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil { s.renderClientForbiddenPage(w, r, err.Error()) return } @@ -818,7 +830,10 @@ func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Reques r.RemoteAddr, user), request: r, } - common.Connections.Add(connection) + if err = common.Connections.Add(connection); err != nil { + s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "") + return + } defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) @@ -866,7 +881,7 @@ func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Reques connID := xid.New().String() protocol := getProtocolFromRequest(r) connectionID := fmt.Sprintf("%v_%v", protocol, connID) - if err := checkHTTPClientUser(&user, r, connectionID); err != nil { + if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil { s.renderClientForbiddenPage(w, r, err.Error()) return } @@ -875,7 +890,10 @@ func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Reques r.RemoteAddr, user), request: r, } - common.Connections.Add(connection) + if err = common.Connections.Add(connection); err != nil { + s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "") + return + } defer common.Connections.Remove(connection.GetID()) name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) diff --git a/sftpd/cryptfs_test.go b/sftpd/cryptfs_test.go index 22de35f4..8c14d1c1 100644 --- a/sftpd/cryptfs_test.go +++ b/sftpd/cryptfs_test.go @@ -369,7 +369,7 @@ func TestTruncate(t *testing.T) { } func TestSCPBasicHandlingCryptoFs(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -427,7 +427,7 @@ func TestSCPBasicHandlingCryptoFs(t *testing.T) { } func TestSCPRecursiveCryptFs(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true diff --git a/sftpd/handler.go b/sftpd/handler.go index 39552a67..f8c2a76c 100644 --- a/sftpd/handler.go +++ b/sftpd/handler.go @@ -457,8 +457,12 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO return t, nil } -// Disconnect disconnects the client closing the network connection +// Disconnect disconnects the client by closing the channel func (c *Connection) Disconnect() error { + if c.channel == nil { + c.Log(logger.LevelWarn, "cannot disconnect a nil channel") + return nil + } return c.channel.Close() } diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index e210472c..d6c74e92 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -14,6 +14,7 @@ import ( "github.com/eikenb/pipeat" "github.com/pkg/sftp" + "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -2152,3 +2153,45 @@ func TestLoadRevokedUserCertsFile(t *testing.T) { err = os.RemoveAll(r.filePath) assert.NoError(t, err) } + +func TestMaxUserSessions(t *testing.T) { + connection := &Connection{ + BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "user_max_sessions", + HomeDir: filepath.Clean(os.TempDir()), + MaxSessions: 1, + }, + }), + } + err := common.Connections.Add(connection) + assert.NoError(t, err) + + c := Configuration{} + c.handleSftpConnection(nil, connection) + + sshCmd := sshCommand{ + command: "cd", + connection: connection, + } + err = sshCmd.handle() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "too many open sessions") + } + scpCmd := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + }, + } + err = scpCmd.handle() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "too many open sessions") + } + err = ServeSubSystemConnection(&connection.User, connection.ID, nil, nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "too many open sessions") + } + common.Connections.Remove(connection.GetID()) + assert.Len(t, common.Connections.GetStats(), 0) +} diff --git a/sftpd/scp.go b/sftpd/scp.go index 3a501de2..211b3b4c 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -37,7 +37,10 @@ func (c *scpCommand) handle() (err error) { err = common.ErrGenericFailure } }() - common.Connections.Add(c.connection) + if err := common.Connections.Add(c.connection); err != nil { + logger.Info(logSender, "", "unable to add SCP connection: %v", err) + return err + } defer common.Connections.Remove(c.connection.GetID()) destPath := c.getDestPath() diff --git a/sftpd/server.go b/sftpd/server.go index 8b2e2967..6931d545 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -562,7 +562,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve case "subsystem": if string(req.Payload[4:]) == "sftp" { ok = true - connection := Connection{ + connection := &Connection{ BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, conn.LocalAddr().String(), conn.RemoteAddr().String(), user), ClientVersion: string(sconn.ClientVersion()), @@ -571,7 +571,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve channel: channel, folderPrefix: c.FolderPrefix, } - go c.handleSftpConnection(channel, &connection) + go c.handleSftpConnection(channel, connection) } case "exec": // protocol will be set later inside processSSHCommand it could be SSH or SCP @@ -600,7 +600,11 @@ func (c *Configuration) handleSftpConnection(channel ssh.Channel, connection *Co logger.Error(logSender, "", "panic in handleSftpConnection: %#v stack strace: %v", r, string(debug.Stack())) } }() - common.Connections.Add(connection) + if err := common.Connections.Add(connection); err != nil { + errClose := connection.Disconnect() + logger.Info(logSender, "", "unable to add connection: %v, close err: %v", err, errClose) + return + } defer common.Connections.Remove(connection.GetID()) // Create the server instance for the channel using the handler we created above. diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index a123e1f4..991524c6 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -3,6 +3,7 @@ package sftpd_test import ( "bufio" "bytes" + "context" "crypto/rand" "crypto/sha256" "crypto/sha512" @@ -129,6 +130,7 @@ var ( allPerms = []string{dataprovider.PermAny} homeBasePath string scpPath string + scpForce bool gitPath string sshPath string hookCmdPath string @@ -935,8 +937,6 @@ func TestConcurrency(t *testing.T) { conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { - err = checkBasicSFTP(client) - assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName+strconv.Itoa(counter), testFileSize, client) assert.NoError(t, err) assert.Greater(t, common.Connections.GetActiveSessions(defaultUsername), 0) @@ -9231,7 +9231,7 @@ func TestGitErrors(t *testing.T) { // Start SCP tests func TestSCPBasicHandling(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -9295,7 +9295,7 @@ func TestSCPBasicHandling(t *testing.T) { } func TestSCPUploadFileOverwrite(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -9376,7 +9376,7 @@ func TestSCPUploadFileOverwrite(t *testing.T) { } func TestSCPRecursive(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -9485,7 +9485,7 @@ func TestSCPStartDirectory(t *testing.T) { } func TestSCPPatternsFilter(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -9606,7 +9606,7 @@ func TestSCPUploadMaxSize(t *testing.T) { } func TestSCPVirtualFolders(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -9658,7 +9658,7 @@ func TestSCPVirtualFolders(t *testing.T) { } func TestSCPNestedFolders(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } baseUser, resp, err := httpdtest.AddUser(getTestUser(false), http.StatusCreated) @@ -9791,7 +9791,7 @@ func TestSCPNestedFolders(t *testing.T) { } func TestSCPVirtualFoldersQuota(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -9889,7 +9889,7 @@ func TestSCPVirtualFoldersQuota(t *testing.T) { } func TestSCPPermsSubDirs(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -9929,7 +9929,7 @@ func TestSCPPermsSubDirs(t *testing.T) { } func TestSCPPermCreateDirs(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -9963,7 +9963,7 @@ func TestSCPPermCreateDirs(t *testing.T) { } func TestSCPPermUpload(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -9987,7 +9987,7 @@ func TestSCPPermUpload(t *testing.T) { } func TestSCPPermOverwrite(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -10014,7 +10014,7 @@ func TestSCPPermOverwrite(t *testing.T) { } func TestSCPPermDownload(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -10043,7 +10043,7 @@ func TestSCPPermDownload(t *testing.T) { } func TestSCPQuotaSize(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -10099,7 +10099,7 @@ func TestSCPQuotaSize(t *testing.T) { } func TestSCPEscapeHomeDir(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -10135,7 +10135,7 @@ func TestSCPEscapeHomeDir(t *testing.T) { } func TestSCPUploadPaths(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -10170,7 +10170,7 @@ func TestSCPUploadPaths(t *testing.T) { } func TestSCPOverwriteDirWithFile(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } usePubKey := true @@ -10194,7 +10194,7 @@ func TestSCPOverwriteDirWithFile(t *testing.T) { } func TestSCPRemoteToRemote(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } if runtime.GOOS == osWindows { @@ -10229,7 +10229,7 @@ func TestSCPRemoteToRemote(t *testing.T) { } func TestSCPErrors(t *testing.T) { - if len(scpPath) == 0 { + if scpPath == "" { t.Skip("scp command not found, unable to execute this test") } u := getTestUser(true) @@ -10682,6 +10682,9 @@ func getScpDownloadCommand(localPath, remotePath string, preserveTime, recursive if recursive { args = append(args, "-r") } + if scpForce { + args = append(args, "-O") + } args = append(args, "-P") args = append(args, "2022") args = append(args, "-o") @@ -10707,6 +10710,9 @@ func getScpUploadCommand(localPath, remotePath string, preserveTime, remoteToRem args = append(args, "-r") } } + if scpForce { + args = append(args, "-O") + } args = append(args, "-P") args = append(args, "2022") args = append(args, "-o") @@ -10770,6 +10776,12 @@ func checkSystemCommands() { logger.Warn(logSender, "", "unable to get scp command. SCP tests will be skipped, err: %v", err) logger.WarnToConsole("unable to get scp command. SCP tests will be skipped, err: %v", err) scpPath = "" + } else { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + cmd := exec.CommandContext(ctx, scpPath, "-O") + out, _ := cmd.CombinedOutput() + scpForce = !strings.Contains(string(out), "option -- O") } } diff --git a/sftpd/ssh_cmd.go b/sftpd/ssh_cmd.go index 199d9283..40254427 100644 --- a/sftpd/ssh_cmd.go +++ b/sftpd/ssh_cmd.go @@ -115,7 +115,10 @@ func (c *sshCommand) handle() (err error) { err = common.ErrGenericFailure } }() - common.Connections.Add(c.connection) + if err := common.Connections.Add(c.connection); err != nil { + logger.Info(logSender, "", "unable to add SSH command connection: %v", err) + return err + } defer common.Connections.Remove(c.connection.GetID()) c.connection.UpdateLastActivity() @@ -131,7 +134,7 @@ func (c *sshCommand) handle() (err error) { c.sendExitStatus(nil) } else if c.command == "pwd" { // hard coded response to "/" - c.connection.channel.Write([]byte("/\n")) //nolint:errcheck + c.connection.channel.Write([]byte(util.CleanPath(c.connection.User.Filters.StartDirectory) + "\n")) //nolint:errcheck c.sendExitStatus(nil) } else if c.command == "sftpgo-copy" { return c.handleSFTPGoCopy() diff --git a/sftpd/subsystem.go b/sftpd/subsystem.go index bcd0d1d2..879bb14a 100644 --- a/sftpd/subsystem.go +++ b/sftpd/subsystem.go @@ -43,7 +43,6 @@ func ServeSubSystemConnection(user *dataprovider.User, connectionID string, read logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose) return err } - dataprovider.UpdateLastLogin(user) connection := &Connection{ BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolSFTP, "", "", *user), @@ -52,9 +51,15 @@ func ServeSubSystemConnection(user *dataprovider.User, connectionID string, read LocalAddr: &net.IPAddr{}, channel: newSubsystemChannel(reader, writer), } - common.Connections.Add(connection) + err = common.Connections.Add(connection) + if err != nil { + errClose := user.CloseFs() + logger.Warn(logSender, connectionID, "unable to add connection: %v close fs error: %v", err, errClose) + return err + } defer common.Connections.Remove(connection.GetID()) + dataprovider.UpdateLastLogin(user) server := sftp.NewRequestServer(connection.channel, sftp.Handlers{ FileGet: connection, FilePut: connection, diff --git a/vfs/cryptfs.go b/vfs/cryptfs.go index 9a0515dc..5c69c3df 100644 --- a/vfs/cryptfs.go +++ b/vfs/cryptfs.go @@ -203,18 +203,14 @@ func (fs *CryptFs) ReadDir(dirname string) ([]os.FileInfo, error) { if err != nil { return nil, err } - entries, err := f.ReadDir(-1) + list, err := f.Readdir(-1) f.Close() if err != nil { return nil, err } - result := make([]os.FileInfo, len(entries)) - for idx, entry := range entries { - info, err := entry.Info() - if err != nil { - return nil, err - } - result[idx] = fs.ConvertFileInfo(info) + result := make([]os.FileInfo, 0, len(list)) + for _, info := range list { + result = append(result, fs.ConvertFileInfo(info)) } return result, nil } diff --git a/vfs/osfs.go b/vfs/osfs.go index c75c2733..bc7d2323 100644 --- a/vfs/osfs.go +++ b/vfs/osfs.go @@ -3,7 +3,6 @@ package vfs import ( "fmt" "io" - "io/fs" "net/http" "os" "path" @@ -171,20 +170,12 @@ func (*OsFs) ReadDir(dirname string) ([]os.FileInfo, error) { if err != nil { return nil, err } - entries, err := f.ReadDir(-1) + list, err := f.Readdir(-1) f.Close() if err != nil { return nil, err } - result := make([]fs.FileInfo, len(entries)) - for idx, entry := range entries { - info, err := entry.Info() - if err != nil { - return nil, err - } - result[idx] = info - } - return result, nil + return list, nil } // IsUploadResumeSupported returns true if resuming uploads is supported diff --git a/webdavd/server.go b/webdavd/server.go index d9a981b1..55ee4c37 100644 --- a/webdavd/server.go +++ b/webdavd/server.go @@ -196,19 +196,25 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - updateLoginMetrics(&user, ipAddr, loginMethod, err) - - ctx := context.WithValue(r.Context(), requestIDKey, connectionID) - ctx = context.WithValue(ctx, requestStartKey, time.Now()) - connection := &Connection{ BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolWebDAV, util.GetHTTPLocalAddress(r), r.RemoteAddr, user), request: r, } - common.Connections.Add(connection) + if err = common.Connections.Add(connection); err != nil { + errClose := user.CloseFs() + logger.Warn(logSender, connectionID, "unable add connection: %v close fs error: %v", err, errClose) + updateLoginMetrics(&user, ipAddr, loginMethod, err) + http.Error(w, err.Error(), http.StatusTooManyRequests) + return + } defer common.Connections.Remove(connection.GetID()) + updateLoginMetrics(&user, ipAddr, loginMethod, err) + + ctx := context.WithValue(r.Context(), requestIDKey, connectionID) + ctx = context.WithValue(ctx, requestStartKey, time.Now()) + dataprovider.UpdateLastLogin(&user) if s.checkRequestMethod(ctx, r, connection) { @@ -311,14 +317,6 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo user.Username, loginMethod) return connID, fmt.Errorf("login method %v is not allowed for user %#v", loginMethod, user.Username) } - if user.MaxSessions > 0 { - activeSessions := common.Connections.GetActiveSessions(user.Username) - if activeSessions >= user.MaxSessions { - logger.Info(logSender, connID, "authentication refused for user: %#v, too many open sessions: %v/%v", - user.Username, activeSessions, user.MaxSessions) - return connID, fmt.Errorf("too many open sessions: %v", activeSessions) - } - } if !user.IsLoginFromAddrAllowed(r.RemoteAddr) { logger.Info(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, r.RemoteAddr) diff --git a/webdavd/webdavd_test.go b/webdavd/webdavd_test.go index 8eee2ae7..a6f03eb4 100644 --- a/webdavd/webdavd_test.go +++ b/webdavd/webdavd_test.go @@ -1167,7 +1167,8 @@ func TestMaxConnections(t *testing.T) { connection := &webdavd.Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), } - common.Connections.Add(connection) + err = common.Connections.Add(connection) + assert.NoError(t, err) assert.Error(t, checkBasicFunc(client)) common.Connections.Remove(connection.GetID()) _, err = httpdtest.RemoveUser(user, http.StatusOK) @@ -1222,7 +1223,8 @@ func TestMaxSessions(t *testing.T) { connection := &webdavd.Connection{ BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), } - common.Connections.Add(connection) + err = common.Connections.Add(connection) + assert.NoError(t, err) assert.Error(t, checkBasicFunc(client)) common.Connections.Remove(connection.GetID()) _, err = httpdtest.RemoveUser(user, http.StatusOK)