refactoring of user session counters

Fixes #792

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2022-04-14 19:07:41 +02:00
parent 5bc0f4f8af
commit 002a06629e
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
28 changed files with 542 additions and 199 deletions

View file

@ -98,6 +98,7 @@ func init() {
Connections.clients = clientsMap{
clients: make(map[string]int),
}
Connections.perUserConns = make(map[string]int)
}
// errors definitions
@ -345,6 +346,7 @@ type ActiveTransfer interface {
type ActiveConnection interface {
GetID() string
GetUsername() string
GetMaxSessions() int
GetLocalAddress() string
GetRemoteAddress() string
GetClientVersion() string
@ -733,6 +735,29 @@ type ActiveConnections struct {
sync.RWMutex
connections []ActiveConnection
sshConnections []*SSHConnection
perUserConns map[string]int
}
// internal method, must be called within a locked block
func (conns *ActiveConnections) addUserConnection(username string) {
if username == "" {
return
}
conns.perUserConns[username]++
}
// internal method, must be called within a locked block
func (conns *ActiveConnections) removeUserConnection(username string) {
if username == "" {
return
}
if val, ok := conns.perUserConns[username]; ok {
conns.perUserConns[username]--
if val > 1 {
return
}
delete(conns.perUserConns, username)
}
}
// GetActiveSessions returns the number of active sessions for the given username.
@ -741,24 +766,27 @@ func (conns *ActiveConnections) GetActiveSessions(username string) int {
conns.RLock()
defer conns.RUnlock()
numSessions := 0
for _, c := range conns.connections {
if c.GetUsername() == username {
numSessions++
}
}
return numSessions
return conns.perUserConns[username]
}
// Add adds a new connection to the active ones
func (conns *ActiveConnections) Add(c ActiveConnection) {
func (conns *ActiveConnections) Add(c ActiveConnection) error {
conns.Lock()
defer conns.Unlock()
if username := c.GetUsername(); username != "" {
if maxSessions := c.GetMaxSessions(); maxSessions > 0 {
if val := conns.perUserConns[username]; val >= maxSessions {
return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
}
}
conns.addUserConnection(username)
}
conns.connections = append(conns.connections, c)
metric.UpdateActiveConnectionsSize(len(conns.connections))
logger.Debug(c.GetProtocol(), c.GetID(), "connection added, local address %#v, remote address %#v, num open connections: %v",
c.GetLocalAddress(), c.GetRemoteAddress(), len(conns.connections))
return nil
}
// Swap replaces an existing connection with the given one.
@ -771,6 +799,16 @@ func (conns *ActiveConnections) Swap(c ActiveConnection) error {
for idx, conn := range conns.connections {
if conn.GetID() == c.GetID() {
conns.removeUserConnection(conn.GetUsername())
if username := c.GetUsername(); username != "" {
if maxSessions := c.GetMaxSessions(); maxSessions > 0 {
if val := conns.perUserConns[username]; val >= maxSessions {
conns.addUserConnection(conn.GetUsername())
return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
}
}
conns.addUserConnection(username)
}
err := conn.CloseFS()
conns.connections[idx] = c
logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err)
@ -793,6 +831,7 @@ func (conns *ActiveConnections) Remove(connectionID string) {
conns.connections[idx] = conns.connections[lastIdx]
conns.connections[lastIdx] = nil
conns.connections = conns.connections[:lastIdx]
conns.removeUserConnection(conn.GetUsername())
metric.UpdateActiveConnectionsSize(lastIdx)
logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %#v, remote address %#v close fs error: %v, num open connections: %v",
conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)

View file

@ -370,6 +370,29 @@ func TestWhitelist(t *testing.T) {
Config = configCopy
}
func TestUserMaxSessions(t *testing.T) {
c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername,
MaxSessions: 1,
},
})
fakeConn := &fakeConnection{
BaseConnection: c,
}
err := Connections.Add(fakeConn)
assert.NoError(t, err)
err = Connections.Add(fakeConn)
assert.Error(t, err)
err = Connections.Swap(fakeConn)
assert.NoError(t, err)
Connections.Remove(fakeConn.GetID())
Connections.Lock()
Connections.removeUserConnection(userTestUsername)
Connections.Unlock()
assert.Len(t, Connections.GetStats(), 0)
}
func TestMaxConnections(t *testing.T) {
oldValue := Config.MaxTotalConnections
perHost := Config.MaxPerHostConnections
@ -387,7 +410,8 @@ func TestMaxConnections(t *testing.T) {
fakeConn := &fakeConnection{
BaseConnection: c,
}
Connections.Add(fakeConn)
err := Connections.Add(fakeConn)
assert.NoError(t, err)
assert.Len(t, Connections.GetStats(), 1)
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
@ -466,14 +490,16 @@ func TestIdleConnections(t *testing.T) {
sshConn1.lastActivity = c.lastActivity
sshConn2.lastActivity = c.lastActivity
Connections.AddSSHConnection(sshConn1)
Connections.Add(fakeConn)
err = Connections.Add(fakeConn)
assert.NoError(t, err)
assert.Equal(t, Connections.GetActiveSessions(username), 1)
c = NewBaseConnection(sshConn2.id+"_1", ProtocolSSH, "", "", user)
fakeConn = &fakeConnection{
BaseConnection: c,
}
Connections.AddSSHConnection(sshConn2)
Connections.Add(fakeConn)
err = Connections.Add(fakeConn)
assert.NoError(t, err)
assert.Equal(t, Connections.GetActiveSessions(username), 2)
cFTP := NewBaseConnection("id2", ProtocolFTP, "", "", dataprovider.User{})
@ -481,7 +507,8 @@ func TestIdleConnections(t *testing.T) {
fakeConn = &fakeConnection{
BaseConnection: cFTP,
}
Connections.Add(fakeConn)
err = Connections.Add(fakeConn)
assert.NoError(t, err)
assert.Equal(t, Connections.GetActiveSessions(username), 2)
assert.Len(t, Connections.GetStats(), 3)
Connections.RLock()
@ -521,7 +548,8 @@ func TestCloseConnection(t *testing.T) {
BaseConnection: c,
}
assert.True(t, Connections.IsNewConnectionAllowed("127.0.0.1"))
Connections.Add(fakeConn)
err := Connections.Add(fakeConn)
assert.NoError(t, err)
assert.Len(t, Connections.GetStats(), 1)
res := Connections.Close(fakeConn.GetID())
assert.True(t, res)
@ -536,19 +564,34 @@ func TestSwapConnection(t *testing.T) {
fakeConn := &fakeConnection{
BaseConnection: c,
}
Connections.Add(fakeConn)
err := Connections.Add(fakeConn)
assert.NoError(t, err)
if assert.Len(t, Connections.GetStats(), 1) {
assert.Equal(t, "", Connections.GetStats()[0].Username)
}
c = NewBaseConnection("id", ProtocolFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername,
Username: userTestUsername,
MaxSessions: 1,
},
})
fakeConn = &fakeConnection{
BaseConnection: c,
}
err := Connections.Swap(fakeConn)
c1 := NewBaseConnection("id1", ProtocolFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername,
},
})
fakeConn1 := &fakeConnection{
BaseConnection: c1,
}
err = Connections.Add(fakeConn1)
assert.NoError(t, err)
err = Connections.Swap(fakeConn)
assert.Error(t, err)
Connections.Remove(fakeConn1.ID)
err = Connections.Swap(fakeConn)
assert.NoError(t, err)
if assert.Len(t, Connections.GetStats(), 1) {
assert.Equal(t, userTestUsername, Connections.GetStats()[0].Username)
@ -600,9 +643,12 @@ func TestConnectionStatus(t *testing.T) {
command: "PROPFIND",
}
t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
Connections.Add(fakeConn1)
Connections.Add(fakeConn2)
Connections.Add(fakeConn3)
err := Connections.Add(fakeConn1)
assert.NoError(t, err)
err = Connections.Add(fakeConn2)
assert.NoError(t, err)
err = Connections.Add(fakeConn3)
assert.NoError(t, err)
stats := Connections.GetStats()
assert.Len(t, stats, 3)
@ -628,7 +674,7 @@ func TestConnectionStatus(t *testing.T) {
}
}
err := t1.Close()
err = t1.Close()
assert.NoError(t, err)
err = t2.Close()
assert.NoError(t, err)

View file

@ -80,6 +80,11 @@ func (c *BaseConnection) GetUsername() string {
return c.User.Username
}
// GetMaxSessions returns the maximum number of concurrent sessions allowed
func (c *BaseConnection) GetMaxSessions() int {
return c.User.MaxSessions
}
// GetProtocol returns the protocol for the connection
func (c *BaseConnection) GetProtocol() string {
return c.protocol

View file

@ -62,7 +62,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
"/file1", TransferUpload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
transfer1.BytesReceived = 150
Connections.Add(fakeConn1)
err = Connections.Add(fakeConn1)
assert.NoError(t, err)
// the transferschecker will do nothing if there is only one ongoing transfer
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
@ -76,7 +77,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
"/file2", TransferUpload, 0, 0, 120, 40, true, fsUser, dataprovider.TransferQuota{})
transfer1.BytesReceived = 50
transfer2.BytesReceived = 60
Connections.Add(fakeConn2)
err = Connections.Add(fakeConn2)
assert.NoError(t, err)
connID3 := xid.New().String()
conn3 := NewBaseConnection(connID3, ProtocolSFTP, "", "", user)
@ -86,7 +88,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"),
"/file3", TransferDownload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
transfer3.BytesReceived = 60 // this value will be ignored, this is a download
Connections.Add(fakeConn3)
err = Connections.Add(fakeConn3)
assert.NoError(t, err)
// the transfers are not overquota
Connections.checkTransfers()
@ -146,7 +149,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(os.TempDir(), folderName, "file1"),
filepath.Join(os.TempDir(), folderName, "file1"), path.Join(vdirPath, "/file1"), TransferUpload, 0, 0,
100, 0, true, fsFolder, dataprovider.TransferQuota{})
Connections.Add(fakeConn4)
err = Connections.Add(fakeConn4)
assert.NoError(t, err)
connID5 := xid.New().String()
conn5 := NewBaseConnection(connID5, ProtocolSFTP, "", "", user)
fakeConn5 := &fakeConnection{
@ -156,7 +160,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
filepath.Join(os.TempDir(), folderName, "file2"), path.Join(vdirPath, "/file2"), TransferUpload, 0, 0,
100, 0, true, fsFolder, dataprovider.TransferQuota{})
Connections.Add(fakeConn5)
err = Connections.Add(fakeConn5)
assert.NoError(t, err)
transfer4.BytesReceived = 50
transfer5.BytesReceived = 40
Connections.checkTransfers()
@ -245,7 +250,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
"/file1", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
transfer1.BytesReceived = 150
Connections.Add(fakeConn1)
err = Connections.Add(fakeConn1)
assert.NoError(t, err)
// the transferschecker will do nothing if there is only one ongoing transfer
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
@ -258,7 +264,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
"/file2", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
transfer2.BytesReceived = 150
Connections.Add(fakeConn2)
err = Connections.Add(fakeConn2)
assert.NoError(t, err)
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
@ -294,7 +301,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
"/file1", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
transfer3.BytesSent = 150
Connections.Add(fakeConn3)
err = Connections.Add(fakeConn3)
assert.NoError(t, err)
connID4 := xid.New().String()
conn4 := NewBaseConnection(connID4, ProtocolSFTP, "", "", user)
@ -304,7 +312,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
"/file2", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
transfer4.BytesSent = 150
Connections.Add(fakeConn4)
err = Connections.Add(fakeConn4)
assert.NoError(t, err)
Connections.checkTransfers()
assert.Nil(t, transfer3.errAbort)
assert.Nil(t, transfer4.errAbort)

View file

@ -593,7 +593,8 @@ func TestClientVersion(t *testing.T) {
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user),
clientContext: mockCC,
}
common.Connections.Add(connection)
err := common.Connections.Add(connection)
assert.NoError(t, err)
stats := common.Connections.GetStats()
if assert.Len(t, stats, 1) {
assert.Equal(t, "mock version", stats[0].ClientVersion)

View file

@ -168,8 +168,8 @@ func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
cc.RemoteAddr().String(), user),
clientContext: cc,
}
common.Connections.Add(connection)
return s.initialMsg, nil
err = common.Connections.Add(connection)
return s.initialMsg, err
}
// ClientDisconnected is called when the user disconnects, even if he never authenticated
@ -367,9 +367,9 @@ func (s *Server) validateUser(user dataprovider.User, cc ftpserver.ClientContext
}
err = common.Connections.Swap(connection)
if err != nil {
err = user.CloseFs()
logger.Warn(logSender, connectionID, "unable to swap connection, close fs error: %v", err)
return nil, common.ErrInternalFailure
errClose := user.CloseFs()
logger.Warn(logSender, connectionID, "unable to swap connection: %v, close fs error: %v", err, errClose)
return nil, err
}
return connection, nil
}

15
go.mod
View file

@ -12,9 +12,9 @@ require (
github.com/aws/aws-sdk-go-v2/config v1.15.3
github.com/aws/aws-sdk-go-v2/credentials v1.11.2
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.4
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.5
github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.3
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.4
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.5
github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.4
github.com/aws/aws-sdk-go-v2/service/sts v1.16.3
github.com/cockroachdb/cockroach-go/v2 v2.2.8
@ -32,10 +32,10 @@ require (
github.com/grandcat/zeroconf v1.0.0
github.com/hashicorp/go-hclog v1.2.0
github.com/hashicorp/go-plugin v1.4.3
github.com/hashicorp/go-retryablehttp v0.7.0
github.com/hashicorp/go-retryablehttp v0.7.1
github.com/jlaffaye/ftp v0.0.0-20201112195030-9aae4d151126
github.com/klauspost/compress v1.15.1
github.com/lestrrat-go/jwx v1.2.22
github.com/lestrrat-go/jwx v1.2.23
github.com/lib/pq v1.10.5
github.com/lithammer/shortuuid/v3 v3.0.7
github.com/mattn/go-sqlite3 v1.14.12
@ -54,7 +54,7 @@ require (
github.com/shirou/gopsutil/v3 v3.22.3
github.com/spf13/afero v1.8.2
github.com/spf13/cobra v1.4.0
github.com/spf13/viper v1.10.1
github.com/spf13/viper v1.11.0
github.com/stretchr/testify v1.7.1
github.com/studio-b12/gowebdav v0.0.0-20220128162035-c7b1ff8a5e62
github.com/unrolled/secure v1.10.0
@ -67,7 +67,7 @@ require (
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4
golang.org/x/net v0.0.0-20220412020605-290c469a71a5
golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5
golang.org/x/sys v0.0.0-20220412071739-889880a91fd5
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad
golang.org/x/time v0.0.0-20220411224347-583f2d630306
google.golang.org/api v0.74.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0
@ -130,6 +130,7 @@ require (
github.com/mitchellh/mapstructure v1.4.3 // indirect
github.com/oklog/run v1.1.0 // indirect
github.com/pelletier/go-toml v1.9.4 // indirect
github.com/pelletier/go-toml/v2 v2.0.0-beta.8 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20220216144756-c35f1ee13d7c // indirect
@ -151,7 +152,7 @@ require (
golang.org/x/tools v0.1.10 // indirect
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20220407144326-9054f6ed7bac // indirect
google.golang.org/genproto v0.0.0-20220413183235-5e96e2839df9 // indirect
google.golang.org/grpc v1.45.0 // indirect
google.golang.org/protobuf v1.28.0 // indirect
gopkg.in/ini.v1 v1.66.4 // indirect

29
go.sum
View file

@ -142,8 +142,8 @@ github.com/aws/aws-sdk-go-v2/credentials v1.11.2/go.mod h1:j8YsY9TXTm31k4eFhspiQ
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3 h1:LWPg5zjHV9oz/myQr4wMs0gi4CjnDN/ILmyZUFYXZsU=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3/go.mod h1:uk1vhHHERfSVCUnqSqz8O48LBYDSC+k6brng09jcMOk=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.3/go.mod h1:0dHuD2HZZSiwfJSy1FO5bX1hQ1TxVV1QXXjpn3XUE44=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.4 h1:iqcMQBj/B3FPxVb5SGNHC8XAh64hmaWUC8piZArBE7U=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.4/go.mod h1:s79ZPBpDzcR1BCuAhGCF1rgmd/QmLueKCvdkmX4SDgg=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.5 h1:lPo/NX1o4vkk2C7mHmB2FCf9Qp7KZNHrlzHxdP/yugw=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.5/go.mod h1:JNo9mEKrjnmDBc19z65TZmj1xG9PQHu2GOlApYk31DU=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9 h1:onz/VaaxZ7Z4V+WIN9Txly9XLTmoOh1oJ8XcAC3pako=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9/go.mod h1:AnVH5pvai0pAF4lXRq0bmhbes1u9R8wTE+g+183bZNM=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.3 h1:9stUQR/u2KXU6HkFJYlqnZEjBnbgrVbG6I5HN09xZh0=
@ -164,8 +164,8 @@ github.com/aws/aws-sdk-go-v2/service/kms v1.16.3/go.mod h1:QuiHPBqlOFCi4LqdSskYY
github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.3 h1:xqXHk4UDW7ii4MRciyLpY87yuZds0iymmgHt3h35xTE=
github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.3/go.mod h1:HT0cm2+NUCF33MdXjck554HC6VRgQ4q6JIlSqlYZ18Y=
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.3/go.mod h1:g1qvDuRsJY+XghsV6zg00Z4KJ7DtFFCx8fJD2a491Ak=
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.4 h1:frOI/v6KWuKGlKUA5gheRw01EDpxcCxTalFQkCOZXAo=
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.4/go.mod h1:qFKU5d+PAv+23bi9ZhtWeA+TmLUz7B/R59ZGXQ1Mmu4=
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.5 h1:A3PuAUlh1u47WHcM68CDaG9ZWjK7ewePjDp+0dY9yv4=
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.5/go.mod h1:qFKU5d+PAv+23bi9ZhtWeA+TmLUz7B/R59ZGXQ1Mmu4=
github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.4 h1:EmIEXOjAdXtxa2OGM1VAajZV/i06Q8qd4kBpJd9/p1k=
github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.4/go.mod h1:PJc8s+lxyU8rrre0/4a0pn2wgwiDvOEzoOjcJUBr67o=
github.com/aws/aws-sdk-go-v2/service/sns v1.17.4/go.mod h1:kElt+uCcXxcqFyc+bQqZPFD9DME/eC6oHBXvFzQ9Bcw=
@ -427,8 +427,8 @@ github.com/hashicorp/go-hclog v1.2.0 h1:La19f8d7WIlm4ogzNHB0JGqs5AUDAZ2UfCY4sJXc
github.com/hashicorp/go-hclog v1.2.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
github.com/hashicorp/go-plugin v1.4.3 h1:DXmvivbWD5qdiBts9TpBC7BYL1Aia5sxbRgQB+v6UZM=
github.com/hashicorp/go-plugin v1.4.3/go.mod h1:5fGEH17QVwTTcR0zV7yhDPLLmFX9YSZ38b18Udy6vYQ=
github.com/hashicorp/go-retryablehttp v0.7.0 h1:eu1EI/mbirUgP5C8hVsTNaGZreBDlYiwC1FZWkvQPQ4=
github.com/hashicorp/go-retryablehttp v0.7.0/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY=
github.com/hashicorp/go-retryablehttp v0.7.1 h1:sUiuQAnLlbvmExtFQs72iFW/HXeUn8Z1aJLQ4LJJbTQ=
github.com/hashicorp/go-retryablehttp v0.7.1/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
@ -549,8 +549,8 @@ github.com/lestrrat-go/iter v1.0.1/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbq
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
github.com/lestrrat-go/jwx v1.2.6/go.mod h1:tJuGuAI3LC71IicTx82Mz1n3w9woAs2bYJZpkjJQ5aU=
github.com/lestrrat-go/jwx v1.2.22 h1:bCxMokwHNuJHVxgANP4OBddXGtQ9Oy+6cqp4O2rW7DU=
github.com/lestrrat-go/jwx v1.2.22/go.mod h1:sAXjRwzSvCN6soO4RLoWWm1bVPpb8iOuv0IYfH8OWd8=
github.com/lestrrat-go/jwx v1.2.23 h1:8oP5fY1yzCRraUNNyfAVdOkLCqY7xMZz11lVcvHqC1Y=
github.com/lestrrat-go/jwx v1.2.23/go.mod h1:sAXjRwzSvCN6soO4RLoWWm1bVPpb8iOuv0IYfH8OWd8=
github.com/lestrrat-go/option v1.0.0 h1:WqAWL8kh8VcSoD6xjSH34/1m8yxluXQbDeKNfvFeEO4=
github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
@ -627,6 +627,8 @@ github.com/otiai10/mint v1.3.3 h1:7JgpsBaN0uMkyju4tbYHu0mnM55hNKVYLsXmwr15NQI=
github.com/otiai10/mint v1.3.3/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc=
github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM=
github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
github.com/pelletier/go-toml/v2 v2.0.0-beta.8 h1:dy81yyLYJDwMTifq24Oi/IslOslRrDSb3jwDggjz3Z0=
github.com/pelletier/go-toml/v2 v2.0.0-beta.8/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pires/go-proxyproto v0.6.2 h1:KAZ7UteSOt6urjme6ZldyFm4wDe/z0ZUP0Yv0Dos0d8=
github.com/pires/go-proxyproto v0.6.2/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY=
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA=
@ -708,8 +710,8 @@ github.com/spf13/jwalterweatherman v1.1.0 h1:ue6voC5bR5F8YxI5S67j9i582FU4Qvo2bmq
github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.10.1 h1:nuJZuYpG7gTj/XqiUwg8bA0cp1+M2mC3J4g5luUYBKk=
github.com/spf13/viper v1.10.1/go.mod h1:IGlFPqhNAPKRxohIzWpI5QEy4kuI7tcl5WvR+8qy1rU=
github.com/spf13/viper v1.11.0 h1:7OX/1FS6n7jHD1zGrZTM7WtY13ZELRyosK4k93oPr44=
github.com/spf13/viper v1.11.0/go.mod h1:djo0X/bA5+tYVoCn+C7cAYJGcVn/qYLFTG8gdUsX7Zk=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
@ -934,8 +936,8 @@ golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412071739-889880a91fd5 h1:NubxfvTRuNb4RVzWrIDAUzUvREH1HkCD4JjyQTSG9As=
golang.org/x/sys v0.0.0-20220412071739-889880a91fd5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@ -1166,8 +1168,9 @@ google.golang.org/genproto v0.0.0-20220310185008-1973136f34c6/go.mod h1:kGP+zUP2
google.golang.org/genproto v0.0.0-20220324131243-acbaeb5b85eb/go.mod h1:hAL49I2IFola2sVEjAn7MEwsja0xp51I0tlGAf9hz4E=
google.golang.org/genproto v0.0.0-20220401170504-314d38edb7de/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo=
google.golang.org/genproto v0.0.0-20220405205423-9d709892a2bf/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo=
google.golang.org/genproto v0.0.0-20220407144326-9054f6ed7bac h1:qSNTkEN+L2mvWcLgJOR+8bdHX9rN/IdU3A1Ghpfb1Rg=
google.golang.org/genproto v0.0.0-20220407144326-9054f6ed7bac/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo=
google.golang.org/genproto v0.0.0-20220413183235-5e96e2839df9 h1:XGQ6tc+EnM35IAazg4y6AHmUg4oK8NXsXaILte1vRlk=
google.golang.org/genproto v0.0.0-20220413183235-5e96e2839df9/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo=
google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=

View file

@ -34,7 +34,7 @@ func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, err
connID := xid.New().String()
protocol := getProtocolFromRequest(r)
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return nil, err
}
@ -43,6 +43,10 @@ func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, err
r.RemoteAddr, user),
request: r,
}
if err = common.Connections.Add(connection); err != nil {
sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
return connection, err
}
return connection, nil
}
@ -52,7 +56,6 @@ func readUserFolder(w http.ResponseWriter, r *http.Request) {
if err != nil {
return
}
common.Connections.Add(connection)
defer common.Connections.Remove(connection.GetID())
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@ -70,7 +73,6 @@ func createUserDir(w http.ResponseWriter, r *http.Request) {
if err != nil {
return
}
common.Connections.Add(connection)
defer common.Connections.Remove(connection.GetID())
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@ -90,22 +92,7 @@ func createUserDir(w http.ResponseWriter, r *http.Request) {
func renameUserDir(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
connection, err := getUserConnection(w, r)
if err != nil {
return
}
common.Connections.Add(connection)
defer common.Connections.Remove(connection.GetID())
oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
newName := connection.User.GetCleanedPath(r.URL.Query().Get("target"))
err = connection.Rename(oldName, newName)
if err != nil {
sendAPIResponse(w, r, err, fmt.Sprintf("Unable to rename directory %#v to %#v", oldName, newName),
getMappedStatusCode(err))
return
}
sendAPIResponse(w, r, nil, fmt.Sprintf("Directory %#v renamed to %#v", oldName, newName), http.StatusOK)
renameItem(w, r)
}
func deleteUserDir(w http.ResponseWriter, r *http.Request) {
@ -114,7 +101,6 @@ func deleteUserDir(w http.ResponseWriter, r *http.Request) {
if err != nil {
return
}
common.Connections.Add(connection)
defer common.Connections.Remove(connection.GetID())
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@ -132,7 +118,6 @@ func getUserFile(w http.ResponseWriter, r *http.Request) {
if err != nil {
return
}
common.Connections.Add(connection)
defer common.Connections.Remove(connection.GetID())
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@ -183,7 +168,6 @@ func setFileDirMetadata(w http.ResponseWriter, r *http.Request) {
if err != nil {
return
}
common.Connections.Add(connection)
defer common.Connections.Remove(connection.GetID())
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@ -214,7 +198,6 @@ func uploadUserFile(w http.ResponseWriter, r *http.Request) {
if err != nil {
return
}
common.Connections.Add(connection)
defer common.Connections.Remove(connection.GetID())
filePath := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@ -258,6 +241,8 @@ func uploadUserFiles(w http.ResponseWriter, r *http.Request) {
if err != nil {
return
}
defer common.Connections.Remove(connection.GetID())
transferQuota := connection.GetTransferQuota()
if !transferQuota.HasUploadSpace() {
connection.Log(logger.LevelInfo, "denying file write due to transfer quota limits")
@ -265,8 +250,6 @@ func uploadUserFiles(w http.ResponseWriter, r *http.Request) {
http.StatusRequestEntityTooLarge)
return
}
common.Connections.Add(connection)
defer common.Connections.Remove(connection.GetID())
t := newThrottledReader(r.Body, connection.User.UploadBandwidth, connection)
r.Body = t
@ -332,22 +315,7 @@ func doUploadFiles(w http.ResponseWriter, r *http.Request, connection *Connectio
func renameUserFile(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
connection, err := getUserConnection(w, r)
if err != nil {
return
}
common.Connections.Add(connection)
defer common.Connections.Remove(connection.GetID())
oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
newName := connection.User.GetCleanedPath(r.URL.Query().Get("target"))
err = connection.Rename(oldName, newName)
if err != nil {
sendAPIResponse(w, r, err, fmt.Sprintf("Unable to rename file %#v to %#v", oldName, newName),
getMappedStatusCode(err))
return
}
sendAPIResponse(w, r, nil, fmt.Sprintf("File %#v renamed to %#v", oldName, newName), http.StatusOK)
renameItem(w, r)
}
func deleteUserFile(w http.ResponseWriter, r *http.Request) {
@ -356,7 +324,6 @@ func deleteUserFile(w http.ResponseWriter, r *http.Request) {
if err != nil {
return
}
common.Connections.Add(connection)
defer common.Connections.Remove(connection.GetID())
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@ -393,7 +360,6 @@ func getUserFilesAsZipStream(w http.ResponseWriter, r *http.Request) {
if err != nil {
return
}
common.Connections.Add(connection)
defer common.Connections.Remove(connection.GetID())
var filesList []string
@ -581,3 +547,21 @@ func setModificationTimeFromHeader(r *http.Request, c *Connection, filePath stri
}
}
}
func renameItem(w http.ResponseWriter, r *http.Request) {
connection, err := getUserConnection(w, r)
if err != nil {
return
}
defer common.Connections.Remove(connection.GetID())
oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
newName := connection.User.GetCleanedPath(r.URL.Query().Get("target"))
err = connection.Rename(oldName, newName)
if err != nil {
sendAPIResponse(w, r, err, fmt.Sprintf("Unable to rename %#v -> %#v", oldName, newName),
getMappedStatusCode(err))
return
}
sendAPIResponse(w, r, nil, fmt.Sprintf("%#v renamed to %#v", oldName, newName), http.StatusOK)
}

View file

@ -167,7 +167,10 @@ func (s *httpdServer) readBrowsableShareContents(w http.ResponseWriter, r *http.
return
}
common.Connections.Add(connection)
if err = common.Connections.Add(connection); err != nil {
sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
return
}
defer common.Connections.Remove(connection.GetID())
contents, err := connection.ReadDir(name)
@ -194,7 +197,10 @@ func (s *httpdServer) downloadBrowsableSharedFile(w http.ResponseWriter, r *http
return
}
common.Connections.Add(connection)
if err = common.Connections.Add(connection); err != nil {
sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
return
}
defer common.Connections.Remove(connection.GetID())
info, err := connection.Stat(name, 1)
@ -231,7 +237,10 @@ func (s *httpdServer) downloadFromShare(w http.ResponseWriter, r *http.Request)
return
}
common.Connections.Add(connection)
if err = common.Connections.Add(connection); err != nil {
sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
return
}
defer common.Connections.Remove(connection.GetID())
compress := true
@ -289,7 +298,10 @@ func (s *httpdServer) uploadFileToShare(w http.ResponseWriter, r *http.Request)
}
dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck
common.Connections.Add(connection)
if err = common.Connections.Add(connection); err != nil {
sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
return
}
defer common.Connections.Remove(connection.GetID())
if err := doUploadFile(w, r, connection, filePath); err != nil {
dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck
@ -313,7 +325,10 @@ func (s *httpdServer) uploadFilesToShare(w http.ResponseWriter, r *http.Request)
return
}
common.Connections.Add(connection)
if err = common.Connections.Add(connection); err != nil {
sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
return
}
defer common.Connections.Remove(connection.GetID())
t := newThrottledReader(r.Body, connection.User.UploadBandwidth, connection)

View file

@ -498,7 +498,7 @@ func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err err
dataprovider.ExecutePostLoginHook(user, loginMethod, ip, protocol, err)
}
func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string) error {
func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string, checkSessions bool) error {
if util.IsStringInSlice(common.ProtocolHTTP, user.Filters.DeniedProtocols) {
logger.Info(logSender, connectionID, "cannot login user %#v, protocol HTTP is not allowed", user.Username)
return fmt.Errorf("protocol HTTP is not allowed for user %#v", user.Username)
@ -507,7 +507,7 @@ func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID
logger.Info(logSender, connectionID, "cannot login user %#v, password login method is not allowed", user.Username)
return fmt.Errorf("login method password is not allowed for user %#v", user.Username)
}
if user.MaxSessions > 0 {
if checkSessions && user.MaxSessions > 0 {
activeSessions := common.Connections.GetActiveSessions(user.Username)
if activeSessions >= user.MaxSessions {
logger.Info(logSender, connectionID, "authentication refused for user: %#v, too many open sessions: %v/%v", user.Username,

View file

@ -3857,7 +3857,8 @@ func TestCloseActiveConnection(t *testing.T) {
fakeConn := &fakeConnection{
BaseConnection: c,
}
common.Connections.Add(fakeConn)
err = common.Connections.Add(fakeConn)
assert.NoError(t, err)
_, err = httpdtest.CloseConnection(c.GetID(), http.StatusOK)
assert.NoError(t, err)
assert.Len(t, common.Connections.GetStats(), 0)
@ -3870,12 +3871,14 @@ func TestCloseConnectionAfterUserUpdateDelete(t *testing.T) {
fakeConn := &fakeConnection{
BaseConnection: c,
}
common.Connections.Add(fakeConn)
err = common.Connections.Add(fakeConn)
assert.NoError(t, err)
c1 := common.NewBaseConnection("connID1", common.ProtocolSFTP, "", "", user)
fakeConn1 := &fakeConnection{
BaseConnection: c1,
}
common.Connections.Add(fakeConn1)
err = common.Connections.Add(fakeConn1)
assert.NoError(t, err)
user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "0")
assert.NoError(t, err)
assert.Len(t, common.Connections.GetStats(), 2)
@ -3883,8 +3886,10 @@ func TestCloseConnectionAfterUserUpdateDelete(t *testing.T) {
assert.NoError(t, err)
assert.Len(t, common.Connections.GetStats(), 0)
common.Connections.Add(fakeConn)
common.Connections.Add(fakeConn1)
err = common.Connections.Add(fakeConn)
assert.NoError(t, err)
err = common.Connections.Add(fakeConn1)
assert.NoError(t, err)
assert.Len(t, common.Connections.GetStats(), 2)
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
@ -5173,7 +5178,8 @@ func TestLoaddataMode(t *testing.T) {
fakeConn := &fakeConnection{
BaseConnection: c,
}
common.Connections.Add(fakeConn)
err = common.Connections.Add(fakeConn)
assert.NoError(t, err)
assert.Len(t, common.Connections.GetStats(), 1)
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
assert.NoError(t, err)
@ -8714,7 +8720,8 @@ func TestWebClientMaxConnections(t *testing.T) {
connection := &httpd.Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user),
}
common.Connections.Add(connection)
err = common.Connections.Add(connection)
assert.NoError(t, err)
_, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword)
assert.Error(t, err)
@ -8895,20 +8902,57 @@ func TestMaxSessions(t *testing.T) {
u.Email = "user@session.com"
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
_, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword)
webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword)
assert.NoError(t, err)
_, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
apiToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
assert.NoError(t, err)
// now add a fake connection
fs := vfs.NewOsFs("id", os.TempDir(), "")
connection := &httpd.Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user),
}
common.Connections.Add(connection)
err = common.Connections.Add(connection)
assert.NoError(t, err)
_, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword)
assert.Error(t, err)
_, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
assert.Error(t, err)
// try an user API call
req, err := http.NewRequest(http.MethodGet, userDirsPath+"/?path=%2F", nil)
assert.NoError(t, err)
setBearerForReq(req, apiToken)
rr := executeRequest(req)
checkResponseCode(t, http.StatusTooManyRequests, rr)
assert.Contains(t, rr.Body.String(), "too many open sessions")
// web client requests
req, err = http.NewRequest(http.MethodGet, webClientDownloadZipPath, nil)
assert.NoError(t, err)
setJWTCookieForReq(req, webToken)
rr = executeRequest(req)
checkResponseCode(t, http.StatusTooManyRequests, rr)
assert.Contains(t, rr.Body.String(), "too many open sessions")
req, err = http.NewRequest(http.MethodGet, webClientDirsPath, nil)
assert.NoError(t, err)
setJWTCookieForReq(req, webToken)
rr = executeRequest(req)
checkResponseCode(t, http.StatusTooManyRequests, rr)
assert.Contains(t, rr.Body.String(), "too many open sessions")
req, err = http.NewRequest(http.MethodGet, webClientFilesPath+"?path=p", nil)
assert.NoError(t, err)
setJWTCookieForReq(req, webToken)
rr = executeRequest(req)
checkResponseCode(t, http.StatusTooManyRequests, rr)
assert.Contains(t, rr.Body.String(), "too many open sessions")
req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path=file", nil)
assert.NoError(t, err)
setJWTCookieForReq(req, webToken)
rr = executeRequest(req)
checkResponseCode(t, http.StatusTooManyRequests, rr)
assert.Contains(t, rr.Body.String(), "too many open sessions")
// test reset password
smtpCfg := smtp.Config{
Host: "127.0.0.1",
@ -8924,11 +8968,11 @@ func TestMaxSessions(t *testing.T) {
form.Set(csrfFormToken, csrfToken)
form.Set("username", user.Username)
lastResetCode = ""
req, err := http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode())))
req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req.RemoteAddr = defaultRemoteAddr
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rr := executeRequest(req)
rr = executeRequest(req)
assert.Equal(t, http.StatusFound, rr.Code)
assert.GreaterOrEqual(t, len(lastResetCode), 20)
form = make(url.Values)
@ -9630,6 +9674,123 @@ func TestShareUsage(t *testing.T) {
executeRequest(req)
}
func TestShareMaxSessions(t *testing.T) {
u := getTestUser()
u.MaxSessions = 1
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
assert.NoError(t, err)
share := dataprovider.Share{
Name: "test share max sessions read",
Scope: dataprovider.ShareScopeRead,
Paths: []string{"/"},
}
asJSON, err := json.Marshal(share)
assert.NoError(t, err)
req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON))
assert.NoError(t, err)
setBearerForReq(req, token)
rr := executeRequest(req)
checkResponseCode(t, http.StatusCreated, rr)
objectID := rr.Header().Get("X-Object-ID")
assert.NotEmpty(t, objectID)
req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/dirs", nil)
assert.NoError(t, err)
rr = executeRequest(req)
checkResponseCode(t, http.StatusOK, rr)
// add a fake connection
fs := vfs.NewOsFs("id", os.TempDir(), "")
connection := &httpd.Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user),
}
err = common.Connections.Add(connection)
assert.NoError(t, err)
req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/dirs", nil)
assert.NoError(t, err)
rr = executeRequest(req)
checkResponseCode(t, http.StatusTooManyRequests, rr)
assert.Contains(t, rr.Body.String(), "too many open sessions")
req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"/dirs", nil)
assert.NoError(t, err)
rr = executeRequest(req)
checkResponseCode(t, http.StatusTooManyRequests, rr)
assert.Contains(t, rr.Body.String(), "too many open sessions")
req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"/browse", nil)
assert.NoError(t, err)
rr = executeRequest(req)
checkResponseCode(t, http.StatusTooManyRequests, rr)
assert.Contains(t, rr.Body.String(), "too many open sessions")
req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/files?path=afile", nil)
assert.NoError(t, err)
rr = executeRequest(req)
checkResponseCode(t, http.StatusTooManyRequests, rr)
assert.Contains(t, rr.Body.String(), "too many open sessions")
req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID, nil)
assert.NoError(t, err)
rr = executeRequest(req)
checkResponseCode(t, http.StatusTooManyRequests, rr)
assert.Contains(t, rr.Body.String(), "too many open sessions")
req, err = http.NewRequest(http.MethodDelete, userSharesPath+"/"+objectID, nil)
assert.NoError(t, err)
setBearerForReq(req, token)
rr = executeRequest(req)
checkResponseCode(t, http.StatusOK, rr)
// now test a write share
share = dataprovider.Share{
Name: "test share max sessions write",
Scope: dataprovider.ShareScopeWrite,
Paths: []string{"/"},
}
asJSON, err = json.Marshal(share)
assert.NoError(t, err)
req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON))
assert.NoError(t, err)
setBearerForReq(req, token)
rr = executeRequest(req)
checkResponseCode(t, http.StatusCreated, rr)
objectID = rr.Header().Get("X-Object-ID")
assert.NotEmpty(t, objectID)
req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "file.txt"), bytes.NewBuffer([]byte("content")))
assert.NoError(t, err)
rr = executeRequest(req)
checkResponseCode(t, http.StatusTooManyRequests, rr)
assert.Contains(t, rr.Body.String(), "too many open sessions")
body := new(bytes.Buffer)
writer := multipart.NewWriter(body)
part1, err := writer.CreateFormFile("filenames", "file1.txt")
assert.NoError(t, err)
_, err = part1.Write([]byte("file1 content"))
assert.NoError(t, err)
err = writer.Close()
assert.NoError(t, err)
reader := bytes.NewReader(body.Bytes())
req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader)
assert.NoError(t, err)
req.Header.Add("Content-Type", writer.FormDataContentType())
rr = executeRequest(req)
checkResponseCode(t, http.StatusTooManyRequests, rr)
assert.Contains(t, rr.Body.String(), "too many open sessions")
common.Connections.Remove(connection.GetID())
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
assert.Len(t, common.Connections.GetStats(), 0)
}
func TestShareUploadSingle(t *testing.T) {
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
assert.NoError(t, err)

View file

@ -434,7 +434,7 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu
return err
}
connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
return err
}

View file

@ -331,7 +331,7 @@ func (t *oidcToken) getUser(r *http.Request) error {
return err
}
connectionID := fmt.Sprintf("%v_%v", common.ProtocolOIDC, xid.New().String())
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodIDP, ipAddr, err)
return err
}

View file

@ -228,7 +228,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
return
}
connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
s.renderClientLoginPage(w, err.Error(), ipAddr)
return
@ -268,7 +268,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
return
}
connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String())
if err := checkHTTPClientUser(user, r, connectionID); err != nil {
if err := checkHTTPClientUser(user, r, connectionID, true); err != nil {
s.renderClientResetPwdPage(w, fmt.Sprintf("Password reset successfully but unable to login: %v", err.Error()), ipAddr)
return
}
@ -760,7 +760,7 @@ func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) {
return
}
connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
@ -920,7 +920,7 @@ func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request,
logger.Debug(logSender, "", "signature mismatch for user %#v, unable to refresh cookie", user.Username)
return
}
if err := checkHTTPClientUser(&user, r, xid.New().String()); err != nil {
if err := checkHTTPClientUser(&user, r, xid.New().String(), true); err != nil {
logger.Debug(logSender, "", "unable to refresh cookie for user %#v: %v", user.Username, err)
return
}

View file

@ -595,7 +595,7 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.
connID := xid.New().String()
protocol := getProtocolFromRequest(r)
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
s.renderClientForbiddenPage(w, r, err.Error())
return
}
@ -604,7 +604,10 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.
r.RemoteAddr, user),
request: r,
}
common.Connections.Add(connection)
if err = common.Connections.Add(connection); err != nil {
s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
return
}
defer common.Connections.Remove(connection.GetID())
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@ -635,7 +638,10 @@ func (s *httpdServer) handleShareGetDirContents(w http.ResponseWriter, r *http.R
s.renderClientMessagePage(w, r, "Invalid share path", "", getRespStatus(err), err, "")
return
}
common.Connections.Add(connection)
if err = common.Connections.Add(connection); err != nil {
s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
return
}
defer common.Connections.Remove(connection.GetID())
contents, err := connection.ReadDir(name)
@ -691,7 +697,10 @@ func (s *httpdServer) handleShareGetFiles(w http.ResponseWriter, r *http.Request
return
}
common.Connections.Add(connection)
if err = common.Connections.Add(connection); err != nil {
s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
return
}
defer common.Connections.Remove(connection.GetID())
var info os.FileInfo
@ -735,7 +744,7 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.
connID := xid.New().String()
protocol := getProtocolFromRequest(r)
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
@ -744,7 +753,10 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.
r.RemoteAddr, user),
request: r,
}
common.Connections.Add(connection)
if err = common.Connections.Add(connection); err != nil {
s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
return
}
defer common.Connections.Remove(connection.GetID())
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@ -809,7 +821,7 @@ func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Reques
connID := xid.New().String()
protocol := getProtocolFromRequest(r)
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
s.renderClientForbiddenPage(w, r, err.Error())
return
}
@ -818,7 +830,10 @@ func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Reques
r.RemoteAddr, user),
request: r,
}
common.Connections.Add(connection)
if err = common.Connections.Add(connection); err != nil {
s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
return
}
defer common.Connections.Remove(connection.GetID())
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
@ -866,7 +881,7 @@ func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Reques
connID := xid.New().String()
protocol := getProtocolFromRequest(r)
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
s.renderClientForbiddenPage(w, r, err.Error())
return
}
@ -875,7 +890,10 @@ func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Reques
r.RemoteAddr, user),
request: r,
}
common.Connections.Add(connection)
if err = common.Connections.Add(connection); err != nil {
s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
return
}
defer common.Connections.Remove(connection.GetID())
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))

View file

@ -369,7 +369,7 @@ func TestTruncate(t *testing.T) {
}
func TestSCPBasicHandlingCryptoFs(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -427,7 +427,7 @@ func TestSCPBasicHandlingCryptoFs(t *testing.T) {
}
func TestSCPRecursiveCryptFs(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true

View file

@ -457,8 +457,12 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO
return t, nil
}
// Disconnect disconnects the client closing the network connection
// Disconnect disconnects the client by closing the channel
func (c *Connection) Disconnect() error {
if c.channel == nil {
c.Log(logger.LevelWarn, "cannot disconnect a nil channel")
return nil
}
return c.channel.Close()
}

View file

@ -14,6 +14,7 @@ import (
"github.com/eikenb/pipeat"
"github.com/pkg/sftp"
"github.com/rs/xid"
"github.com/sftpgo/sdk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -2152,3 +2153,45 @@ func TestLoadRevokedUserCertsFile(t *testing.T) {
err = os.RemoveAll(r.filePath)
assert.NoError(t, err)
}
func TestMaxUserSessions(t *testing.T) {
connection := &Connection{
BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: "user_max_sessions",
HomeDir: filepath.Clean(os.TempDir()),
MaxSessions: 1,
},
}),
}
err := common.Connections.Add(connection)
assert.NoError(t, err)
c := Configuration{}
c.handleSftpConnection(nil, connection)
sshCmd := sshCommand{
command: "cd",
connection: connection,
}
err = sshCmd.handle()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "too many open sessions")
}
scpCmd := scpCommand{
sshCommand: sshCommand{
command: "scp",
connection: connection,
},
}
err = scpCmd.handle()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "too many open sessions")
}
err = ServeSubSystemConnection(&connection.User, connection.ID, nil, nil)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "too many open sessions")
}
common.Connections.Remove(connection.GetID())
assert.Len(t, common.Connections.GetStats(), 0)
}

View file

@ -37,7 +37,10 @@ func (c *scpCommand) handle() (err error) {
err = common.ErrGenericFailure
}
}()
common.Connections.Add(c.connection)
if err := common.Connections.Add(c.connection); err != nil {
logger.Info(logSender, "", "unable to add SCP connection: %v", err)
return err
}
defer common.Connections.Remove(c.connection.GetID())
destPath := c.getDestPath()

View file

@ -562,7 +562,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
case "subsystem":
if string(req.Payload[4:]) == "sftp" {
ok = true
connection := Connection{
connection := &Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, conn.LocalAddr().String(),
conn.RemoteAddr().String(), user),
ClientVersion: string(sconn.ClientVersion()),
@ -571,7 +571,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
channel: channel,
folderPrefix: c.FolderPrefix,
}
go c.handleSftpConnection(channel, &connection)
go c.handleSftpConnection(channel, connection)
}
case "exec":
// protocol will be set later inside processSSHCommand it could be SSH or SCP
@ -600,7 +600,11 @@ func (c *Configuration) handleSftpConnection(channel ssh.Channel, connection *Co
logger.Error(logSender, "", "panic in handleSftpConnection: %#v stack strace: %v", r, string(debug.Stack()))
}
}()
common.Connections.Add(connection)
if err := common.Connections.Add(connection); err != nil {
errClose := connection.Disconnect()
logger.Info(logSender, "", "unable to add connection: %v, close err: %v", err, errClose)
return
}
defer common.Connections.Remove(connection.GetID())
// Create the server instance for the channel using the handler we created above.

View file

@ -3,6 +3,7 @@ package sftpd_test
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
@ -129,6 +130,7 @@ var (
allPerms = []string{dataprovider.PermAny}
homeBasePath string
scpPath string
scpForce bool
gitPath string
sshPath string
hookCmdPath string
@ -935,8 +937,6 @@ func TestConcurrency(t *testing.T) {
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
err = checkBasicSFTP(client)
assert.NoError(t, err)
err = sftpUploadFile(testFilePath, testFileName+strconv.Itoa(counter), testFileSize, client)
assert.NoError(t, err)
assert.Greater(t, common.Connections.GetActiveSessions(defaultUsername), 0)
@ -9231,7 +9231,7 @@ func TestGitErrors(t *testing.T) {
// Start SCP tests
func TestSCPBasicHandling(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -9295,7 +9295,7 @@ func TestSCPBasicHandling(t *testing.T) {
}
func TestSCPUploadFileOverwrite(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -9376,7 +9376,7 @@ func TestSCPUploadFileOverwrite(t *testing.T) {
}
func TestSCPRecursive(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -9485,7 +9485,7 @@ func TestSCPStartDirectory(t *testing.T) {
}
func TestSCPPatternsFilter(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -9606,7 +9606,7 @@ func TestSCPUploadMaxSize(t *testing.T) {
}
func TestSCPVirtualFolders(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -9658,7 +9658,7 @@ func TestSCPVirtualFolders(t *testing.T) {
}
func TestSCPNestedFolders(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
baseUser, resp, err := httpdtest.AddUser(getTestUser(false), http.StatusCreated)
@ -9791,7 +9791,7 @@ func TestSCPNestedFolders(t *testing.T) {
}
func TestSCPVirtualFoldersQuota(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -9889,7 +9889,7 @@ func TestSCPVirtualFoldersQuota(t *testing.T) {
}
func TestSCPPermsSubDirs(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -9929,7 +9929,7 @@ func TestSCPPermsSubDirs(t *testing.T) {
}
func TestSCPPermCreateDirs(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -9963,7 +9963,7 @@ func TestSCPPermCreateDirs(t *testing.T) {
}
func TestSCPPermUpload(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -9987,7 +9987,7 @@ func TestSCPPermUpload(t *testing.T) {
}
func TestSCPPermOverwrite(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -10014,7 +10014,7 @@ func TestSCPPermOverwrite(t *testing.T) {
}
func TestSCPPermDownload(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -10043,7 +10043,7 @@ func TestSCPPermDownload(t *testing.T) {
}
func TestSCPQuotaSize(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -10099,7 +10099,7 @@ func TestSCPQuotaSize(t *testing.T) {
}
func TestSCPEscapeHomeDir(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -10135,7 +10135,7 @@ func TestSCPEscapeHomeDir(t *testing.T) {
}
func TestSCPUploadPaths(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -10170,7 +10170,7 @@ func TestSCPUploadPaths(t *testing.T) {
}
func TestSCPOverwriteDirWithFile(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
@ -10194,7 +10194,7 @@ func TestSCPOverwriteDirWithFile(t *testing.T) {
}
func TestSCPRemoteToRemote(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
if runtime.GOOS == osWindows {
@ -10229,7 +10229,7 @@ func TestSCPRemoteToRemote(t *testing.T) {
}
func TestSCPErrors(t *testing.T) {
if len(scpPath) == 0 {
if scpPath == "" {
t.Skip("scp command not found, unable to execute this test")
}
u := getTestUser(true)
@ -10682,6 +10682,9 @@ func getScpDownloadCommand(localPath, remotePath string, preserveTime, recursive
if recursive {
args = append(args, "-r")
}
if scpForce {
args = append(args, "-O")
}
args = append(args, "-P")
args = append(args, "2022")
args = append(args, "-o")
@ -10707,6 +10710,9 @@ func getScpUploadCommand(localPath, remotePath string, preserveTime, remoteToRem
args = append(args, "-r")
}
}
if scpForce {
args = append(args, "-O")
}
args = append(args, "-P")
args = append(args, "2022")
args = append(args, "-o")
@ -10770,6 +10776,12 @@ func checkSystemCommands() {
logger.Warn(logSender, "", "unable to get scp command. SCP tests will be skipped, err: %v", err)
logger.WarnToConsole("unable to get scp command. SCP tests will be skipped, err: %v", err)
scpPath = ""
} else {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, scpPath, "-O")
out, _ := cmd.CombinedOutput()
scpForce = !strings.Contains(string(out), "option -- O")
}
}

View file

@ -115,7 +115,10 @@ func (c *sshCommand) handle() (err error) {
err = common.ErrGenericFailure
}
}()
common.Connections.Add(c.connection)
if err := common.Connections.Add(c.connection); err != nil {
logger.Info(logSender, "", "unable to add SSH command connection: %v", err)
return err
}
defer common.Connections.Remove(c.connection.GetID())
c.connection.UpdateLastActivity()
@ -131,7 +134,7 @@ func (c *sshCommand) handle() (err error) {
c.sendExitStatus(nil)
} else if c.command == "pwd" {
// hard coded response to "/"
c.connection.channel.Write([]byte("/\n")) //nolint:errcheck
c.connection.channel.Write([]byte(util.CleanPath(c.connection.User.Filters.StartDirectory) + "\n")) //nolint:errcheck
c.sendExitStatus(nil)
} else if c.command == "sftpgo-copy" {
return c.handleSFTPGoCopy()

View file

@ -43,7 +43,6 @@ func ServeSubSystemConnection(user *dataprovider.User, connectionID string, read
logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose)
return err
}
dataprovider.UpdateLastLogin(user)
connection := &Connection{
BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolSFTP, "", "", *user),
@ -52,9 +51,15 @@ func ServeSubSystemConnection(user *dataprovider.User, connectionID string, read
LocalAddr: &net.IPAddr{},
channel: newSubsystemChannel(reader, writer),
}
common.Connections.Add(connection)
err = common.Connections.Add(connection)
if err != nil {
errClose := user.CloseFs()
logger.Warn(logSender, connectionID, "unable to add connection: %v close fs error: %v", err, errClose)
return err
}
defer common.Connections.Remove(connection.GetID())
dataprovider.UpdateLastLogin(user)
server := sftp.NewRequestServer(connection.channel, sftp.Handlers{
FileGet: connection,
FilePut: connection,

View file

@ -203,18 +203,14 @@ func (fs *CryptFs) ReadDir(dirname string) ([]os.FileInfo, error) {
if err != nil {
return nil, err
}
entries, err := f.ReadDir(-1)
list, err := f.Readdir(-1)
f.Close()
if err != nil {
return nil, err
}
result := make([]os.FileInfo, len(entries))
for idx, entry := range entries {
info, err := entry.Info()
if err != nil {
return nil, err
}
result[idx] = fs.ConvertFileInfo(info)
result := make([]os.FileInfo, 0, len(list))
for _, info := range list {
result = append(result, fs.ConvertFileInfo(info))
}
return result, nil
}

View file

@ -3,7 +3,6 @@ package vfs
import (
"fmt"
"io"
"io/fs"
"net/http"
"os"
"path"
@ -171,20 +170,12 @@ func (*OsFs) ReadDir(dirname string) ([]os.FileInfo, error) {
if err != nil {
return nil, err
}
entries, err := f.ReadDir(-1)
list, err := f.Readdir(-1)
f.Close()
if err != nil {
return nil, err
}
result := make([]fs.FileInfo, len(entries))
for idx, entry := range entries {
info, err := entry.Info()
if err != nil {
return nil, err
}
result[idx] = info
}
return result, nil
return list, nil
}
// IsUploadResumeSupported returns true if resuming uploads is supported

View file

@ -196,19 +196,25 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
updateLoginMetrics(&user, ipAddr, loginMethod, err)
ctx := context.WithValue(r.Context(), requestIDKey, connectionID)
ctx = context.WithValue(ctx, requestStartKey, time.Now())
connection := &Connection{
BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolWebDAV, util.GetHTTPLocalAddress(r),
r.RemoteAddr, user),
request: r,
}
common.Connections.Add(connection)
if err = common.Connections.Add(connection); err != nil {
errClose := user.CloseFs()
logger.Warn(logSender, connectionID, "unable add connection: %v close fs error: %v", err, errClose)
updateLoginMetrics(&user, ipAddr, loginMethod, err)
http.Error(w, err.Error(), http.StatusTooManyRequests)
return
}
defer common.Connections.Remove(connection.GetID())
updateLoginMetrics(&user, ipAddr, loginMethod, err)
ctx := context.WithValue(r.Context(), requestIDKey, connectionID)
ctx = context.WithValue(ctx, requestStartKey, time.Now())
dataprovider.UpdateLastLogin(&user)
if s.checkRequestMethod(ctx, r, connection) {
@ -311,14 +317,6 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo
user.Username, loginMethod)
return connID, fmt.Errorf("login method %v is not allowed for user %#v", loginMethod, user.Username)
}
if user.MaxSessions > 0 {
activeSessions := common.Connections.GetActiveSessions(user.Username)
if activeSessions >= user.MaxSessions {
logger.Info(logSender, connID, "authentication refused for user: %#v, too many open sessions: %v/%v",
user.Username, activeSessions, user.MaxSessions)
return connID, fmt.Errorf("too many open sessions: %v", activeSessions)
}
}
if !user.IsLoginFromAddrAllowed(r.RemoteAddr) {
logger.Info(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v",
user.Username, r.RemoteAddr)

View file

@ -1167,7 +1167,8 @@ func TestMaxConnections(t *testing.T) {
connection := &webdavd.Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user),
}
common.Connections.Add(connection)
err = common.Connections.Add(connection)
assert.NoError(t, err)
assert.Error(t, checkBasicFunc(client))
common.Connections.Remove(connection.GetID())
_, err = httpdtest.RemoveUser(user, http.StatusOK)
@ -1222,7 +1223,8 @@ func TestMaxSessions(t *testing.T) {
connection := &webdavd.Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user),
}
common.Connections.Add(connection)
err = common.Connections.Add(connection)
assert.NoError(t, err)
assert.Error(t, checkBasicFunc(client))
common.Connections.Remove(connection.GetID())
_, err = httpdtest.RemoveUser(user, http.StatusOK)