Explorar o código

refactoring of user session counters

Fixes #792

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino %!s(int64=3) %!d(string=hai) anos
pai
achega
002a06629e

+ 47 - 8
common/common.go

@@ -98,6 +98,7 @@ func init() {
 	Connections.clients = clientsMap{
 		clients: make(map[string]int),
 	}
+	Connections.perUserConns = make(map[string]int)
 }
 
 // errors definitions
@@ -345,6 +346,7 @@ type ActiveTransfer interface {
 type ActiveConnection interface {
 	GetID() string
 	GetUsername() string
+	GetMaxSessions() int
 	GetLocalAddress() string
 	GetRemoteAddress() string
 	GetClientVersion() string
@@ -733,6 +735,29 @@ type ActiveConnections struct {
 	sync.RWMutex
 	connections    []ActiveConnection
 	sshConnections []*SSHConnection
+	perUserConns   map[string]int
+}
+
+// internal method, must be called within a locked block
+func (conns *ActiveConnections) addUserConnection(username string) {
+	if username == "" {
+		return
+	}
+	conns.perUserConns[username]++
+}
+
+// internal method, must be called within a locked block
+func (conns *ActiveConnections) removeUserConnection(username string) {
+	if username == "" {
+		return
+	}
+	if val, ok := conns.perUserConns[username]; ok {
+		conns.perUserConns[username]--
+		if val > 1 {
+			return
+		}
+		delete(conns.perUserConns, username)
+	}
 }
 
 // GetActiveSessions returns the number of active sessions for the given username.
@@ -741,24 +766,27 @@ func (conns *ActiveConnections) GetActiveSessions(username string) int {
 	conns.RLock()
 	defer conns.RUnlock()
 
-	numSessions := 0
-	for _, c := range conns.connections {
-		if c.GetUsername() == username {
-			numSessions++
-		}
-	}
-	return numSessions
+	return conns.perUserConns[username]
 }
 
 // Add adds a new connection to the active ones
-func (conns *ActiveConnections) Add(c ActiveConnection) {
+func (conns *ActiveConnections) Add(c ActiveConnection) error {
 	conns.Lock()
 	defer conns.Unlock()
 
+	if username := c.GetUsername(); username != "" {
+		if maxSessions := c.GetMaxSessions(); maxSessions > 0 {
+			if val := conns.perUserConns[username]; val >= maxSessions {
+				return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
+			}
+		}
+		conns.addUserConnection(username)
+	}
 	conns.connections = append(conns.connections, c)
 	metric.UpdateActiveConnectionsSize(len(conns.connections))
 	logger.Debug(c.GetProtocol(), c.GetID(), "connection added, local address %#v, remote address %#v, num open connections: %v",
 		c.GetLocalAddress(), c.GetRemoteAddress(), len(conns.connections))
+	return nil
 }
 
 // Swap replaces an existing connection with the given one.
@@ -771,6 +799,16 @@ func (conns *ActiveConnections) Swap(c ActiveConnection) error {
 
 	for idx, conn := range conns.connections {
 		if conn.GetID() == c.GetID() {
+			conns.removeUserConnection(conn.GetUsername())
+			if username := c.GetUsername(); username != "" {
+				if maxSessions := c.GetMaxSessions(); maxSessions > 0 {
+					if val := conns.perUserConns[username]; val >= maxSessions {
+						conns.addUserConnection(conn.GetUsername())
+						return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
+					}
+				}
+				conns.addUserConnection(username)
+			}
 			err := conn.CloseFS()
 			conns.connections[idx] = c
 			logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err)
@@ -793,6 +831,7 @@ func (conns *ActiveConnections) Remove(connectionID string) {
 			conns.connections[idx] = conns.connections[lastIdx]
 			conns.connections[lastIdx] = nil
 			conns.connections = conns.connections[:lastIdx]
+			conns.removeUserConnection(conn.GetUsername())
 			metric.UpdateActiveConnectionsSize(lastIdx)
 			logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %#v, remote address %#v close fs error: %v, num open connections: %v",
 				conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)

+ 58 - 12
common/common_test.go

@@ -370,6 +370,29 @@ func TestWhitelist(t *testing.T) {
 	Config = configCopy
 }
 
+func TestUserMaxSessions(t *testing.T) {
+	c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{
+		BaseUser: sdk.BaseUser{
+			Username:    userTestUsername,
+			MaxSessions: 1,
+		},
+	})
+	fakeConn := &fakeConnection{
+		BaseConnection: c,
+	}
+	err := Connections.Add(fakeConn)
+	assert.NoError(t, err)
+	err = Connections.Add(fakeConn)
+	assert.Error(t, err)
+	err = Connections.Swap(fakeConn)
+	assert.NoError(t, err)
+	Connections.Remove(fakeConn.GetID())
+	Connections.Lock()
+	Connections.removeUserConnection(userTestUsername)
+	Connections.Unlock()
+	assert.Len(t, Connections.GetStats(), 0)
+}
+
 func TestMaxConnections(t *testing.T) {
 	oldValue := Config.MaxTotalConnections
 	perHost := Config.MaxPerHostConnections
@@ -387,7 +410,8 @@ func TestMaxConnections(t *testing.T) {
 	fakeConn := &fakeConnection{
 		BaseConnection: c,
 	}
-	Connections.Add(fakeConn)
+	err := Connections.Add(fakeConn)
+	assert.NoError(t, err)
 	assert.Len(t, Connections.GetStats(), 1)
 	assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
 
@@ -466,14 +490,16 @@ func TestIdleConnections(t *testing.T) {
 	sshConn1.lastActivity = c.lastActivity
 	sshConn2.lastActivity = c.lastActivity
 	Connections.AddSSHConnection(sshConn1)
-	Connections.Add(fakeConn)
+	err = Connections.Add(fakeConn)
+	assert.NoError(t, err)
 	assert.Equal(t, Connections.GetActiveSessions(username), 1)
 	c = NewBaseConnection(sshConn2.id+"_1", ProtocolSSH, "", "", user)
 	fakeConn = &fakeConnection{
 		BaseConnection: c,
 	}
 	Connections.AddSSHConnection(sshConn2)
-	Connections.Add(fakeConn)
+	err = Connections.Add(fakeConn)
+	assert.NoError(t, err)
 	assert.Equal(t, Connections.GetActiveSessions(username), 2)
 
 	cFTP := NewBaseConnection("id2", ProtocolFTP, "", "", dataprovider.User{})
@@ -481,7 +507,8 @@ func TestIdleConnections(t *testing.T) {
 	fakeConn = &fakeConnection{
 		BaseConnection: cFTP,
 	}
-	Connections.Add(fakeConn)
+	err = Connections.Add(fakeConn)
+	assert.NoError(t, err)
 	assert.Equal(t, Connections.GetActiveSessions(username), 2)
 	assert.Len(t, Connections.GetStats(), 3)
 	Connections.RLock()
@@ -521,7 +548,8 @@ func TestCloseConnection(t *testing.T) {
 		BaseConnection: c,
 	}
 	assert.True(t, Connections.IsNewConnectionAllowed("127.0.0.1"))
-	Connections.Add(fakeConn)
+	err := Connections.Add(fakeConn)
+	assert.NoError(t, err)
 	assert.Len(t, Connections.GetStats(), 1)
 	res := Connections.Close(fakeConn.GetID())
 	assert.True(t, res)
@@ -536,19 +564,34 @@ func TestSwapConnection(t *testing.T) {
 	fakeConn := &fakeConnection{
 		BaseConnection: c,
 	}
-	Connections.Add(fakeConn)
+	err := Connections.Add(fakeConn)
+	assert.NoError(t, err)
 	if assert.Len(t, Connections.GetStats(), 1) {
 		assert.Equal(t, "", Connections.GetStats()[0].Username)
 	}
 	c = NewBaseConnection("id", ProtocolFTP, "", "", dataprovider.User{
 		BaseUser: sdk.BaseUser{
-			Username: userTestUsername,
+			Username:    userTestUsername,
+			MaxSessions: 1,
 		},
 	})
 	fakeConn = &fakeConnection{
 		BaseConnection: c,
 	}
-	err := Connections.Swap(fakeConn)
+	c1 := NewBaseConnection("id1", ProtocolFTP, "", "", dataprovider.User{
+		BaseUser: sdk.BaseUser{
+			Username: userTestUsername,
+		},
+	})
+	fakeConn1 := &fakeConnection{
+		BaseConnection: c1,
+	}
+	err = Connections.Add(fakeConn1)
+	assert.NoError(t, err)
+	err = Connections.Swap(fakeConn)
+	assert.Error(t, err)
+	Connections.Remove(fakeConn1.ID)
+	err = Connections.Swap(fakeConn)
 	assert.NoError(t, err)
 	if assert.Len(t, Connections.GetStats(), 1) {
 		assert.Equal(t, userTestUsername, Connections.GetStats()[0].Username)
@@ -600,9 +643,12 @@ func TestConnectionStatus(t *testing.T) {
 		command:        "PROPFIND",
 	}
 	t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
-	Connections.Add(fakeConn1)
-	Connections.Add(fakeConn2)
-	Connections.Add(fakeConn3)
+	err := Connections.Add(fakeConn1)
+	assert.NoError(t, err)
+	err = Connections.Add(fakeConn2)
+	assert.NoError(t, err)
+	err = Connections.Add(fakeConn3)
+	assert.NoError(t, err)
 
 	stats := Connections.GetStats()
 	assert.Len(t, stats, 3)
@@ -628,7 +674,7 @@ func TestConnectionStatus(t *testing.T) {
 		}
 	}
 
-	err := t1.Close()
+	err = t1.Close()
 	assert.NoError(t, err)
 	err = t2.Close()
 	assert.NoError(t, err)

+ 5 - 0
common/connection.go

@@ -80,6 +80,11 @@ func (c *BaseConnection) GetUsername() string {
 	return c.User.Username
 }
 
+// GetMaxSessions returns the maximum number of concurrent sessions allowed
+func (c *BaseConnection) GetMaxSessions() int {
+	return c.User.MaxSessions
+}
+
 // GetProtocol returns the protocol for the connection
 func (c *BaseConnection) GetProtocol() string {
 	return c.protocol

+ 18 - 9
common/transferschecker_test.go

@@ -62,7 +62,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
 	transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
 		"/file1", TransferUpload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
 	transfer1.BytesReceived = 150
-	Connections.Add(fakeConn1)
+	err = Connections.Add(fakeConn1)
+	assert.NoError(t, err)
 	// the transferschecker will do nothing if there is only one ongoing transfer
 	Connections.checkTransfers()
 	assert.Nil(t, transfer1.errAbort)
@@ -76,7 +77,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
 		"/file2", TransferUpload, 0, 0, 120, 40, true, fsUser, dataprovider.TransferQuota{})
 	transfer1.BytesReceived = 50
 	transfer2.BytesReceived = 60
-	Connections.Add(fakeConn2)
+	err = Connections.Add(fakeConn2)
+	assert.NoError(t, err)
 
 	connID3 := xid.New().String()
 	conn3 := NewBaseConnection(connID3, ProtocolSFTP, "", "", user)
@@ -86,7 +88,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
 	transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"),
 		"/file3", TransferDownload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
 	transfer3.BytesReceived = 60 // this value will be ignored, this is a download
-	Connections.Add(fakeConn3)
+	err = Connections.Add(fakeConn3)
+	assert.NoError(t, err)
 
 	// the transfers are not overquota
 	Connections.checkTransfers()
@@ -146,7 +149,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
 	transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(os.TempDir(), folderName, "file1"),
 		filepath.Join(os.TempDir(), folderName, "file1"), path.Join(vdirPath, "/file1"), TransferUpload, 0, 0,
 		100, 0, true, fsFolder, dataprovider.TransferQuota{})
-	Connections.Add(fakeConn4)
+	err = Connections.Add(fakeConn4)
+	assert.NoError(t, err)
 	connID5 := xid.New().String()
 	conn5 := NewBaseConnection(connID5, ProtocolSFTP, "", "", user)
 	fakeConn5 := &fakeConnection{
@@ -156,7 +160,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
 		filepath.Join(os.TempDir(), folderName, "file2"), path.Join(vdirPath, "/file2"), TransferUpload, 0, 0,
 		100, 0, true, fsFolder, dataprovider.TransferQuota{})
 
-	Connections.Add(fakeConn5)
+	err = Connections.Add(fakeConn5)
+	assert.NoError(t, err)
 	transfer4.BytesReceived = 50
 	transfer5.BytesReceived = 40
 	Connections.checkTransfers()
@@ -245,7 +250,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
 	transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
 		"/file1", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
 	transfer1.BytesReceived = 150
-	Connections.Add(fakeConn1)
+	err = Connections.Add(fakeConn1)
+	assert.NoError(t, err)
 	// the transferschecker will do nothing if there is only one ongoing transfer
 	Connections.checkTransfers()
 	assert.Nil(t, transfer1.errAbort)
@@ -258,7 +264,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
 	transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
 		"/file2", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
 	transfer2.BytesReceived = 150
-	Connections.Add(fakeConn2)
+	err = Connections.Add(fakeConn2)
+	assert.NoError(t, err)
 	Connections.checkTransfers()
 	assert.Nil(t, transfer1.errAbort)
 	assert.Nil(t, transfer2.errAbort)
@@ -294,7 +301,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
 	transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
 		"/file1", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
 	transfer3.BytesSent = 150
-	Connections.Add(fakeConn3)
+	err = Connections.Add(fakeConn3)
+	assert.NoError(t, err)
 
 	connID4 := xid.New().String()
 	conn4 := NewBaseConnection(connID4, ProtocolSFTP, "", "", user)
@@ -304,7 +312,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
 	transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
 		"/file2", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
 	transfer4.BytesSent = 150
-	Connections.Add(fakeConn4)
+	err = Connections.Add(fakeConn4)
+	assert.NoError(t, err)
 	Connections.checkTransfers()
 	assert.Nil(t, transfer3.errAbort)
 	assert.Nil(t, transfer4.errAbort)

+ 2 - 1
ftpd/internal_test.go

@@ -593,7 +593,8 @@ func TestClientVersion(t *testing.T) {
 		BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user),
 		clientContext:  mockCC,
 	}
-	common.Connections.Add(connection)
+	err := common.Connections.Add(connection)
+	assert.NoError(t, err)
 	stats := common.Connections.GetStats()
 	if assert.Len(t, stats, 1) {
 		assert.Equal(t, "mock version", stats[0].ClientVersion)

+ 5 - 5
ftpd/server.go

@@ -168,8 +168,8 @@ func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
 			cc.RemoteAddr().String(), user),
 		clientContext: cc,
 	}
-	common.Connections.Add(connection)
-	return s.initialMsg, nil
+	err = common.Connections.Add(connection)
+	return s.initialMsg, err
 }
 
 // ClientDisconnected is called when the user disconnects, even if he never authenticated
@@ -367,9 +367,9 @@ func (s *Server) validateUser(user dataprovider.User, cc ftpserver.ClientContext
 	}
 	err = common.Connections.Swap(connection)
 	if err != nil {
-		err = user.CloseFs()
-		logger.Warn(logSender, connectionID, "unable to swap connection, close fs error: %v", err)
-		return nil, common.ErrInternalFailure
+		errClose := user.CloseFs()
+		logger.Warn(logSender, connectionID, "unable to swap connection: %v, close fs error: %v", err, errClose)
+		return nil, err
 	}
 	return connection, nil
 }

+ 8 - 7
go.mod

@@ -12,9 +12,9 @@ require (
 	github.com/aws/aws-sdk-go-v2/config v1.15.3
 	github.com/aws/aws-sdk-go-v2/credentials v1.11.2
 	github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3
-	github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.4
+	github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.5
 	github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.3
-	github.com/aws/aws-sdk-go-v2/service/s3 v1.26.4
+	github.com/aws/aws-sdk-go-v2/service/s3 v1.26.5
 	github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.4
 	github.com/aws/aws-sdk-go-v2/service/sts v1.16.3
 	github.com/cockroachdb/cockroach-go/v2 v2.2.8
@@ -32,10 +32,10 @@ require (
 	github.com/grandcat/zeroconf v1.0.0
 	github.com/hashicorp/go-hclog v1.2.0
 	github.com/hashicorp/go-plugin v1.4.3
-	github.com/hashicorp/go-retryablehttp v0.7.0
+	github.com/hashicorp/go-retryablehttp v0.7.1
 	github.com/jlaffaye/ftp v0.0.0-20201112195030-9aae4d151126
 	github.com/klauspost/compress v1.15.1
-	github.com/lestrrat-go/jwx v1.2.22
+	github.com/lestrrat-go/jwx v1.2.23
 	github.com/lib/pq v1.10.5
 	github.com/lithammer/shortuuid/v3 v3.0.7
 	github.com/mattn/go-sqlite3 v1.14.12
@@ -54,7 +54,7 @@ require (
 	github.com/shirou/gopsutil/v3 v3.22.3
 	github.com/spf13/afero v1.8.2
 	github.com/spf13/cobra v1.4.0
-	github.com/spf13/viper v1.10.1
+	github.com/spf13/viper v1.11.0
 	github.com/stretchr/testify v1.7.1
 	github.com/studio-b12/gowebdav v0.0.0-20220128162035-c7b1ff8a5e62
 	github.com/unrolled/secure v1.10.0
@@ -67,7 +67,7 @@ require (
 	golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4
 	golang.org/x/net v0.0.0-20220412020605-290c469a71a5
 	golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5
-	golang.org/x/sys v0.0.0-20220412071739-889880a91fd5
+	golang.org/x/sys v0.0.0-20220412211240-33da011f77ad
 	golang.org/x/time v0.0.0-20220411224347-583f2d630306
 	google.golang.org/api v0.74.0
 	gopkg.in/natefinch/lumberjack.v2 v2.0.0
@@ -130,6 +130,7 @@ require (
 	github.com/mitchellh/mapstructure v1.4.3 // indirect
 	github.com/oklog/run v1.1.0 // indirect
 	github.com/pelletier/go-toml v1.9.4 // indirect
+	github.com/pelletier/go-toml/v2 v2.0.0-beta.8 // indirect
 	github.com/pkg/errors v0.9.1 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/power-devops/perfstat v0.0.0-20220216144756-c35f1ee13d7c // indirect
@@ -151,7 +152,7 @@ require (
 	golang.org/x/tools v0.1.10 // indirect
 	golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f // indirect
 	google.golang.org/appengine v1.6.7 // indirect
-	google.golang.org/genproto v0.0.0-20220407144326-9054f6ed7bac // indirect
+	google.golang.org/genproto v0.0.0-20220413183235-5e96e2839df9 // indirect
 	google.golang.org/grpc v1.45.0 // indirect
 	google.golang.org/protobuf v1.28.0 // indirect
 	gopkg.in/ini.v1 v1.66.4 // indirect

+ 16 - 13
go.sum

@@ -142,8 +142,8 @@ github.com/aws/aws-sdk-go-v2/credentials v1.11.2/go.mod h1:j8YsY9TXTm31k4eFhspiQ
 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3 h1:LWPg5zjHV9oz/myQr4wMs0gi4CjnDN/ILmyZUFYXZsU=
 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3/go.mod h1:uk1vhHHERfSVCUnqSqz8O48LBYDSC+k6brng09jcMOk=
 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.3/go.mod h1:0dHuD2HZZSiwfJSy1FO5bX1hQ1TxVV1QXXjpn3XUE44=
-github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.4 h1:iqcMQBj/B3FPxVb5SGNHC8XAh64hmaWUC8piZArBE7U=
-github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.4/go.mod h1:s79ZPBpDzcR1BCuAhGCF1rgmd/QmLueKCvdkmX4SDgg=
+github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.5 h1:lPo/NX1o4vkk2C7mHmB2FCf9Qp7KZNHrlzHxdP/yugw=
+github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.5/go.mod h1:JNo9mEKrjnmDBc19z65TZmj1xG9PQHu2GOlApYk31DU=
 github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9 h1:onz/VaaxZ7Z4V+WIN9Txly9XLTmoOh1oJ8XcAC3pako=
 github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9/go.mod h1:AnVH5pvai0pAF4lXRq0bmhbes1u9R8wTE+g+183bZNM=
 github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.3 h1:9stUQR/u2KXU6HkFJYlqnZEjBnbgrVbG6I5HN09xZh0=
@@ -164,8 +164,8 @@ github.com/aws/aws-sdk-go-v2/service/kms v1.16.3/go.mod h1:QuiHPBqlOFCi4LqdSskYY
 github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.3 h1:xqXHk4UDW7ii4MRciyLpY87yuZds0iymmgHt3h35xTE=
 github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.3/go.mod h1:HT0cm2+NUCF33MdXjck554HC6VRgQ4q6JIlSqlYZ18Y=
 github.com/aws/aws-sdk-go-v2/service/s3 v1.26.3/go.mod h1:g1qvDuRsJY+XghsV6zg00Z4KJ7DtFFCx8fJD2a491Ak=
-github.com/aws/aws-sdk-go-v2/service/s3 v1.26.4 h1:frOI/v6KWuKGlKUA5gheRw01EDpxcCxTalFQkCOZXAo=
-github.com/aws/aws-sdk-go-v2/service/s3 v1.26.4/go.mod h1:qFKU5d+PAv+23bi9ZhtWeA+TmLUz7B/R59ZGXQ1Mmu4=
+github.com/aws/aws-sdk-go-v2/service/s3 v1.26.5 h1:A3PuAUlh1u47WHcM68CDaG9ZWjK7ewePjDp+0dY9yv4=
+github.com/aws/aws-sdk-go-v2/service/s3 v1.26.5/go.mod h1:qFKU5d+PAv+23bi9ZhtWeA+TmLUz7B/R59ZGXQ1Mmu4=
 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.4 h1:EmIEXOjAdXtxa2OGM1VAajZV/i06Q8qd4kBpJd9/p1k=
 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.4/go.mod h1:PJc8s+lxyU8rrre0/4a0pn2wgwiDvOEzoOjcJUBr67o=
 github.com/aws/aws-sdk-go-v2/service/sns v1.17.4/go.mod h1:kElt+uCcXxcqFyc+bQqZPFD9DME/eC6oHBXvFzQ9Bcw=
@@ -427,8 +427,8 @@ github.com/hashicorp/go-hclog v1.2.0 h1:La19f8d7WIlm4ogzNHB0JGqs5AUDAZ2UfCY4sJXc
 github.com/hashicorp/go-hclog v1.2.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
 github.com/hashicorp/go-plugin v1.4.3 h1:DXmvivbWD5qdiBts9TpBC7BYL1Aia5sxbRgQB+v6UZM=
 github.com/hashicorp/go-plugin v1.4.3/go.mod h1:5fGEH17QVwTTcR0zV7yhDPLLmFX9YSZ38b18Udy6vYQ=
-github.com/hashicorp/go-retryablehttp v0.7.0 h1:eu1EI/mbirUgP5C8hVsTNaGZreBDlYiwC1FZWkvQPQ4=
-github.com/hashicorp/go-retryablehttp v0.7.0/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY=
+github.com/hashicorp/go-retryablehttp v0.7.1 h1:sUiuQAnLlbvmExtFQs72iFW/HXeUn8Z1aJLQ4LJJbTQ=
+github.com/hashicorp/go-retryablehttp v0.7.1/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY=
 github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
 github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
 github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
@@ -549,8 +549,8 @@ github.com/lestrrat-go/iter v1.0.1/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbq
 github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
 github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
 github.com/lestrrat-go/jwx v1.2.6/go.mod h1:tJuGuAI3LC71IicTx82Mz1n3w9woAs2bYJZpkjJQ5aU=
-github.com/lestrrat-go/jwx v1.2.22 h1:bCxMokwHNuJHVxgANP4OBddXGtQ9Oy+6cqp4O2rW7DU=
-github.com/lestrrat-go/jwx v1.2.22/go.mod h1:sAXjRwzSvCN6soO4RLoWWm1bVPpb8iOuv0IYfH8OWd8=
+github.com/lestrrat-go/jwx v1.2.23 h1:8oP5fY1yzCRraUNNyfAVdOkLCqY7xMZz11lVcvHqC1Y=
+github.com/lestrrat-go/jwx v1.2.23/go.mod h1:sAXjRwzSvCN6soO4RLoWWm1bVPpb8iOuv0IYfH8OWd8=
 github.com/lestrrat-go/option v1.0.0 h1:WqAWL8kh8VcSoD6xjSH34/1m8yxluXQbDeKNfvFeEO4=
 github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
 github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
@@ -627,6 +627,8 @@ github.com/otiai10/mint v1.3.3 h1:7JgpsBaN0uMkyju4tbYHu0mnM55hNKVYLsXmwr15NQI=
 github.com/otiai10/mint v1.3.3/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc=
 github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM=
 github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
+github.com/pelletier/go-toml/v2 v2.0.0-beta.8 h1:dy81yyLYJDwMTifq24Oi/IslOslRrDSb3jwDggjz3Z0=
+github.com/pelletier/go-toml/v2 v2.0.0-beta.8/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
 github.com/pires/go-proxyproto v0.6.2 h1:KAZ7UteSOt6urjme6ZldyFm4wDe/z0ZUP0Yv0Dos0d8=
 github.com/pires/go-proxyproto v0.6.2/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY=
 github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA=
@@ -708,8 +710,8 @@ github.com/spf13/jwalterweatherman v1.1.0 h1:ue6voC5bR5F8YxI5S67j9i582FU4Qvo2bmq
 github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo=
 github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
 github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
-github.com/spf13/viper v1.10.1 h1:nuJZuYpG7gTj/XqiUwg8bA0cp1+M2mC3J4g5luUYBKk=
-github.com/spf13/viper v1.10.1/go.mod h1:IGlFPqhNAPKRxohIzWpI5QEy4kuI7tcl5WvR+8qy1rU=
+github.com/spf13/viper v1.11.0 h1:7OX/1FS6n7jHD1zGrZTM7WtY13ZELRyosK4k93oPr44=
+github.com/spf13/viper v1.11.0/go.mod h1:djo0X/bA5+tYVoCn+C7cAYJGcVn/qYLFTG8gdUsX7Zk=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
@@ -934,8 +936,8 @@ golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBc
 golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220412071739-889880a91fd5 h1:NubxfvTRuNb4RVzWrIDAUzUvREH1HkCD4JjyQTSG9As=
-golang.org/x/sys v0.0.0-20220412071739-889880a91fd5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0=
+golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
 golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -1166,8 +1168,9 @@ google.golang.org/genproto v0.0.0-20220310185008-1973136f34c6/go.mod h1:kGP+zUP2
 google.golang.org/genproto v0.0.0-20220324131243-acbaeb5b85eb/go.mod h1:hAL49I2IFola2sVEjAn7MEwsja0xp51I0tlGAf9hz4E=
 google.golang.org/genproto v0.0.0-20220401170504-314d38edb7de/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo=
 google.golang.org/genproto v0.0.0-20220405205423-9d709892a2bf/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo=
-google.golang.org/genproto v0.0.0-20220407144326-9054f6ed7bac h1:qSNTkEN+L2mvWcLgJOR+8bdHX9rN/IdU3A1Ghpfb1Rg=
 google.golang.org/genproto v0.0.0-20220407144326-9054f6ed7bac/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo=
+google.golang.org/genproto v0.0.0-20220413183235-5e96e2839df9 h1:XGQ6tc+EnM35IAazg4y6AHmUg4oK8NXsXaILte1vRlk=
+google.golang.org/genproto v0.0.0-20220413183235-5e96e2839df9/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo=
 google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
 google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
 google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=

+ 27 - 43
httpd/api_http_user.go

@@ -34,7 +34,7 @@ func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, err
 	connID := xid.New().String()
 	protocol := getProtocolFromRequest(r)
 	connectionID := fmt.Sprintf("%v_%v", protocol, connID)
-	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
+	if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
 		return nil, err
 	}
@@ -43,6 +43,10 @@ func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, err
 			r.RemoteAddr, user),
 		request: r,
 	}
+	if err = common.Connections.Add(connection); err != nil {
+		sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
+		return connection, err
+	}
 	return connection, nil
 }
 
@@ -52,7 +56,6 @@ func readUserFolder(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		return
 	}
-	common.Connections.Add(connection)
 	defer common.Connections.Remove(connection.GetID())
 
 	name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@@ -70,7 +73,6 @@ func createUserDir(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		return
 	}
-	common.Connections.Add(connection)
 	defer common.Connections.Remove(connection.GetID())
 
 	name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@@ -90,22 +92,7 @@ func createUserDir(w http.ResponseWriter, r *http.Request) {
 
 func renameUserDir(w http.ResponseWriter, r *http.Request) {
 	r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
-	connection, err := getUserConnection(w, r)
-	if err != nil {
-		return
-	}
-	common.Connections.Add(connection)
-	defer common.Connections.Remove(connection.GetID())
-
-	oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
-	newName := connection.User.GetCleanedPath(r.URL.Query().Get("target"))
-	err = connection.Rename(oldName, newName)
-	if err != nil {
-		sendAPIResponse(w, r, err, fmt.Sprintf("Unable to rename directory %#v to %#v", oldName, newName),
-			getMappedStatusCode(err))
-		return
-	}
-	sendAPIResponse(w, r, nil, fmt.Sprintf("Directory %#v renamed to %#v", oldName, newName), http.StatusOK)
+	renameItem(w, r)
 }
 
 func deleteUserDir(w http.ResponseWriter, r *http.Request) {
@@ -114,7 +101,6 @@ func deleteUserDir(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		return
 	}
-	common.Connections.Add(connection)
 	defer common.Connections.Remove(connection.GetID())
 
 	name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@@ -132,7 +118,6 @@ func getUserFile(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		return
 	}
-	common.Connections.Add(connection)
 	defer common.Connections.Remove(connection.GetID())
 
 	name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@@ -183,7 +168,6 @@ func setFileDirMetadata(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		return
 	}
-	common.Connections.Add(connection)
 	defer common.Connections.Remove(connection.GetID())
 
 	name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@@ -214,7 +198,6 @@ func uploadUserFile(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		return
 	}
-	common.Connections.Add(connection)
 	defer common.Connections.Remove(connection.GetID())
 
 	filePath := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@@ -258,6 +241,8 @@ func uploadUserFiles(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		return
 	}
+	defer common.Connections.Remove(connection.GetID())
+
 	transferQuota := connection.GetTransferQuota()
 	if !transferQuota.HasUploadSpace() {
 		connection.Log(logger.LevelInfo, "denying file write due to transfer quota limits")
@@ -265,8 +250,6 @@ func uploadUserFiles(w http.ResponseWriter, r *http.Request) {
 			http.StatusRequestEntityTooLarge)
 		return
 	}
-	common.Connections.Add(connection)
-	defer common.Connections.Remove(connection.GetID())
 
 	t := newThrottledReader(r.Body, connection.User.UploadBandwidth, connection)
 	r.Body = t
@@ -332,22 +315,7 @@ func doUploadFiles(w http.ResponseWriter, r *http.Request, connection *Connectio
 
 func renameUserFile(w http.ResponseWriter, r *http.Request) {
 	r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
-	connection, err := getUserConnection(w, r)
-	if err != nil {
-		return
-	}
-	common.Connections.Add(connection)
-	defer common.Connections.Remove(connection.GetID())
-
-	oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
-	newName := connection.User.GetCleanedPath(r.URL.Query().Get("target"))
-	err = connection.Rename(oldName, newName)
-	if err != nil {
-		sendAPIResponse(w, r, err, fmt.Sprintf("Unable to rename file %#v to %#v", oldName, newName),
-			getMappedStatusCode(err))
-		return
-	}
-	sendAPIResponse(w, r, nil, fmt.Sprintf("File %#v renamed to %#v", oldName, newName), http.StatusOK)
+	renameItem(w, r)
 }
 
 func deleteUserFile(w http.ResponseWriter, r *http.Request) {
@@ -356,7 +324,6 @@ func deleteUserFile(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		return
 	}
-	common.Connections.Add(connection)
 	defer common.Connections.Remove(connection.GetID())
 
 	name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@@ -393,7 +360,6 @@ func getUserFilesAsZipStream(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		return
 	}
-	common.Connections.Add(connection)
 	defer common.Connections.Remove(connection.GetID())
 
 	var filesList []string
@@ -581,3 +547,21 @@ func setModificationTimeFromHeader(r *http.Request, c *Connection, filePath stri
 		}
 	}
 }
+
+func renameItem(w http.ResponseWriter, r *http.Request) {
+	connection, err := getUserConnection(w, r)
+	if err != nil {
+		return
+	}
+	defer common.Connections.Remove(connection.GetID())
+
+	oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
+	newName := connection.User.GetCleanedPath(r.URL.Query().Get("target"))
+	err = connection.Rename(oldName, newName)
+	if err != nil {
+		sendAPIResponse(w, r, err, fmt.Sprintf("Unable to rename %#v -> %#v", oldName, newName),
+			getMappedStatusCode(err))
+		return
+	}
+	sendAPIResponse(w, r, nil, fmt.Sprintf("%#v renamed to %#v", oldName, newName), http.StatusOK)
+}

+ 20 - 5
httpd/api_shares.go

@@ -167,7 +167,10 @@ func (s *httpdServer) readBrowsableShareContents(w http.ResponseWriter, r *http.
 		return
 	}
 
-	common.Connections.Add(connection)
+	if err = common.Connections.Add(connection); err != nil {
+		sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
+		return
+	}
 	defer common.Connections.Remove(connection.GetID())
 
 	contents, err := connection.ReadDir(name)
@@ -194,7 +197,10 @@ func (s *httpdServer) downloadBrowsableSharedFile(w http.ResponseWriter, r *http
 		return
 	}
 
-	common.Connections.Add(connection)
+	if err = common.Connections.Add(connection); err != nil {
+		sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
+		return
+	}
 	defer common.Connections.Remove(connection.GetID())
 
 	info, err := connection.Stat(name, 1)
@@ -231,7 +237,10 @@ func (s *httpdServer) downloadFromShare(w http.ResponseWriter, r *http.Request)
 		return
 	}
 
-	common.Connections.Add(connection)
+	if err = common.Connections.Add(connection); err != nil {
+		sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
+		return
+	}
 	defer common.Connections.Remove(connection.GetID())
 
 	compress := true
@@ -289,7 +298,10 @@ func (s *httpdServer) uploadFileToShare(w http.ResponseWriter, r *http.Request)
 	}
 	dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck
 
-	common.Connections.Add(connection)
+	if err = common.Connections.Add(connection); err != nil {
+		sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
+		return
+	}
 	defer common.Connections.Remove(connection.GetID())
 	if err := doUploadFile(w, r, connection, filePath); err != nil {
 		dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck
@@ -313,7 +325,10 @@ func (s *httpdServer) uploadFilesToShare(w http.ResponseWriter, r *http.Request)
 		return
 	}
 
-	common.Connections.Add(connection)
+	if err = common.Connections.Add(connection); err != nil {
+		sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
+		return
+	}
 	defer common.Connections.Remove(connection.GetID())
 
 	t := newThrottledReader(r.Body, connection.User.UploadBandwidth, connection)

+ 2 - 2
httpd/api_utils.go

@@ -498,7 +498,7 @@ func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err err
 	dataprovider.ExecutePostLoginHook(user, loginMethod, ip, protocol, err)
 }
 
-func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string) error {
+func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string, checkSessions bool) error {
 	if util.IsStringInSlice(common.ProtocolHTTP, user.Filters.DeniedProtocols) {
 		logger.Info(logSender, connectionID, "cannot login user %#v, protocol HTTP is not allowed", user.Username)
 		return fmt.Errorf("protocol HTTP is not allowed for user %#v", user.Username)
@@ -507,7 +507,7 @@ func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID
 		logger.Info(logSender, connectionID, "cannot login user %#v, password login method is not allowed", user.Username)
 		return fmt.Errorf("login method password is not allowed for user %#v", user.Username)
 	}
-	if user.MaxSessions > 0 {
+	if checkSessions && user.MaxSessions > 0 {
 		activeSessions := common.Connections.GetActiveSessions(user.Username)
 		if activeSessions >= user.MaxSessions {
 			logger.Info(logSender, connectionID, "authentication refused for user: %#v, too many open sessions: %v/%v", user.Username,

+ 173 - 12
httpd/httpd_test.go

@@ -3857,7 +3857,8 @@ func TestCloseActiveConnection(t *testing.T) {
 	fakeConn := &fakeConnection{
 		BaseConnection: c,
 	}
-	common.Connections.Add(fakeConn)
+	err = common.Connections.Add(fakeConn)
+	assert.NoError(t, err)
 	_, err = httpdtest.CloseConnection(c.GetID(), http.StatusOK)
 	assert.NoError(t, err)
 	assert.Len(t, common.Connections.GetStats(), 0)
@@ -3870,12 +3871,14 @@ func TestCloseConnectionAfterUserUpdateDelete(t *testing.T) {
 	fakeConn := &fakeConnection{
 		BaseConnection: c,
 	}
-	common.Connections.Add(fakeConn)
+	err = common.Connections.Add(fakeConn)
+	assert.NoError(t, err)
 	c1 := common.NewBaseConnection("connID1", common.ProtocolSFTP, "", "", user)
 	fakeConn1 := &fakeConnection{
 		BaseConnection: c1,
 	}
-	common.Connections.Add(fakeConn1)
+	err = common.Connections.Add(fakeConn1)
+	assert.NoError(t, err)
 	user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "0")
 	assert.NoError(t, err)
 	assert.Len(t, common.Connections.GetStats(), 2)
@@ -3883,8 +3886,10 @@ func TestCloseConnectionAfterUserUpdateDelete(t *testing.T) {
 	assert.NoError(t, err)
 	assert.Len(t, common.Connections.GetStats(), 0)
 
-	common.Connections.Add(fakeConn)
-	common.Connections.Add(fakeConn1)
+	err = common.Connections.Add(fakeConn)
+	assert.NoError(t, err)
+	err = common.Connections.Add(fakeConn1)
+	assert.NoError(t, err)
 	assert.Len(t, common.Connections.GetStats(), 2)
 	_, err = httpdtest.RemoveUser(user, http.StatusOK)
 	assert.NoError(t, err)
@@ -5173,7 +5178,8 @@ func TestLoaddataMode(t *testing.T) {
 	fakeConn := &fakeConnection{
 		BaseConnection: c,
 	}
-	common.Connections.Add(fakeConn)
+	err = common.Connections.Add(fakeConn)
+	assert.NoError(t, err)
 	assert.Len(t, common.Connections.GetStats(), 1)
 	user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
 	assert.NoError(t, err)
@@ -8714,7 +8720,8 @@ func TestWebClientMaxConnections(t *testing.T) {
 	connection := &httpd.Connection{
 		BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user),
 	}
-	common.Connections.Add(connection)
+	err = common.Connections.Add(connection)
+	assert.NoError(t, err)
 
 	_, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword)
 	assert.Error(t, err)
@@ -8895,20 +8902,57 @@ func TestMaxSessions(t *testing.T) {
 	u.Email = "user@session.com"
 	user, _, err := httpdtest.AddUser(u, http.StatusCreated)
 	assert.NoError(t, err)
-	_, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword)
+	webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword)
 	assert.NoError(t, err)
-	_, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
+	apiToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
 	assert.NoError(t, err)
 	// now add a fake connection
 	fs := vfs.NewOsFs("id", os.TempDir(), "")
 	connection := &httpd.Connection{
 		BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user),
 	}
-	common.Connections.Add(connection)
+	err = common.Connections.Add(connection)
+	assert.NoError(t, err)
 	_, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword)
 	assert.Error(t, err)
 	_, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
 	assert.Error(t, err)
+	// try an user API call
+	req, err := http.NewRequest(http.MethodGet, userDirsPath+"/?path=%2F", nil)
+	assert.NoError(t, err)
+	setBearerForReq(req, apiToken)
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusTooManyRequests, rr)
+	assert.Contains(t, rr.Body.String(), "too many open sessions")
+	// web client requests
+	req, err = http.NewRequest(http.MethodGet, webClientDownloadZipPath, nil)
+	assert.NoError(t, err)
+	setJWTCookieForReq(req, webToken)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusTooManyRequests, rr)
+	assert.Contains(t, rr.Body.String(), "too many open sessions")
+
+	req, err = http.NewRequest(http.MethodGet, webClientDirsPath, nil)
+	assert.NoError(t, err)
+	setJWTCookieForReq(req, webToken)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusTooManyRequests, rr)
+	assert.Contains(t, rr.Body.String(), "too many open sessions")
+
+	req, err = http.NewRequest(http.MethodGet, webClientFilesPath+"?path=p", nil)
+	assert.NoError(t, err)
+	setJWTCookieForReq(req, webToken)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusTooManyRequests, rr)
+	assert.Contains(t, rr.Body.String(), "too many open sessions")
+
+	req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path=file", nil)
+	assert.NoError(t, err)
+	setJWTCookieForReq(req, webToken)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusTooManyRequests, rr)
+	assert.Contains(t, rr.Body.String(), "too many open sessions")
+
 	// test reset password
 	smtpCfg := smtp.Config{
 		Host:          "127.0.0.1",
@@ -8924,11 +8968,11 @@ func TestMaxSessions(t *testing.T) {
 	form.Set(csrfFormToken, csrfToken)
 	form.Set("username", user.Username)
 	lastResetCode = ""
-	req, err := http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode())))
+	req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode())))
 	assert.NoError(t, err)
 	req.RemoteAddr = defaultRemoteAddr
 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
-	rr := executeRequest(req)
+	rr = executeRequest(req)
 	assert.Equal(t, http.StatusFound, rr.Code)
 	assert.GreaterOrEqual(t, len(lastResetCode), 20)
 	form = make(url.Values)
@@ -9630,6 +9674,123 @@ func TestShareUsage(t *testing.T) {
 	executeRequest(req)
 }
 
+func TestShareMaxSessions(t *testing.T) {
+	u := getTestUser()
+	u.MaxSessions = 1
+	user, _, err := httpdtest.AddUser(u, http.StatusCreated)
+	assert.NoError(t, err)
+	token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
+	assert.NoError(t, err)
+
+	share := dataprovider.Share{
+		Name:  "test share max sessions read",
+		Scope: dataprovider.ShareScopeRead,
+		Paths: []string{"/"},
+	}
+	asJSON, err := json.Marshal(share)
+	assert.NoError(t, err)
+	req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON))
+	assert.NoError(t, err)
+	setBearerForReq(req, token)
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusCreated, rr)
+	objectID := rr.Header().Get("X-Object-ID")
+	assert.NotEmpty(t, objectID)
+
+	req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/dirs", nil)
+	assert.NoError(t, err)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr)
+	// add a fake connection
+	fs := vfs.NewOsFs("id", os.TempDir(), "")
+	connection := &httpd.Connection{
+		BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user),
+	}
+	err = common.Connections.Add(connection)
+	assert.NoError(t, err)
+
+	req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/dirs", nil)
+	assert.NoError(t, err)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusTooManyRequests, rr)
+	assert.Contains(t, rr.Body.String(), "too many open sessions")
+
+	req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"/dirs", nil)
+	assert.NoError(t, err)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusTooManyRequests, rr)
+	assert.Contains(t, rr.Body.String(), "too many open sessions")
+
+	req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"/browse", nil)
+	assert.NoError(t, err)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusTooManyRequests, rr)
+	assert.Contains(t, rr.Body.String(), "too many open sessions")
+
+	req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/files?path=afile", nil)
+	assert.NoError(t, err)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusTooManyRequests, rr)
+	assert.Contains(t, rr.Body.String(), "too many open sessions")
+
+	req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID, nil)
+	assert.NoError(t, err)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusTooManyRequests, rr)
+	assert.Contains(t, rr.Body.String(), "too many open sessions")
+
+	req, err = http.NewRequest(http.MethodDelete, userSharesPath+"/"+objectID, nil)
+	assert.NoError(t, err)
+	setBearerForReq(req, token)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr)
+
+	// now test a write share
+	share = dataprovider.Share{
+		Name:  "test share max sessions write",
+		Scope: dataprovider.ShareScopeWrite,
+		Paths: []string{"/"},
+	}
+	asJSON, err = json.Marshal(share)
+	assert.NoError(t, err)
+	req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON))
+	assert.NoError(t, err)
+	setBearerForReq(req, token)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusCreated, rr)
+	objectID = rr.Header().Get("X-Object-ID")
+	assert.NotEmpty(t, objectID)
+
+	req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "file.txt"), bytes.NewBuffer([]byte("content")))
+	assert.NoError(t, err)
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusTooManyRequests, rr)
+	assert.Contains(t, rr.Body.String(), "too many open sessions")
+
+	body := new(bytes.Buffer)
+	writer := multipart.NewWriter(body)
+	part1, err := writer.CreateFormFile("filenames", "file1.txt")
+	assert.NoError(t, err)
+	_, err = part1.Write([]byte("file1 content"))
+	assert.NoError(t, err)
+	err = writer.Close()
+	assert.NoError(t, err)
+	reader := bytes.NewReader(body.Bytes())
+	req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader)
+	assert.NoError(t, err)
+	req.Header.Add("Content-Type", writer.FormDataContentType())
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusTooManyRequests, rr)
+	assert.Contains(t, rr.Body.String(), "too many open sessions")
+
+	common.Connections.Remove(connection.GetID())
+	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+	assert.Len(t, common.Connections.GetStats(), 0)
+}
+
 func TestShareUploadSingle(t *testing.T) {
 	user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
 	assert.NoError(t, err)

+ 1 - 1
httpd/middleware.go

@@ -434,7 +434,7 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu
 		return err
 	}
 	connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
-	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
+	if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
 		return err
 	}

+ 1 - 1
httpd/oidc.go

@@ -331,7 +331,7 @@ func (t *oidcToken) getUser(r *http.Request) error {
 		return err
 	}
 	connectionID := fmt.Sprintf("%v_%v", common.ProtocolOIDC, xid.New().String())
-	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
+	if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
 		updateLoginMetrics(&user, dataprovider.LoginMethodIDP, ipAddr, err)
 		return err
 	}

+ 4 - 4
httpd/server.go

@@ -228,7 +228,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
 		return
 	}
 	connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
-	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
+	if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
 		s.renderClientLoginPage(w, err.Error(), ipAddr)
 		return
@@ -268,7 +268,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
 		return
 	}
 	connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String())
-	if err := checkHTTPClientUser(user, r, connectionID); err != nil {
+	if err := checkHTTPClientUser(user, r, connectionID, true); err != nil {
 		s.renderClientResetPwdPage(w, fmt.Sprintf("Password reset successfully but unable to login: %v", err.Error()), ipAddr)
 		return
 	}
@@ -760,7 +760,7 @@ func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
-	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
+	if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
 		updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
 		return
@@ -920,7 +920,7 @@ func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request,
 		logger.Debug(logSender, "", "signature mismatch for user %#v, unable to refresh cookie", user.Username)
 		return
 	}
-	if err := checkHTTPClientUser(&user, r, xid.New().String()); err != nil {
+	if err := checkHTTPClientUser(&user, r, xid.New().String(), true); err != nil {
 		logger.Debug(logSender, "", "unable to refresh cookie for user %#v: %v", user.Username, err)
 		return
 	}

+ 28 - 10
httpd/webclient.go

@@ -595,7 +595,7 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.
 	connID := xid.New().String()
 	protocol := getProtocolFromRequest(r)
 	connectionID := fmt.Sprintf("%v_%v", protocol, connID)
-	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
+	if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
 		s.renderClientForbiddenPage(w, r, err.Error())
 		return
 	}
@@ -604,7 +604,10 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.
 			r.RemoteAddr, user),
 		request: r,
 	}
-	common.Connections.Add(connection)
+	if err = common.Connections.Add(connection); err != nil {
+		s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
+		return
+	}
 	defer common.Connections.Remove(connection.GetID())
 
 	name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@@ -635,7 +638,10 @@ func (s *httpdServer) handleShareGetDirContents(w http.ResponseWriter, r *http.R
 		s.renderClientMessagePage(w, r, "Invalid share path", "", getRespStatus(err), err, "")
 		return
 	}
-	common.Connections.Add(connection)
+	if err = common.Connections.Add(connection); err != nil {
+		s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
+		return
+	}
 	defer common.Connections.Remove(connection.GetID())
 
 	contents, err := connection.ReadDir(name)
@@ -691,7 +697,10 @@ func (s *httpdServer) handleShareGetFiles(w http.ResponseWriter, r *http.Request
 		return
 	}
 
-	common.Connections.Add(connection)
+	if err = common.Connections.Add(connection); err != nil {
+		s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
+		return
+	}
 	defer common.Connections.Remove(connection.GetID())
 
 	var info os.FileInfo
@@ -735,7 +744,7 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.
 	connID := xid.New().String()
 	protocol := getProtocolFromRequest(r)
 	connectionID := fmt.Sprintf("%v_%v", protocol, connID)
-	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
+	if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
 		sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
 		return
 	}
@@ -744,7 +753,10 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.
 			r.RemoteAddr, user),
 		request: r,
 	}
-	common.Connections.Add(connection)
+	if err = common.Connections.Add(connection); err != nil {
+		s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
+		return
+	}
 	defer common.Connections.Remove(connection.GetID())
 
 	name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@@ -809,7 +821,7 @@ func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Reques
 	connID := xid.New().String()
 	protocol := getProtocolFromRequest(r)
 	connectionID := fmt.Sprintf("%v_%v", protocol, connID)
-	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
+	if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
 		s.renderClientForbiddenPage(w, r, err.Error())
 		return
 	}
@@ -818,7 +830,10 @@ func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Reques
 			r.RemoteAddr, user),
 		request: r,
 	}
-	common.Connections.Add(connection)
+	if err = common.Connections.Add(connection); err != nil {
+		s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
+		return
+	}
 	defer common.Connections.Remove(connection.GetID())
 
 	name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@@ -866,7 +881,7 @@ func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Reques
 	connID := xid.New().String()
 	protocol := getProtocolFromRequest(r)
 	connectionID := fmt.Sprintf("%v_%v", protocol, connID)
-	if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
+	if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
 		s.renderClientForbiddenPage(w, r, err.Error())
 		return
 	}
@@ -875,7 +890,10 @@ func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Reques
 			r.RemoteAddr, user),
 		request: r,
 	}
-	common.Connections.Add(connection)
+	if err = common.Connections.Add(connection); err != nil {
+		s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
+		return
+	}
 	defer common.Connections.Remove(connection.GetID())
 
 	name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))

+ 2 - 2
sftpd/cryptfs_test.go

@@ -369,7 +369,7 @@ func TestTruncate(t *testing.T) {
 }
 
 func TestSCPBasicHandlingCryptoFs(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -427,7 +427,7 @@ func TestSCPBasicHandlingCryptoFs(t *testing.T) {
 }
 
 func TestSCPRecursiveCryptFs(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true

+ 5 - 1
sftpd/handler.go

@@ -457,8 +457,12 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO
 	return t, nil
 }
 
-// Disconnect disconnects the client closing the network connection
+// Disconnect disconnects the client by closing the channel
 func (c *Connection) Disconnect() error {
+	if c.channel == nil {
+		c.Log(logger.LevelWarn, "cannot disconnect a nil channel")
+		return nil
+	}
 	return c.channel.Close()
 }
 

+ 43 - 0
sftpd/internal_test.go

@@ -14,6 +14,7 @@ import (
 
 	"github.com/eikenb/pipeat"
 	"github.com/pkg/sftp"
+	"github.com/rs/xid"
 	"github.com/sftpgo/sdk"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
@@ -2152,3 +2153,45 @@ func TestLoadRevokedUserCertsFile(t *testing.T) {
 	err = os.RemoveAll(r.filePath)
 	assert.NoError(t, err)
 }
+
+func TestMaxUserSessions(t *testing.T) {
+	connection := &Connection{
+		BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolSFTP, "", "", dataprovider.User{
+			BaseUser: sdk.BaseUser{
+				Username:    "user_max_sessions",
+				HomeDir:     filepath.Clean(os.TempDir()),
+				MaxSessions: 1,
+			},
+		}),
+	}
+	err := common.Connections.Add(connection)
+	assert.NoError(t, err)
+
+	c := Configuration{}
+	c.handleSftpConnection(nil, connection)
+
+	sshCmd := sshCommand{
+		command:    "cd",
+		connection: connection,
+	}
+	err = sshCmd.handle()
+	if assert.Error(t, err) {
+		assert.Contains(t, err.Error(), "too many open sessions")
+	}
+	scpCmd := scpCommand{
+		sshCommand: sshCommand{
+			command:    "scp",
+			connection: connection,
+		},
+	}
+	err = scpCmd.handle()
+	if assert.Error(t, err) {
+		assert.Contains(t, err.Error(), "too many open sessions")
+	}
+	err = ServeSubSystemConnection(&connection.User, connection.ID, nil, nil)
+	if assert.Error(t, err) {
+		assert.Contains(t, err.Error(), "too many open sessions")
+	}
+	common.Connections.Remove(connection.GetID())
+	assert.Len(t, common.Connections.GetStats(), 0)
+}

+ 4 - 1
sftpd/scp.go

@@ -37,7 +37,10 @@ func (c *scpCommand) handle() (err error) {
 			err = common.ErrGenericFailure
 		}
 	}()
-	common.Connections.Add(c.connection)
+	if err := common.Connections.Add(c.connection); err != nil {
+		logger.Info(logSender, "", "unable to add SCP connection: %v", err)
+		return err
+	}
 	defer common.Connections.Remove(c.connection.GetID())
 
 	destPath := c.getDestPath()

+ 7 - 3
sftpd/server.go

@@ -562,7 +562,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 				case "subsystem":
 					if string(req.Payload[4:]) == "sftp" {
 						ok = true
-						connection := Connection{
+						connection := &Connection{
 							BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, conn.LocalAddr().String(),
 								conn.RemoteAddr().String(), user),
 							ClientVersion: string(sconn.ClientVersion()),
@@ -571,7 +571,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 							channel:       channel,
 							folderPrefix:  c.FolderPrefix,
 						}
-						go c.handleSftpConnection(channel, &connection)
+						go c.handleSftpConnection(channel, connection)
 					}
 				case "exec":
 					// protocol will be set later inside processSSHCommand it could be SSH or SCP
@@ -600,7 +600,11 @@ func (c *Configuration) handleSftpConnection(channel ssh.Channel, connection *Co
 			logger.Error(logSender, "", "panic in handleSftpConnection: %#v stack strace: %v", r, string(debug.Stack()))
 		}
 	}()
-	common.Connections.Add(connection)
+	if err := common.Connections.Add(connection); err != nil {
+		errClose := connection.Disconnect()
+		logger.Info(logSender, "", "unable to add connection: %v, close err: %v", err, errClose)
+		return
+	}
 	defer common.Connections.Remove(connection.GetID())
 
 	// Create the server instance for the channel using the handler we created above.

+ 32 - 20
sftpd/sftpd_test.go

@@ -3,6 +3,7 @@ package sftpd_test
 import (
 	"bufio"
 	"bytes"
+	"context"
 	"crypto/rand"
 	"crypto/sha256"
 	"crypto/sha512"
@@ -129,6 +130,7 @@ var (
 	allPerms         = []string{dataprovider.PermAny}
 	homeBasePath     string
 	scpPath          string
+	scpForce         bool
 	gitPath          string
 	sshPath          string
 	hookCmdPath      string
@@ -935,8 +937,6 @@ func TestConcurrency(t *testing.T) {
 
 			conn, client, err := getSftpClient(user, usePubKey)
 			if assert.NoError(t, err) {
-				err = checkBasicSFTP(client)
-				assert.NoError(t, err)
 				err = sftpUploadFile(testFilePath, testFileName+strconv.Itoa(counter), testFileSize, client)
 				assert.NoError(t, err)
 				assert.Greater(t, common.Connections.GetActiveSessions(defaultUsername), 0)
@@ -9231,7 +9231,7 @@ func TestGitErrors(t *testing.T) {
 
 // Start SCP tests
 func TestSCPBasicHandling(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -9295,7 +9295,7 @@ func TestSCPBasicHandling(t *testing.T) {
 }
 
 func TestSCPUploadFileOverwrite(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -9376,7 +9376,7 @@ func TestSCPUploadFileOverwrite(t *testing.T) {
 }
 
 func TestSCPRecursive(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -9485,7 +9485,7 @@ func TestSCPStartDirectory(t *testing.T) {
 }
 
 func TestSCPPatternsFilter(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -9606,7 +9606,7 @@ func TestSCPUploadMaxSize(t *testing.T) {
 }
 
 func TestSCPVirtualFolders(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -9658,7 +9658,7 @@ func TestSCPVirtualFolders(t *testing.T) {
 }
 
 func TestSCPNestedFolders(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	baseUser, resp, err := httpdtest.AddUser(getTestUser(false), http.StatusCreated)
@@ -9791,7 +9791,7 @@ func TestSCPNestedFolders(t *testing.T) {
 }
 
 func TestSCPVirtualFoldersQuota(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -9889,7 +9889,7 @@ func TestSCPVirtualFoldersQuota(t *testing.T) {
 }
 
 func TestSCPPermsSubDirs(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -9929,7 +9929,7 @@ func TestSCPPermsSubDirs(t *testing.T) {
 }
 
 func TestSCPPermCreateDirs(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -9963,7 +9963,7 @@ func TestSCPPermCreateDirs(t *testing.T) {
 }
 
 func TestSCPPermUpload(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -9987,7 +9987,7 @@ func TestSCPPermUpload(t *testing.T) {
 }
 
 func TestSCPPermOverwrite(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -10014,7 +10014,7 @@ func TestSCPPermOverwrite(t *testing.T) {
 }
 
 func TestSCPPermDownload(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -10043,7 +10043,7 @@ func TestSCPPermDownload(t *testing.T) {
 }
 
 func TestSCPQuotaSize(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -10099,7 +10099,7 @@ func TestSCPQuotaSize(t *testing.T) {
 }
 
 func TestSCPEscapeHomeDir(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -10135,7 +10135,7 @@ func TestSCPEscapeHomeDir(t *testing.T) {
 }
 
 func TestSCPUploadPaths(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -10170,7 +10170,7 @@ func TestSCPUploadPaths(t *testing.T) {
 }
 
 func TestSCPOverwriteDirWithFile(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	usePubKey := true
@@ -10194,7 +10194,7 @@ func TestSCPOverwriteDirWithFile(t *testing.T) {
 }
 
 func TestSCPRemoteToRemote(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	if runtime.GOOS == osWindows {
@@ -10229,7 +10229,7 @@ func TestSCPRemoteToRemote(t *testing.T) {
 }
 
 func TestSCPErrors(t *testing.T) {
-	if len(scpPath) == 0 {
+	if scpPath == "" {
 		t.Skip("scp command not found, unable to execute this test")
 	}
 	u := getTestUser(true)
@@ -10682,6 +10682,9 @@ func getScpDownloadCommand(localPath, remotePath string, preserveTime, recursive
 	if recursive {
 		args = append(args, "-r")
 	}
+	if scpForce {
+		args = append(args, "-O")
+	}
 	args = append(args, "-P")
 	args = append(args, "2022")
 	args = append(args, "-o")
@@ -10707,6 +10710,9 @@ func getScpUploadCommand(localPath, remotePath string, preserveTime, remoteToRem
 			args = append(args, "-r")
 		}
 	}
+	if scpForce {
+		args = append(args, "-O")
+	}
 	args = append(args, "-P")
 	args = append(args, "2022")
 	args = append(args, "-o")
@@ -10770,6 +10776,12 @@ func checkSystemCommands() {
 		logger.Warn(logSender, "", "unable to get scp command. SCP tests will be skipped, err: %v", err)
 		logger.WarnToConsole("unable to get scp command. SCP tests will be skipped, err: %v", err)
 		scpPath = ""
+	} else {
+		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+		defer cancel()
+		cmd := exec.CommandContext(ctx, scpPath, "-O")
+		out, _ := cmd.CombinedOutput()
+		scpForce = !strings.Contains(string(out), "option -- O")
 	}
 }
 

+ 5 - 2
sftpd/ssh_cmd.go

@@ -115,7 +115,10 @@ func (c *sshCommand) handle() (err error) {
 			err = common.ErrGenericFailure
 		}
 	}()
-	common.Connections.Add(c.connection)
+	if err := common.Connections.Add(c.connection); err != nil {
+		logger.Info(logSender, "", "unable to add SSH command connection: %v", err)
+		return err
+	}
 	defer common.Connections.Remove(c.connection.GetID())
 
 	c.connection.UpdateLastActivity()
@@ -131,7 +134,7 @@ func (c *sshCommand) handle() (err error) {
 		c.sendExitStatus(nil)
 	} else if c.command == "pwd" {
 		// hard coded response to "/"
-		c.connection.channel.Write([]byte("/\n")) //nolint:errcheck
+		c.connection.channel.Write([]byte(util.CleanPath(c.connection.User.Filters.StartDirectory) + "\n")) //nolint:errcheck
 		c.sendExitStatus(nil)
 	} else if c.command == "sftpgo-copy" {
 		return c.handleSFTPGoCopy()

+ 7 - 2
sftpd/subsystem.go

@@ -43,7 +43,6 @@ func ServeSubSystemConnection(user *dataprovider.User, connectionID string, read
 		logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose)
 		return err
 	}
-	dataprovider.UpdateLastLogin(user)
 
 	connection := &Connection{
 		BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolSFTP, "", "", *user),
@@ -52,9 +51,15 @@ func ServeSubSystemConnection(user *dataprovider.User, connectionID string, read
 		LocalAddr:      &net.IPAddr{},
 		channel:        newSubsystemChannel(reader, writer),
 	}
-	common.Connections.Add(connection)
+	err = common.Connections.Add(connection)
+	if err != nil {
+		errClose := user.CloseFs()
+		logger.Warn(logSender, connectionID, "unable to add connection: %v close fs error: %v", err, errClose)
+		return err
+	}
 	defer common.Connections.Remove(connection.GetID())
 
+	dataprovider.UpdateLastLogin(user)
 	server := sftp.NewRequestServer(connection.channel, sftp.Handlers{
 		FileGet:  connection,
 		FilePut:  connection,

+ 4 - 8
vfs/cryptfs.go

@@ -203,18 +203,14 @@ func (fs *CryptFs) ReadDir(dirname string) ([]os.FileInfo, error) {
 	if err != nil {
 		return nil, err
 	}
-	entries, err := f.ReadDir(-1)
+	list, err := f.Readdir(-1)
 	f.Close()
 	if err != nil {
 		return nil, err
 	}
-	result := make([]os.FileInfo, len(entries))
-	for idx, entry := range entries {
-		info, err := entry.Info()
-		if err != nil {
-			return nil, err
-		}
-		result[idx] = fs.ConvertFileInfo(info)
+	result := make([]os.FileInfo, 0, len(list))
+	for _, info := range list {
+		result = append(result, fs.ConvertFileInfo(info))
 	}
 	return result, nil
 }

+ 2 - 11
vfs/osfs.go

@@ -3,7 +3,6 @@ package vfs
 import (
 	"fmt"
 	"io"
-	"io/fs"
 	"net/http"
 	"os"
 	"path"
@@ -171,20 +170,12 @@ func (*OsFs) ReadDir(dirname string) ([]os.FileInfo, error) {
 	if err != nil {
 		return nil, err
 	}
-	entries, err := f.ReadDir(-1)
+	list, err := f.Readdir(-1)
 	f.Close()
 	if err != nil {
 		return nil, err
 	}
-	result := make([]fs.FileInfo, len(entries))
-	for idx, entry := range entries {
-		info, err := entry.Info()
-		if err != nil {
-			return nil, err
-		}
-		result[idx] = info
-	}
-	return result, nil
+	return list, nil
 }
 
 // IsUploadResumeSupported returns true if resuming uploads is supported

+ 12 - 14
webdavd/server.go

@@ -196,19 +196,25 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	updateLoginMetrics(&user, ipAddr, loginMethod, err)
-
-	ctx := context.WithValue(r.Context(), requestIDKey, connectionID)
-	ctx = context.WithValue(ctx, requestStartKey, time.Now())
-
 	connection := &Connection{
 		BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolWebDAV, util.GetHTTPLocalAddress(r),
 			r.RemoteAddr, user),
 		request: r,
 	}
-	common.Connections.Add(connection)
+	if err = common.Connections.Add(connection); err != nil {
+		errClose := user.CloseFs()
+		logger.Warn(logSender, connectionID, "unable add connection: %v close fs error: %v", err, errClose)
+		updateLoginMetrics(&user, ipAddr, loginMethod, err)
+		http.Error(w, err.Error(), http.StatusTooManyRequests)
+		return
+	}
 	defer common.Connections.Remove(connection.GetID())
 
+	updateLoginMetrics(&user, ipAddr, loginMethod, err)
+
+	ctx := context.WithValue(r.Context(), requestIDKey, connectionID)
+	ctx = context.WithValue(ctx, requestStartKey, time.Now())
+
 	dataprovider.UpdateLastLogin(&user)
 
 	if s.checkRequestMethod(ctx, r, connection) {
@@ -311,14 +317,6 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo
 			user.Username, loginMethod)
 		return connID, fmt.Errorf("login method %v is not allowed for user %#v", loginMethod, user.Username)
 	}
-	if user.MaxSessions > 0 {
-		activeSessions := common.Connections.GetActiveSessions(user.Username)
-		if activeSessions >= user.MaxSessions {
-			logger.Info(logSender, connID, "authentication refused for user: %#v, too many open sessions: %v/%v",
-				user.Username, activeSessions, user.MaxSessions)
-			return connID, fmt.Errorf("too many open sessions: %v", activeSessions)
-		}
-	}
 	if !user.IsLoginFromAddrAllowed(r.RemoteAddr) {
 		logger.Info(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v",
 			user.Username, r.RemoteAddr)

+ 4 - 2
webdavd/webdavd_test.go

@@ -1167,7 +1167,8 @@ func TestMaxConnections(t *testing.T) {
 	connection := &webdavd.Connection{
 		BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user),
 	}
-	common.Connections.Add(connection)
+	err = common.Connections.Add(connection)
+	assert.NoError(t, err)
 	assert.Error(t, checkBasicFunc(client))
 	common.Connections.Remove(connection.GetID())
 	_, err = httpdtest.RemoveUser(user, http.StatusOK)
@@ -1222,7 +1223,8 @@ func TestMaxSessions(t *testing.T) {
 	connection := &webdavd.Connection{
 		BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user),
 	}
-	common.Connections.Add(connection)
+	err = common.Connections.Add(connection)
+	assert.NoError(t, err)
 	assert.Error(t, checkBasicFunc(client))
 	common.Connections.Remove(connection.GetID())
 	_, err = httpdtest.RemoveUser(user, http.StatusOK)