refactoring of user session counters
Fixes #792 Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
parent
5bc0f4f8af
commit
002a06629e
28 changed files with 542 additions and 199 deletions
|
@ -98,6 +98,7 @@ func init() {
|
|||
Connections.clients = clientsMap{
|
||||
clients: make(map[string]int),
|
||||
}
|
||||
Connections.perUserConns = make(map[string]int)
|
||||
}
|
||||
|
||||
// errors definitions
|
||||
|
@ -345,6 +346,7 @@ type ActiveTransfer interface {
|
|||
type ActiveConnection interface {
|
||||
GetID() string
|
||||
GetUsername() string
|
||||
GetMaxSessions() int
|
||||
GetLocalAddress() string
|
||||
GetRemoteAddress() string
|
||||
GetClientVersion() string
|
||||
|
@ -733,6 +735,29 @@ type ActiveConnections struct {
|
|||
sync.RWMutex
|
||||
connections []ActiveConnection
|
||||
sshConnections []*SSHConnection
|
||||
perUserConns map[string]int
|
||||
}
|
||||
|
||||
// internal method, must be called within a locked block
|
||||
func (conns *ActiveConnections) addUserConnection(username string) {
|
||||
if username == "" {
|
||||
return
|
||||
}
|
||||
conns.perUserConns[username]++
|
||||
}
|
||||
|
||||
// internal method, must be called within a locked block
|
||||
func (conns *ActiveConnections) removeUserConnection(username string) {
|
||||
if username == "" {
|
||||
return
|
||||
}
|
||||
if val, ok := conns.perUserConns[username]; ok {
|
||||
conns.perUserConns[username]--
|
||||
if val > 1 {
|
||||
return
|
||||
}
|
||||
delete(conns.perUserConns, username)
|
||||
}
|
||||
}
|
||||
|
||||
// GetActiveSessions returns the number of active sessions for the given username.
|
||||
|
@ -741,24 +766,27 @@ func (conns *ActiveConnections) GetActiveSessions(username string) int {
|
|||
conns.RLock()
|
||||
defer conns.RUnlock()
|
||||
|
||||
numSessions := 0
|
||||
for _, c := range conns.connections {
|
||||
if c.GetUsername() == username {
|
||||
numSessions++
|
||||
}
|
||||
}
|
||||
return numSessions
|
||||
return conns.perUserConns[username]
|
||||
}
|
||||
|
||||
// Add adds a new connection to the active ones
|
||||
func (conns *ActiveConnections) Add(c ActiveConnection) {
|
||||
func (conns *ActiveConnections) Add(c ActiveConnection) error {
|
||||
conns.Lock()
|
||||
defer conns.Unlock()
|
||||
|
||||
if username := c.GetUsername(); username != "" {
|
||||
if maxSessions := c.GetMaxSessions(); maxSessions > 0 {
|
||||
if val := conns.perUserConns[username]; val >= maxSessions {
|
||||
return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
|
||||
}
|
||||
}
|
||||
conns.addUserConnection(username)
|
||||
}
|
||||
conns.connections = append(conns.connections, c)
|
||||
metric.UpdateActiveConnectionsSize(len(conns.connections))
|
||||
logger.Debug(c.GetProtocol(), c.GetID(), "connection added, local address %#v, remote address %#v, num open connections: %v",
|
||||
c.GetLocalAddress(), c.GetRemoteAddress(), len(conns.connections))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Swap replaces an existing connection with the given one.
|
||||
|
@ -771,6 +799,16 @@ func (conns *ActiveConnections) Swap(c ActiveConnection) error {
|
|||
|
||||
for idx, conn := range conns.connections {
|
||||
if conn.GetID() == c.GetID() {
|
||||
conns.removeUserConnection(conn.GetUsername())
|
||||
if username := c.GetUsername(); username != "" {
|
||||
if maxSessions := c.GetMaxSessions(); maxSessions > 0 {
|
||||
if val := conns.perUserConns[username]; val >= maxSessions {
|
||||
conns.addUserConnection(conn.GetUsername())
|
||||
return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
|
||||
}
|
||||
}
|
||||
conns.addUserConnection(username)
|
||||
}
|
||||
err := conn.CloseFS()
|
||||
conns.connections[idx] = c
|
||||
logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err)
|
||||
|
@ -793,6 +831,7 @@ func (conns *ActiveConnections) Remove(connectionID string) {
|
|||
conns.connections[idx] = conns.connections[lastIdx]
|
||||
conns.connections[lastIdx] = nil
|
||||
conns.connections = conns.connections[:lastIdx]
|
||||
conns.removeUserConnection(conn.GetUsername())
|
||||
metric.UpdateActiveConnectionsSize(lastIdx)
|
||||
logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %#v, remote address %#v close fs error: %v, num open connections: %v",
|
||||
conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)
|
||||
|
|
|
@ -370,6 +370,29 @@ func TestWhitelist(t *testing.T) {
|
|||
Config = configCopy
|
||||
}
|
||||
|
||||
func TestUserMaxSessions(t *testing.T) {
|
||||
c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{
|
||||
BaseUser: sdk.BaseUser{
|
||||
Username: userTestUsername,
|
||||
MaxSessions: 1,
|
||||
},
|
||||
})
|
||||
fakeConn := &fakeConnection{
|
||||
BaseConnection: c,
|
||||
}
|
||||
err := Connections.Add(fakeConn)
|
||||
assert.NoError(t, err)
|
||||
err = Connections.Add(fakeConn)
|
||||
assert.Error(t, err)
|
||||
err = Connections.Swap(fakeConn)
|
||||
assert.NoError(t, err)
|
||||
Connections.Remove(fakeConn.GetID())
|
||||
Connections.Lock()
|
||||
Connections.removeUserConnection(userTestUsername)
|
||||
Connections.Unlock()
|
||||
assert.Len(t, Connections.GetStats(), 0)
|
||||
}
|
||||
|
||||
func TestMaxConnections(t *testing.T) {
|
||||
oldValue := Config.MaxTotalConnections
|
||||
perHost := Config.MaxPerHostConnections
|
||||
|
@ -387,7 +410,8 @@ func TestMaxConnections(t *testing.T) {
|
|||
fakeConn := &fakeConnection{
|
||||
BaseConnection: c,
|
||||
}
|
||||
Connections.Add(fakeConn)
|
||||
err := Connections.Add(fakeConn)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, Connections.GetStats(), 1)
|
||||
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||
|
||||
|
@ -466,14 +490,16 @@ func TestIdleConnections(t *testing.T) {
|
|||
sshConn1.lastActivity = c.lastActivity
|
||||
sshConn2.lastActivity = c.lastActivity
|
||||
Connections.AddSSHConnection(sshConn1)
|
||||
Connections.Add(fakeConn)
|
||||
err = Connections.Add(fakeConn)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, Connections.GetActiveSessions(username), 1)
|
||||
c = NewBaseConnection(sshConn2.id+"_1", ProtocolSSH, "", "", user)
|
||||
fakeConn = &fakeConnection{
|
||||
BaseConnection: c,
|
||||
}
|
||||
Connections.AddSSHConnection(sshConn2)
|
||||
Connections.Add(fakeConn)
|
||||
err = Connections.Add(fakeConn)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, Connections.GetActiveSessions(username), 2)
|
||||
|
||||
cFTP := NewBaseConnection("id2", ProtocolFTP, "", "", dataprovider.User{})
|
||||
|
@ -481,7 +507,8 @@ func TestIdleConnections(t *testing.T) {
|
|||
fakeConn = &fakeConnection{
|
||||
BaseConnection: cFTP,
|
||||
}
|
||||
Connections.Add(fakeConn)
|
||||
err = Connections.Add(fakeConn)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, Connections.GetActiveSessions(username), 2)
|
||||
assert.Len(t, Connections.GetStats(), 3)
|
||||
Connections.RLock()
|
||||
|
@ -521,7 +548,8 @@ func TestCloseConnection(t *testing.T) {
|
|||
BaseConnection: c,
|
||||
}
|
||||
assert.True(t, Connections.IsNewConnectionAllowed("127.0.0.1"))
|
||||
Connections.Add(fakeConn)
|
||||
err := Connections.Add(fakeConn)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, Connections.GetStats(), 1)
|
||||
res := Connections.Close(fakeConn.GetID())
|
||||
assert.True(t, res)
|
||||
|
@ -536,19 +564,34 @@ func TestSwapConnection(t *testing.T) {
|
|||
fakeConn := &fakeConnection{
|
||||
BaseConnection: c,
|
||||
}
|
||||
Connections.Add(fakeConn)
|
||||
err := Connections.Add(fakeConn)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, Connections.GetStats(), 1) {
|
||||
assert.Equal(t, "", Connections.GetStats()[0].Username)
|
||||
}
|
||||
c = NewBaseConnection("id", ProtocolFTP, "", "", dataprovider.User{
|
||||
BaseUser: sdk.BaseUser{
|
||||
Username: userTestUsername,
|
||||
MaxSessions: 1,
|
||||
},
|
||||
})
|
||||
fakeConn = &fakeConnection{
|
||||
BaseConnection: c,
|
||||
}
|
||||
err := Connections.Swap(fakeConn)
|
||||
c1 := NewBaseConnection("id1", ProtocolFTP, "", "", dataprovider.User{
|
||||
BaseUser: sdk.BaseUser{
|
||||
Username: userTestUsername,
|
||||
},
|
||||
})
|
||||
fakeConn1 := &fakeConnection{
|
||||
BaseConnection: c1,
|
||||
}
|
||||
err = Connections.Add(fakeConn1)
|
||||
assert.NoError(t, err)
|
||||
err = Connections.Swap(fakeConn)
|
||||
assert.Error(t, err)
|
||||
Connections.Remove(fakeConn1.ID)
|
||||
err = Connections.Swap(fakeConn)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, Connections.GetStats(), 1) {
|
||||
assert.Equal(t, userTestUsername, Connections.GetStats()[0].Username)
|
||||
|
@ -600,9 +643,12 @@ func TestConnectionStatus(t *testing.T) {
|
|||
command: "PROPFIND",
|
||||
}
|
||||
t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
|
||||
Connections.Add(fakeConn1)
|
||||
Connections.Add(fakeConn2)
|
||||
Connections.Add(fakeConn3)
|
||||
err := Connections.Add(fakeConn1)
|
||||
assert.NoError(t, err)
|
||||
err = Connections.Add(fakeConn2)
|
||||
assert.NoError(t, err)
|
||||
err = Connections.Add(fakeConn3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats := Connections.GetStats()
|
||||
assert.Len(t, stats, 3)
|
||||
|
@ -628,7 +674,7 @@ func TestConnectionStatus(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
err := t1.Close()
|
||||
err = t1.Close()
|
||||
assert.NoError(t, err)
|
||||
err = t2.Close()
|
||||
assert.NoError(t, err)
|
||||
|
|
|
@ -80,6 +80,11 @@ func (c *BaseConnection) GetUsername() string {
|
|||
return c.User.Username
|
||||
}
|
||||
|
||||
// GetMaxSessions returns the maximum number of concurrent sessions allowed
|
||||
func (c *BaseConnection) GetMaxSessions() int {
|
||||
return c.User.MaxSessions
|
||||
}
|
||||
|
||||
// GetProtocol returns the protocol for the connection
|
||||
func (c *BaseConnection) GetProtocol() string {
|
||||
return c.protocol
|
||||
|
|
|
@ -62,7 +62,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
|
|||
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
|
||||
"/file1", TransferUpload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
|
||||
transfer1.BytesReceived = 150
|
||||
Connections.Add(fakeConn1)
|
||||
err = Connections.Add(fakeConn1)
|
||||
assert.NoError(t, err)
|
||||
// the transferschecker will do nothing if there is only one ongoing transfer
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
|
@ -76,7 +77,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
|
|||
"/file2", TransferUpload, 0, 0, 120, 40, true, fsUser, dataprovider.TransferQuota{})
|
||||
transfer1.BytesReceived = 50
|
||||
transfer2.BytesReceived = 60
|
||||
Connections.Add(fakeConn2)
|
||||
err = Connections.Add(fakeConn2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
connID3 := xid.New().String()
|
||||
conn3 := NewBaseConnection(connID3, ProtocolSFTP, "", "", user)
|
||||
|
@ -86,7 +88,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
|
|||
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"),
|
||||
"/file3", TransferDownload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
|
||||
transfer3.BytesReceived = 60 // this value will be ignored, this is a download
|
||||
Connections.Add(fakeConn3)
|
||||
err = Connections.Add(fakeConn3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// the transfers are not overquota
|
||||
Connections.checkTransfers()
|
||||
|
@ -146,7 +149,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
|
|||
transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(os.TempDir(), folderName, "file1"),
|
||||
filepath.Join(os.TempDir(), folderName, "file1"), path.Join(vdirPath, "/file1"), TransferUpload, 0, 0,
|
||||
100, 0, true, fsFolder, dataprovider.TransferQuota{})
|
||||
Connections.Add(fakeConn4)
|
||||
err = Connections.Add(fakeConn4)
|
||||
assert.NoError(t, err)
|
||||
connID5 := xid.New().String()
|
||||
conn5 := NewBaseConnection(connID5, ProtocolSFTP, "", "", user)
|
||||
fakeConn5 := &fakeConnection{
|
||||
|
@ -156,7 +160,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
|
|||
filepath.Join(os.TempDir(), folderName, "file2"), path.Join(vdirPath, "/file2"), TransferUpload, 0, 0,
|
||||
100, 0, true, fsFolder, dataprovider.TransferQuota{})
|
||||
|
||||
Connections.Add(fakeConn5)
|
||||
err = Connections.Add(fakeConn5)
|
||||
assert.NoError(t, err)
|
||||
transfer4.BytesReceived = 50
|
||||
transfer5.BytesReceived = 40
|
||||
Connections.checkTransfers()
|
||||
|
@ -245,7 +250,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
|
|||
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
|
||||
"/file1", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
|
||||
transfer1.BytesReceived = 150
|
||||
Connections.Add(fakeConn1)
|
||||
err = Connections.Add(fakeConn1)
|
||||
assert.NoError(t, err)
|
||||
// the transferschecker will do nothing if there is only one ongoing transfer
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
|
@ -258,7 +264,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
|
|||
transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
|
||||
"/file2", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
|
||||
transfer2.BytesReceived = 150
|
||||
Connections.Add(fakeConn2)
|
||||
err = Connections.Add(fakeConn2)
|
||||
assert.NoError(t, err)
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer1.errAbort)
|
||||
assert.Nil(t, transfer2.errAbort)
|
||||
|
@ -294,7 +301,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
|
|||
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
|
||||
"/file1", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
|
||||
transfer3.BytesSent = 150
|
||||
Connections.Add(fakeConn3)
|
||||
err = Connections.Add(fakeConn3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
connID4 := xid.New().String()
|
||||
conn4 := NewBaseConnection(connID4, ProtocolSFTP, "", "", user)
|
||||
|
@ -304,7 +312,8 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
|
|||
transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
|
||||
"/file2", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
|
||||
transfer4.BytesSent = 150
|
||||
Connections.Add(fakeConn4)
|
||||
err = Connections.Add(fakeConn4)
|
||||
assert.NoError(t, err)
|
||||
Connections.checkTransfers()
|
||||
assert.Nil(t, transfer3.errAbort)
|
||||
assert.Nil(t, transfer4.errAbort)
|
||||
|
|
|
@ -593,7 +593,8 @@ func TestClientVersion(t *testing.T) {
|
|||
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user),
|
||||
clientContext: mockCC,
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
err := common.Connections.Add(connection)
|
||||
assert.NoError(t, err)
|
||||
stats := common.Connections.GetStats()
|
||||
if assert.Len(t, stats, 1) {
|
||||
assert.Equal(t, "mock version", stats[0].ClientVersion)
|
||||
|
|
|
@ -168,8 +168,8 @@ func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
|
|||
cc.RemoteAddr().String(), user),
|
||||
clientContext: cc,
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
return s.initialMsg, nil
|
||||
err = common.Connections.Add(connection)
|
||||
return s.initialMsg, err
|
||||
}
|
||||
|
||||
// ClientDisconnected is called when the user disconnects, even if he never authenticated
|
||||
|
@ -367,9 +367,9 @@ func (s *Server) validateUser(user dataprovider.User, cc ftpserver.ClientContext
|
|||
}
|
||||
err = common.Connections.Swap(connection)
|
||||
if err != nil {
|
||||
err = user.CloseFs()
|
||||
logger.Warn(logSender, connectionID, "unable to swap connection, close fs error: %v", err)
|
||||
return nil, common.ErrInternalFailure
|
||||
errClose := user.CloseFs()
|
||||
logger.Warn(logSender, connectionID, "unable to swap connection: %v, close fs error: %v", err, errClose)
|
||||
return nil, err
|
||||
}
|
||||
return connection, nil
|
||||
}
|
||||
|
|
15
go.mod
15
go.mod
|
@ -12,9 +12,9 @@ require (
|
|||
github.com/aws/aws-sdk-go-v2/config v1.15.3
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.11.2
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3
|
||||
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.4
|
||||
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.5
|
||||
github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.3
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.4
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.5
|
||||
github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.4
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.16.3
|
||||
github.com/cockroachdb/cockroach-go/v2 v2.2.8
|
||||
|
@ -32,10 +32,10 @@ require (
|
|||
github.com/grandcat/zeroconf v1.0.0
|
||||
github.com/hashicorp/go-hclog v1.2.0
|
||||
github.com/hashicorp/go-plugin v1.4.3
|
||||
github.com/hashicorp/go-retryablehttp v0.7.0
|
||||
github.com/hashicorp/go-retryablehttp v0.7.1
|
||||
github.com/jlaffaye/ftp v0.0.0-20201112195030-9aae4d151126
|
||||
github.com/klauspost/compress v1.15.1
|
||||
github.com/lestrrat-go/jwx v1.2.22
|
||||
github.com/lestrrat-go/jwx v1.2.23
|
||||
github.com/lib/pq v1.10.5
|
||||
github.com/lithammer/shortuuid/v3 v3.0.7
|
||||
github.com/mattn/go-sqlite3 v1.14.12
|
||||
|
@ -54,7 +54,7 @@ require (
|
|||
github.com/shirou/gopsutil/v3 v3.22.3
|
||||
github.com/spf13/afero v1.8.2
|
||||
github.com/spf13/cobra v1.4.0
|
||||
github.com/spf13/viper v1.10.1
|
||||
github.com/spf13/viper v1.11.0
|
||||
github.com/stretchr/testify v1.7.1
|
||||
github.com/studio-b12/gowebdav v0.0.0-20220128162035-c7b1ff8a5e62
|
||||
github.com/unrolled/secure v1.10.0
|
||||
|
@ -67,7 +67,7 @@ require (
|
|||
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4
|
||||
golang.org/x/net v0.0.0-20220412020605-290c469a71a5
|
||||
golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5
|
||||
golang.org/x/sys v0.0.0-20220412071739-889880a91fd5
|
||||
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad
|
||||
golang.org/x/time v0.0.0-20220411224347-583f2d630306
|
||||
google.golang.org/api v0.74.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
|
@ -130,6 +130,7 @@ require (
|
|||
github.com/mitchellh/mapstructure v1.4.3 // indirect
|
||||
github.com/oklog/run v1.1.0 // indirect
|
||||
github.com/pelletier/go-toml v1.9.4 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.0-beta.8 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20220216144756-c35f1ee13d7c // indirect
|
||||
|
@ -151,7 +152,7 @@ require (
|
|||
golang.org/x/tools v0.1.10 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/genproto v0.0.0-20220407144326-9054f6ed7bac // indirect
|
||||
google.golang.org/genproto v0.0.0-20220413183235-5e96e2839df9 // indirect
|
||||
google.golang.org/grpc v1.45.0 // indirect
|
||||
google.golang.org/protobuf v1.28.0 // indirect
|
||||
gopkg.in/ini.v1 v1.66.4 // indirect
|
||||
|
|
29
go.sum
29
go.sum
|
@ -142,8 +142,8 @@ github.com/aws/aws-sdk-go-v2/credentials v1.11.2/go.mod h1:j8YsY9TXTm31k4eFhspiQ
|
|||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3 h1:LWPg5zjHV9oz/myQr4wMs0gi4CjnDN/ILmyZUFYXZsU=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3/go.mod h1:uk1vhHHERfSVCUnqSqz8O48LBYDSC+k6brng09jcMOk=
|
||||
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.3/go.mod h1:0dHuD2HZZSiwfJSy1FO5bX1hQ1TxVV1QXXjpn3XUE44=
|
||||
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.4 h1:iqcMQBj/B3FPxVb5SGNHC8XAh64hmaWUC8piZArBE7U=
|
||||
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.4/go.mod h1:s79ZPBpDzcR1BCuAhGCF1rgmd/QmLueKCvdkmX4SDgg=
|
||||
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.5 h1:lPo/NX1o4vkk2C7mHmB2FCf9Qp7KZNHrlzHxdP/yugw=
|
||||
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.5/go.mod h1:JNo9mEKrjnmDBc19z65TZmj1xG9PQHu2GOlApYk31DU=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9 h1:onz/VaaxZ7Z4V+WIN9Txly9XLTmoOh1oJ8XcAC3pako=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9/go.mod h1:AnVH5pvai0pAF4lXRq0bmhbes1u9R8wTE+g+183bZNM=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.3 h1:9stUQR/u2KXU6HkFJYlqnZEjBnbgrVbG6I5HN09xZh0=
|
||||
|
@ -164,8 +164,8 @@ github.com/aws/aws-sdk-go-v2/service/kms v1.16.3/go.mod h1:QuiHPBqlOFCi4LqdSskYY
|
|||
github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.3 h1:xqXHk4UDW7ii4MRciyLpY87yuZds0iymmgHt3h35xTE=
|
||||
github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.3/go.mod h1:HT0cm2+NUCF33MdXjck554HC6VRgQ4q6JIlSqlYZ18Y=
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.3/go.mod h1:g1qvDuRsJY+XghsV6zg00Z4KJ7DtFFCx8fJD2a491Ak=
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.4 h1:frOI/v6KWuKGlKUA5gheRw01EDpxcCxTalFQkCOZXAo=
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.4/go.mod h1:qFKU5d+PAv+23bi9ZhtWeA+TmLUz7B/R59ZGXQ1Mmu4=
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.5 h1:A3PuAUlh1u47WHcM68CDaG9ZWjK7ewePjDp+0dY9yv4=
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.5/go.mod h1:qFKU5d+PAv+23bi9ZhtWeA+TmLUz7B/R59ZGXQ1Mmu4=
|
||||
github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.4 h1:EmIEXOjAdXtxa2OGM1VAajZV/i06Q8qd4kBpJd9/p1k=
|
||||
github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.4/go.mod h1:PJc8s+lxyU8rrre0/4a0pn2wgwiDvOEzoOjcJUBr67o=
|
||||
github.com/aws/aws-sdk-go-v2/service/sns v1.17.4/go.mod h1:kElt+uCcXxcqFyc+bQqZPFD9DME/eC6oHBXvFzQ9Bcw=
|
||||
|
@ -427,8 +427,8 @@ github.com/hashicorp/go-hclog v1.2.0 h1:La19f8d7WIlm4ogzNHB0JGqs5AUDAZ2UfCY4sJXc
|
|||
github.com/hashicorp/go-hclog v1.2.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
|
||||
github.com/hashicorp/go-plugin v1.4.3 h1:DXmvivbWD5qdiBts9TpBC7BYL1Aia5sxbRgQB+v6UZM=
|
||||
github.com/hashicorp/go-plugin v1.4.3/go.mod h1:5fGEH17QVwTTcR0zV7yhDPLLmFX9YSZ38b18Udy6vYQ=
|
||||
github.com/hashicorp/go-retryablehttp v0.7.0 h1:eu1EI/mbirUgP5C8hVsTNaGZreBDlYiwC1FZWkvQPQ4=
|
||||
github.com/hashicorp/go-retryablehttp v0.7.0/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY=
|
||||
github.com/hashicorp/go-retryablehttp v0.7.1 h1:sUiuQAnLlbvmExtFQs72iFW/HXeUn8Z1aJLQ4LJJbTQ=
|
||||
github.com/hashicorp/go-retryablehttp v0.7.1/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY=
|
||||
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||
|
@ -549,8 +549,8 @@ github.com/lestrrat-go/iter v1.0.1/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbq
|
|||
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
|
||||
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
|
||||
github.com/lestrrat-go/jwx v1.2.6/go.mod h1:tJuGuAI3LC71IicTx82Mz1n3w9woAs2bYJZpkjJQ5aU=
|
||||
github.com/lestrrat-go/jwx v1.2.22 h1:bCxMokwHNuJHVxgANP4OBddXGtQ9Oy+6cqp4O2rW7DU=
|
||||
github.com/lestrrat-go/jwx v1.2.22/go.mod h1:sAXjRwzSvCN6soO4RLoWWm1bVPpb8iOuv0IYfH8OWd8=
|
||||
github.com/lestrrat-go/jwx v1.2.23 h1:8oP5fY1yzCRraUNNyfAVdOkLCqY7xMZz11lVcvHqC1Y=
|
||||
github.com/lestrrat-go/jwx v1.2.23/go.mod h1:sAXjRwzSvCN6soO4RLoWWm1bVPpb8iOuv0IYfH8OWd8=
|
||||
github.com/lestrrat-go/option v1.0.0 h1:WqAWL8kh8VcSoD6xjSH34/1m8yxluXQbDeKNfvFeEO4=
|
||||
github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
|
||||
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
|
@ -627,6 +627,8 @@ github.com/otiai10/mint v1.3.3 h1:7JgpsBaN0uMkyju4tbYHu0mnM55hNKVYLsXmwr15NQI=
|
|||
github.com/otiai10/mint v1.3.3/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc=
|
||||
github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM=
|
||||
github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
|
||||
github.com/pelletier/go-toml/v2 v2.0.0-beta.8 h1:dy81yyLYJDwMTifq24Oi/IslOslRrDSb3jwDggjz3Z0=
|
||||
github.com/pelletier/go-toml/v2 v2.0.0-beta.8/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
|
||||
github.com/pires/go-proxyproto v0.6.2 h1:KAZ7UteSOt6urjme6ZldyFm4wDe/z0ZUP0Yv0Dos0d8=
|
||||
github.com/pires/go-proxyproto v0.6.2/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY=
|
||||
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA=
|
||||
|
@ -708,8 +710,8 @@ github.com/spf13/jwalterweatherman v1.1.0 h1:ue6voC5bR5F8YxI5S67j9i582FU4Qvo2bmq
|
|||
github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.10.1 h1:nuJZuYpG7gTj/XqiUwg8bA0cp1+M2mC3J4g5luUYBKk=
|
||||
github.com/spf13/viper v1.10.1/go.mod h1:IGlFPqhNAPKRxohIzWpI5QEy4kuI7tcl5WvR+8qy1rU=
|
||||
github.com/spf13/viper v1.11.0 h1:7OX/1FS6n7jHD1zGrZTM7WtY13ZELRyosK4k93oPr44=
|
||||
github.com/spf13/viper v1.11.0/go.mod h1:djo0X/bA5+tYVoCn+C7cAYJGcVn/qYLFTG8gdUsX7Zk=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||
|
@ -934,8 +936,8 @@ golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220412071739-889880a91fd5 h1:NubxfvTRuNb4RVzWrIDAUzUvREH1HkCD4JjyQTSG9As=
|
||||
golang.org/x/sys v0.0.0-20220412071739-889880a91fd5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0=
|
||||
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
|
@ -1166,8 +1168,9 @@ google.golang.org/genproto v0.0.0-20220310185008-1973136f34c6/go.mod h1:kGP+zUP2
|
|||
google.golang.org/genproto v0.0.0-20220324131243-acbaeb5b85eb/go.mod h1:hAL49I2IFola2sVEjAn7MEwsja0xp51I0tlGAf9hz4E=
|
||||
google.golang.org/genproto v0.0.0-20220401170504-314d38edb7de/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo=
|
||||
google.golang.org/genproto v0.0.0-20220405205423-9d709892a2bf/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo=
|
||||
google.golang.org/genproto v0.0.0-20220407144326-9054f6ed7bac h1:qSNTkEN+L2mvWcLgJOR+8bdHX9rN/IdU3A1Ghpfb1Rg=
|
||||
google.golang.org/genproto v0.0.0-20220407144326-9054f6ed7bac/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo=
|
||||
google.golang.org/genproto v0.0.0-20220413183235-5e96e2839df9 h1:XGQ6tc+EnM35IAazg4y6AHmUg4oK8NXsXaILte1vRlk=
|
||||
google.golang.org/genproto v0.0.0-20220413183235-5e96e2839df9/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo=
|
||||
google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
||||
|
|
|
@ -34,7 +34,7 @@ func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, err
|
|||
connID := xid.New().String()
|
||||
protocol := getProtocolFromRequest(r)
|
||||
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
|
||||
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
|
||||
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
|
||||
sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return nil, err
|
||||
}
|
||||
|
@ -43,6 +43,10 @@ func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, err
|
|||
r.RemoteAddr, user),
|
||||
request: r,
|
||||
}
|
||||
if err = common.Connections.Add(connection); err != nil {
|
||||
sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
|
||||
return connection, err
|
||||
}
|
||||
return connection, nil
|
||||
}
|
||||
|
||||
|
@ -52,7 +56,6 @@ func readUserFolder(w http.ResponseWriter, r *http.Request) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
|
||||
|
@ -70,7 +73,6 @@ func createUserDir(w http.ResponseWriter, r *http.Request) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
|
||||
|
@ -90,22 +92,7 @@ func createUserDir(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
func renameUserDir(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
connection, err := getUserConnection(w, r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
|
||||
newName := connection.User.GetCleanedPath(r.URL.Query().Get("target"))
|
||||
err = connection.Rename(oldName, newName)
|
||||
if err != nil {
|
||||
sendAPIResponse(w, r, err, fmt.Sprintf("Unable to rename directory %#v to %#v", oldName, newName),
|
||||
getMappedStatusCode(err))
|
||||
return
|
||||
}
|
||||
sendAPIResponse(w, r, nil, fmt.Sprintf("Directory %#v renamed to %#v", oldName, newName), http.StatusOK)
|
||||
renameItem(w, r)
|
||||
}
|
||||
|
||||
func deleteUserDir(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -114,7 +101,6 @@ func deleteUserDir(w http.ResponseWriter, r *http.Request) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
|
||||
|
@ -132,7 +118,6 @@ func getUserFile(w http.ResponseWriter, r *http.Request) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
|
||||
|
@ -183,7 +168,6 @@ func setFileDirMetadata(w http.ResponseWriter, r *http.Request) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
|
||||
|
@ -214,7 +198,6 @@ func uploadUserFile(w http.ResponseWriter, r *http.Request) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
filePath := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
|
||||
|
@ -258,6 +241,8 @@ func uploadUserFiles(w http.ResponseWriter, r *http.Request) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
transferQuota := connection.GetTransferQuota()
|
||||
if !transferQuota.HasUploadSpace() {
|
||||
connection.Log(logger.LevelInfo, "denying file write due to transfer quota limits")
|
||||
|
@ -265,8 +250,6 @@ func uploadUserFiles(w http.ResponseWriter, r *http.Request) {
|
|||
http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
t := newThrottledReader(r.Body, connection.User.UploadBandwidth, connection)
|
||||
r.Body = t
|
||||
|
@ -332,22 +315,7 @@ func doUploadFiles(w http.ResponseWriter, r *http.Request, connection *Connectio
|
|||
|
||||
func renameUserFile(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
connection, err := getUserConnection(w, r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
|
||||
newName := connection.User.GetCleanedPath(r.URL.Query().Get("target"))
|
||||
err = connection.Rename(oldName, newName)
|
||||
if err != nil {
|
||||
sendAPIResponse(w, r, err, fmt.Sprintf("Unable to rename file %#v to %#v", oldName, newName),
|
||||
getMappedStatusCode(err))
|
||||
return
|
||||
}
|
||||
sendAPIResponse(w, r, nil, fmt.Sprintf("File %#v renamed to %#v", oldName, newName), http.StatusOK)
|
||||
renameItem(w, r)
|
||||
}
|
||||
|
||||
func deleteUserFile(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -356,7 +324,6 @@ func deleteUserFile(w http.ResponseWriter, r *http.Request) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
|
||||
|
@ -393,7 +360,6 @@ func getUserFilesAsZipStream(w http.ResponseWriter, r *http.Request) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
var filesList []string
|
||||
|
@ -581,3 +547,21 @@ func setModificationTimeFromHeader(r *http.Request, c *Connection, filePath stri
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func renameItem(w http.ResponseWriter, r *http.Request) {
|
||||
connection, err := getUserConnection(w, r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
|
||||
newName := connection.User.GetCleanedPath(r.URL.Query().Get("target"))
|
||||
err = connection.Rename(oldName, newName)
|
||||
if err != nil {
|
||||
sendAPIResponse(w, r, err, fmt.Sprintf("Unable to rename %#v -> %#v", oldName, newName),
|
||||
getMappedStatusCode(err))
|
||||
return
|
||||
}
|
||||
sendAPIResponse(w, r, nil, fmt.Sprintf("%#v renamed to %#v", oldName, newName), http.StatusOK)
|
||||
}
|
||||
|
|
|
@ -167,7 +167,10 @@ func (s *httpdServer) readBrowsableShareContents(w http.ResponseWriter, r *http.
|
|||
return
|
||||
}
|
||||
|
||||
common.Connections.Add(connection)
|
||||
if err = common.Connections.Add(connection); err != nil {
|
||||
sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
contents, err := connection.ReadDir(name)
|
||||
|
@ -194,7 +197,10 @@ func (s *httpdServer) downloadBrowsableSharedFile(w http.ResponseWriter, r *http
|
|||
return
|
||||
}
|
||||
|
||||
common.Connections.Add(connection)
|
||||
if err = common.Connections.Add(connection); err != nil {
|
||||
sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
info, err := connection.Stat(name, 1)
|
||||
|
@ -231,7 +237,10 @@ func (s *httpdServer) downloadFromShare(w http.ResponseWriter, r *http.Request)
|
|||
return
|
||||
}
|
||||
|
||||
common.Connections.Add(connection)
|
||||
if err = common.Connections.Add(connection); err != nil {
|
||||
sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
compress := true
|
||||
|
@ -289,7 +298,10 @@ func (s *httpdServer) uploadFileToShare(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck
|
||||
|
||||
common.Connections.Add(connection)
|
||||
if err = common.Connections.Add(connection); err != nil {
|
||||
sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
if err := doUploadFile(w, r, connection, filePath); err != nil {
|
||||
dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck
|
||||
|
@ -313,7 +325,10 @@ func (s *httpdServer) uploadFilesToShare(w http.ResponseWriter, r *http.Request)
|
|||
return
|
||||
}
|
||||
|
||||
common.Connections.Add(connection)
|
||||
if err = common.Connections.Add(connection); err != nil {
|
||||
sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
t := newThrottledReader(r.Body, connection.User.UploadBandwidth, connection)
|
||||
|
|
|
@ -498,7 +498,7 @@ func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err err
|
|||
dataprovider.ExecutePostLoginHook(user, loginMethod, ip, protocol, err)
|
||||
}
|
||||
|
||||
func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string) error {
|
||||
func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string, checkSessions bool) error {
|
||||
if util.IsStringInSlice(common.ProtocolHTTP, user.Filters.DeniedProtocols) {
|
||||
logger.Info(logSender, connectionID, "cannot login user %#v, protocol HTTP is not allowed", user.Username)
|
||||
return fmt.Errorf("protocol HTTP is not allowed for user %#v", user.Username)
|
||||
|
@ -507,7 +507,7 @@ func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID
|
|||
logger.Info(logSender, connectionID, "cannot login user %#v, password login method is not allowed", user.Username)
|
||||
return fmt.Errorf("login method password is not allowed for user %#v", user.Username)
|
||||
}
|
||||
if user.MaxSessions > 0 {
|
||||
if checkSessions && user.MaxSessions > 0 {
|
||||
activeSessions := common.Connections.GetActiveSessions(user.Username)
|
||||
if activeSessions >= user.MaxSessions {
|
||||
logger.Info(logSender, connectionID, "authentication refused for user: %#v, too many open sessions: %v/%v", user.Username,
|
||||
|
|
|
@ -3857,7 +3857,8 @@ func TestCloseActiveConnection(t *testing.T) {
|
|||
fakeConn := &fakeConnection{
|
||||
BaseConnection: c,
|
||||
}
|
||||
common.Connections.Add(fakeConn)
|
||||
err = common.Connections.Add(fakeConn)
|
||||
assert.NoError(t, err)
|
||||
_, err = httpdtest.CloseConnection(c.GetID(), http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, common.Connections.GetStats(), 0)
|
||||
|
@ -3870,12 +3871,14 @@ func TestCloseConnectionAfterUserUpdateDelete(t *testing.T) {
|
|||
fakeConn := &fakeConnection{
|
||||
BaseConnection: c,
|
||||
}
|
||||
common.Connections.Add(fakeConn)
|
||||
err = common.Connections.Add(fakeConn)
|
||||
assert.NoError(t, err)
|
||||
c1 := common.NewBaseConnection("connID1", common.ProtocolSFTP, "", "", user)
|
||||
fakeConn1 := &fakeConnection{
|
||||
BaseConnection: c1,
|
||||
}
|
||||
common.Connections.Add(fakeConn1)
|
||||
err = common.Connections.Add(fakeConn1)
|
||||
assert.NoError(t, err)
|
||||
user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "0")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, common.Connections.GetStats(), 2)
|
||||
|
@ -3883,8 +3886,10 @@ func TestCloseConnectionAfterUserUpdateDelete(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Len(t, common.Connections.GetStats(), 0)
|
||||
|
||||
common.Connections.Add(fakeConn)
|
||||
common.Connections.Add(fakeConn1)
|
||||
err = common.Connections.Add(fakeConn)
|
||||
assert.NoError(t, err)
|
||||
err = common.Connections.Add(fakeConn1)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, common.Connections.GetStats(), 2)
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
|
@ -5173,7 +5178,8 @@ func TestLoaddataMode(t *testing.T) {
|
|||
fakeConn := &fakeConnection{
|
||||
BaseConnection: c,
|
||||
}
|
||||
common.Connections.Add(fakeConn)
|
||||
err = common.Connections.Add(fakeConn)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, common.Connections.GetStats(), 1)
|
||||
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
|
@ -8714,7 +8720,8 @@ func TestWebClientMaxConnections(t *testing.T) {
|
|||
connection := &httpd.Connection{
|
||||
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user),
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
err = common.Connections.Add(connection)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword)
|
||||
assert.Error(t, err)
|
||||
|
@ -8895,20 +8902,57 @@ func TestMaxSessions(t *testing.T) {
|
|||
u.Email = "user@session.com"
|
||||
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
_, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword)
|
||||
webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword)
|
||||
assert.NoError(t, err)
|
||||
_, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
|
||||
apiToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
|
||||
assert.NoError(t, err)
|
||||
// now add a fake connection
|
||||
fs := vfs.NewOsFs("id", os.TempDir(), "")
|
||||
connection := &httpd.Connection{
|
||||
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user),
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
err = common.Connections.Add(connection)
|
||||
assert.NoError(t, err)
|
||||
_, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword)
|
||||
assert.Error(t, err)
|
||||
_, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
|
||||
assert.Error(t, err)
|
||||
// try an user API call
|
||||
req, err := http.NewRequest(http.MethodGet, userDirsPath+"/?path=%2F", nil)
|
||||
assert.NoError(t, err)
|
||||
setBearerForReq(req, apiToken)
|
||||
rr := executeRequest(req)
|
||||
checkResponseCode(t, http.StatusTooManyRequests, rr)
|
||||
assert.Contains(t, rr.Body.String(), "too many open sessions")
|
||||
// web client requests
|
||||
req, err = http.NewRequest(http.MethodGet, webClientDownloadZipPath, nil)
|
||||
assert.NoError(t, err)
|
||||
setJWTCookieForReq(req, webToken)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusTooManyRequests, rr)
|
||||
assert.Contains(t, rr.Body.String(), "too many open sessions")
|
||||
|
||||
req, err = http.NewRequest(http.MethodGet, webClientDirsPath, nil)
|
||||
assert.NoError(t, err)
|
||||
setJWTCookieForReq(req, webToken)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusTooManyRequests, rr)
|
||||
assert.Contains(t, rr.Body.String(), "too many open sessions")
|
||||
|
||||
req, err = http.NewRequest(http.MethodGet, webClientFilesPath+"?path=p", nil)
|
||||
assert.NoError(t, err)
|
||||
setJWTCookieForReq(req, webToken)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusTooManyRequests, rr)
|
||||
assert.Contains(t, rr.Body.String(), "too many open sessions")
|
||||
|
||||
req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path=file", nil)
|
||||
assert.NoError(t, err)
|
||||
setJWTCookieForReq(req, webToken)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusTooManyRequests, rr)
|
||||
assert.Contains(t, rr.Body.String(), "too many open sessions")
|
||||
|
||||
// test reset password
|
||||
smtpCfg := smtp.Config{
|
||||
Host: "127.0.0.1",
|
||||
|
@ -8924,11 +8968,11 @@ func TestMaxSessions(t *testing.T) {
|
|||
form.Set(csrfFormToken, csrfToken)
|
||||
form.Set("username", user.Username)
|
||||
lastResetCode = ""
|
||||
req, err := http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode())))
|
||||
req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode())))
|
||||
assert.NoError(t, err)
|
||||
req.RemoteAddr = defaultRemoteAddr
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
rr := executeRequest(req)
|
||||
rr = executeRequest(req)
|
||||
assert.Equal(t, http.StatusFound, rr.Code)
|
||||
assert.GreaterOrEqual(t, len(lastResetCode), 20)
|
||||
form = make(url.Values)
|
||||
|
@ -9630,6 +9674,123 @@ func TestShareUsage(t *testing.T) {
|
|||
executeRequest(req)
|
||||
}
|
||||
|
||||
func TestShareMaxSessions(t *testing.T) {
|
||||
u := getTestUser()
|
||||
u.MaxSessions = 1
|
||||
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword)
|
||||
assert.NoError(t, err)
|
||||
|
||||
share := dataprovider.Share{
|
||||
Name: "test share max sessions read",
|
||||
Scope: dataprovider.ShareScopeRead,
|
||||
Paths: []string{"/"},
|
||||
}
|
||||
asJSON, err := json.Marshal(share)
|
||||
assert.NoError(t, err)
|
||||
req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON))
|
||||
assert.NoError(t, err)
|
||||
setBearerForReq(req, token)
|
||||
rr := executeRequest(req)
|
||||
checkResponseCode(t, http.StatusCreated, rr)
|
||||
objectID := rr.Header().Get("X-Object-ID")
|
||||
assert.NotEmpty(t, objectID)
|
||||
|
||||
req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/dirs", nil)
|
||||
assert.NoError(t, err)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusOK, rr)
|
||||
// add a fake connection
|
||||
fs := vfs.NewOsFs("id", os.TempDir(), "")
|
||||
connection := &httpd.Connection{
|
||||
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user),
|
||||
}
|
||||
err = common.Connections.Add(connection)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/dirs", nil)
|
||||
assert.NoError(t, err)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusTooManyRequests, rr)
|
||||
assert.Contains(t, rr.Body.String(), "too many open sessions")
|
||||
|
||||
req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"/dirs", nil)
|
||||
assert.NoError(t, err)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusTooManyRequests, rr)
|
||||
assert.Contains(t, rr.Body.String(), "too many open sessions")
|
||||
|
||||
req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"/browse", nil)
|
||||
assert.NoError(t, err)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusTooManyRequests, rr)
|
||||
assert.Contains(t, rr.Body.String(), "too many open sessions")
|
||||
|
||||
req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/files?path=afile", nil)
|
||||
assert.NoError(t, err)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusTooManyRequests, rr)
|
||||
assert.Contains(t, rr.Body.String(), "too many open sessions")
|
||||
|
||||
req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID, nil)
|
||||
assert.NoError(t, err)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusTooManyRequests, rr)
|
||||
assert.Contains(t, rr.Body.String(), "too many open sessions")
|
||||
|
||||
req, err = http.NewRequest(http.MethodDelete, userSharesPath+"/"+objectID, nil)
|
||||
assert.NoError(t, err)
|
||||
setBearerForReq(req, token)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusOK, rr)
|
||||
|
||||
// now test a write share
|
||||
share = dataprovider.Share{
|
||||
Name: "test share max sessions write",
|
||||
Scope: dataprovider.ShareScopeWrite,
|
||||
Paths: []string{"/"},
|
||||
}
|
||||
asJSON, err = json.Marshal(share)
|
||||
assert.NoError(t, err)
|
||||
req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON))
|
||||
assert.NoError(t, err)
|
||||
setBearerForReq(req, token)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusCreated, rr)
|
||||
objectID = rr.Header().Get("X-Object-ID")
|
||||
assert.NotEmpty(t, objectID)
|
||||
|
||||
req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "file.txt"), bytes.NewBuffer([]byte("content")))
|
||||
assert.NoError(t, err)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusTooManyRequests, rr)
|
||||
assert.Contains(t, rr.Body.String(), "too many open sessions")
|
||||
|
||||
body := new(bytes.Buffer)
|
||||
writer := multipart.NewWriter(body)
|
||||
part1, err := writer.CreateFormFile("filenames", "file1.txt")
|
||||
assert.NoError(t, err)
|
||||
_, err = part1.Write([]byte("file1 content"))
|
||||
assert.NoError(t, err)
|
||||
err = writer.Close()
|
||||
assert.NoError(t, err)
|
||||
reader := bytes.NewReader(body.Bytes())
|
||||
req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader)
|
||||
assert.NoError(t, err)
|
||||
req.Header.Add("Content-Type", writer.FormDataContentType())
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusTooManyRequests, rr)
|
||||
assert.Contains(t, rr.Body.String(), "too many open sessions")
|
||||
|
||||
common.Connections.Remove(connection.GetID())
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, common.Connections.GetStats(), 0)
|
||||
}
|
||||
|
||||
func TestShareUploadSingle(t *testing.T) {
|
||||
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
|
|
|
@ -434,7 +434,7 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu
|
|||
return err
|
||||
}
|
||||
connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
|
||||
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
|
||||
if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -331,7 +331,7 @@ func (t *oidcToken) getUser(r *http.Request) error {
|
|||
return err
|
||||
}
|
||||
connectionID := fmt.Sprintf("%v_%v", common.ProtocolOIDC, xid.New().String())
|
||||
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
|
||||
if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodIDP, ipAddr, err)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -228,7 +228,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
|
|||
return
|
||||
}
|
||||
connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
|
||||
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
|
||||
if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
|
||||
s.renderClientLoginPage(w, err.Error(), ipAddr)
|
||||
return
|
||||
|
@ -268,7 +268,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
|
|||
return
|
||||
}
|
||||
connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String())
|
||||
if err := checkHTTPClientUser(user, r, connectionID); err != nil {
|
||||
if err := checkHTTPClientUser(user, r, connectionID, true); err != nil {
|
||||
s.renderClientResetPwdPage(w, fmt.Sprintf("Password reset successfully but unable to login: %v", err.Error()), ipAddr)
|
||||
return
|
||||
}
|
||||
|
@ -760,7 +760,7 @@ func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
|
||||
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
|
||||
if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
|
||||
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
|
||||
sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return
|
||||
|
@ -920,7 +920,7 @@ func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request,
|
|||
logger.Debug(logSender, "", "signature mismatch for user %#v, unable to refresh cookie", user.Username)
|
||||
return
|
||||
}
|
||||
if err := checkHTTPClientUser(&user, r, xid.New().String()); err != nil {
|
||||
if err := checkHTTPClientUser(&user, r, xid.New().String(), true); err != nil {
|
||||
logger.Debug(logSender, "", "unable to refresh cookie for user %#v: %v", user.Username, err)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -595,7 +595,7 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.
|
|||
connID := xid.New().String()
|
||||
protocol := getProtocolFromRequest(r)
|
||||
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
|
||||
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
|
||||
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
|
||||
s.renderClientForbiddenPage(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
@ -604,7 +604,10 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.
|
|||
r.RemoteAddr, user),
|
||||
request: r,
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
if err = common.Connections.Add(connection); err != nil {
|
||||
s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
|
||||
|
@ -635,7 +638,10 @@ func (s *httpdServer) handleShareGetDirContents(w http.ResponseWriter, r *http.R
|
|||
s.renderClientMessagePage(w, r, "Invalid share path", "", getRespStatus(err), err, "")
|
||||
return
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
if err = common.Connections.Add(connection); err != nil {
|
||||
s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
contents, err := connection.ReadDir(name)
|
||||
|
@ -691,7 +697,10 @@ func (s *httpdServer) handleShareGetFiles(w http.ResponseWriter, r *http.Request
|
|||
return
|
||||
}
|
||||
|
||||
common.Connections.Add(connection)
|
||||
if err = common.Connections.Add(connection); err != nil {
|
||||
s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
var info os.FileInfo
|
||||
|
@ -735,7 +744,7 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.
|
|||
connID := xid.New().String()
|
||||
protocol := getProtocolFromRequest(r)
|
||||
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
|
||||
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
|
||||
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
|
||||
sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
@ -744,7 +753,10 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.
|
|||
r.RemoteAddr, user),
|
||||
request: r,
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
if err = common.Connections.Add(connection); err != nil {
|
||||
s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
|
||||
|
@ -809,7 +821,7 @@ func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Reques
|
|||
connID := xid.New().String()
|
||||
protocol := getProtocolFromRequest(r)
|
||||
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
|
||||
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
|
||||
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
|
||||
s.renderClientForbiddenPage(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
@ -818,7 +830,10 @@ func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Reques
|
|||
r.RemoteAddr, user),
|
||||
request: r,
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
if err = common.Connections.Add(connection); err != nil {
|
||||
s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
|
||||
|
@ -866,7 +881,7 @@ func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Reques
|
|||
connID := xid.New().String()
|
||||
protocol := getProtocolFromRequest(r)
|
||||
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
|
||||
if err := checkHTTPClientUser(&user, r, connectionID); err != nil {
|
||||
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
|
||||
s.renderClientForbiddenPage(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
@ -875,7 +890,10 @@ func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Reques
|
|||
r.RemoteAddr, user),
|
||||
request: r,
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
if err = common.Connections.Add(connection); err != nil {
|
||||
s.renderClientMessagePage(w, r, "Unable to add connection", "", http.StatusTooManyRequests, err, "")
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
name := connection.User.GetCleanedPath(r.URL.Query().Get("path"))
|
||||
|
|
|
@ -369,7 +369,7 @@ func TestTruncate(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPBasicHandlingCryptoFs(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -427,7 +427,7 @@ func TestSCPBasicHandlingCryptoFs(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPRecursiveCryptFs(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
|
|
@ -457,8 +457,12 @@ func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileO
|
|||
return t, nil
|
||||
}
|
||||
|
||||
// Disconnect disconnects the client closing the network connection
|
||||
// Disconnect disconnects the client by closing the channel
|
||||
func (c *Connection) Disconnect() error {
|
||||
if c.channel == nil {
|
||||
c.Log(logger.LevelWarn, "cannot disconnect a nil channel")
|
||||
return nil
|
||||
}
|
||||
return c.channel.Close()
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/eikenb/pipeat"
|
||||
"github.com/pkg/sftp"
|
||||
"github.com/rs/xid"
|
||||
"github.com/sftpgo/sdk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -2152,3 +2153,45 @@ func TestLoadRevokedUserCertsFile(t *testing.T) {
|
|||
err = os.RemoveAll(r.filePath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestMaxUserSessions(t *testing.T) {
|
||||
connection := &Connection{
|
||||
BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolSFTP, "", "", dataprovider.User{
|
||||
BaseUser: sdk.BaseUser{
|
||||
Username: "user_max_sessions",
|
||||
HomeDir: filepath.Clean(os.TempDir()),
|
||||
MaxSessions: 1,
|
||||
},
|
||||
}),
|
||||
}
|
||||
err := common.Connections.Add(connection)
|
||||
assert.NoError(t, err)
|
||||
|
||||
c := Configuration{}
|
||||
c.handleSftpConnection(nil, connection)
|
||||
|
||||
sshCmd := sshCommand{
|
||||
command: "cd",
|
||||
connection: connection,
|
||||
}
|
||||
err = sshCmd.handle()
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "too many open sessions")
|
||||
}
|
||||
scpCmd := scpCommand{
|
||||
sshCommand: sshCommand{
|
||||
command: "scp",
|
||||
connection: connection,
|
||||
},
|
||||
}
|
||||
err = scpCmd.handle()
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "too many open sessions")
|
||||
}
|
||||
err = ServeSubSystemConnection(&connection.User, connection.ID, nil, nil)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "too many open sessions")
|
||||
}
|
||||
common.Connections.Remove(connection.GetID())
|
||||
assert.Len(t, common.Connections.GetStats(), 0)
|
||||
}
|
||||
|
|
|
@ -37,7 +37,10 @@ func (c *scpCommand) handle() (err error) {
|
|||
err = common.ErrGenericFailure
|
||||
}
|
||||
}()
|
||||
common.Connections.Add(c.connection)
|
||||
if err := common.Connections.Add(c.connection); err != nil {
|
||||
logger.Info(logSender, "", "unable to add SCP connection: %v", err)
|
||||
return err
|
||||
}
|
||||
defer common.Connections.Remove(c.connection.GetID())
|
||||
|
||||
destPath := c.getDestPath()
|
||||
|
|
|
@ -562,7 +562,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
|
|||
case "subsystem":
|
||||
if string(req.Payload[4:]) == "sftp" {
|
||||
ok = true
|
||||
connection := Connection{
|
||||
connection := &Connection{
|
||||
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, conn.LocalAddr().String(),
|
||||
conn.RemoteAddr().String(), user),
|
||||
ClientVersion: string(sconn.ClientVersion()),
|
||||
|
@ -571,7 +571,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
|
|||
channel: channel,
|
||||
folderPrefix: c.FolderPrefix,
|
||||
}
|
||||
go c.handleSftpConnection(channel, &connection)
|
||||
go c.handleSftpConnection(channel, connection)
|
||||
}
|
||||
case "exec":
|
||||
// protocol will be set later inside processSSHCommand it could be SSH or SCP
|
||||
|
@ -600,7 +600,11 @@ func (c *Configuration) handleSftpConnection(channel ssh.Channel, connection *Co
|
|||
logger.Error(logSender, "", "panic in handleSftpConnection: %#v stack strace: %v", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
common.Connections.Add(connection)
|
||||
if err := common.Connections.Add(connection); err != nil {
|
||||
errClose := connection.Disconnect()
|
||||
logger.Info(logSender, "", "unable to add connection: %v, close err: %v", err, errClose)
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
// Create the server instance for the channel using the handler we created above.
|
||||
|
|
|
@ -3,6 +3,7 @@ package sftpd_test
|
|||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
|
@ -129,6 +130,7 @@ var (
|
|||
allPerms = []string{dataprovider.PermAny}
|
||||
homeBasePath string
|
||||
scpPath string
|
||||
scpForce bool
|
||||
gitPath string
|
||||
sshPath string
|
||||
hookCmdPath string
|
||||
|
@ -935,8 +937,6 @@ func TestConcurrency(t *testing.T) {
|
|||
|
||||
conn, client, err := getSftpClient(user, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
err = checkBasicSFTP(client)
|
||||
assert.NoError(t, err)
|
||||
err = sftpUploadFile(testFilePath, testFileName+strconv.Itoa(counter), testFileSize, client)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, common.Connections.GetActiveSessions(defaultUsername), 0)
|
||||
|
@ -9231,7 +9231,7 @@ func TestGitErrors(t *testing.T) {
|
|||
|
||||
// Start SCP tests
|
||||
func TestSCPBasicHandling(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -9295,7 +9295,7 @@ func TestSCPBasicHandling(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPUploadFileOverwrite(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -9376,7 +9376,7 @@ func TestSCPUploadFileOverwrite(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPRecursive(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -9485,7 +9485,7 @@ func TestSCPStartDirectory(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPPatternsFilter(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -9606,7 +9606,7 @@ func TestSCPUploadMaxSize(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPVirtualFolders(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -9658,7 +9658,7 @@ func TestSCPVirtualFolders(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPNestedFolders(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
baseUser, resp, err := httpdtest.AddUser(getTestUser(false), http.StatusCreated)
|
||||
|
@ -9791,7 +9791,7 @@ func TestSCPNestedFolders(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPVirtualFoldersQuota(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -9889,7 +9889,7 @@ func TestSCPVirtualFoldersQuota(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPPermsSubDirs(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -9929,7 +9929,7 @@ func TestSCPPermsSubDirs(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPPermCreateDirs(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -9963,7 +9963,7 @@ func TestSCPPermCreateDirs(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPPermUpload(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -9987,7 +9987,7 @@ func TestSCPPermUpload(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPPermOverwrite(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -10014,7 +10014,7 @@ func TestSCPPermOverwrite(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPPermDownload(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -10043,7 +10043,7 @@ func TestSCPPermDownload(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPQuotaSize(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -10099,7 +10099,7 @@ func TestSCPQuotaSize(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPEscapeHomeDir(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -10135,7 +10135,7 @@ func TestSCPEscapeHomeDir(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPUploadPaths(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -10170,7 +10170,7 @@ func TestSCPUploadPaths(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPOverwriteDirWithFile(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
|
@ -10194,7 +10194,7 @@ func TestSCPOverwriteDirWithFile(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPRemoteToRemote(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
if runtime.GOOS == osWindows {
|
||||
|
@ -10229,7 +10229,7 @@ func TestSCPRemoteToRemote(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSCPErrors(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
if scpPath == "" {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
u := getTestUser(true)
|
||||
|
@ -10682,6 +10682,9 @@ func getScpDownloadCommand(localPath, remotePath string, preserveTime, recursive
|
|||
if recursive {
|
||||
args = append(args, "-r")
|
||||
}
|
||||
if scpForce {
|
||||
args = append(args, "-O")
|
||||
}
|
||||
args = append(args, "-P")
|
||||
args = append(args, "2022")
|
||||
args = append(args, "-o")
|
||||
|
@ -10707,6 +10710,9 @@ func getScpUploadCommand(localPath, remotePath string, preserveTime, remoteToRem
|
|||
args = append(args, "-r")
|
||||
}
|
||||
}
|
||||
if scpForce {
|
||||
args = append(args, "-O")
|
||||
}
|
||||
args = append(args, "-P")
|
||||
args = append(args, "2022")
|
||||
args = append(args, "-o")
|
||||
|
@ -10770,6 +10776,12 @@ func checkSystemCommands() {
|
|||
logger.Warn(logSender, "", "unable to get scp command. SCP tests will be skipped, err: %v", err)
|
||||
logger.WarnToConsole("unable to get scp command. SCP tests will be skipped, err: %v", err)
|
||||
scpPath = ""
|
||||
} else {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
cmd := exec.CommandContext(ctx, scpPath, "-O")
|
||||
out, _ := cmd.CombinedOutput()
|
||||
scpForce = !strings.Contains(string(out), "option -- O")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -115,7 +115,10 @@ func (c *sshCommand) handle() (err error) {
|
|||
err = common.ErrGenericFailure
|
||||
}
|
||||
}()
|
||||
common.Connections.Add(c.connection)
|
||||
if err := common.Connections.Add(c.connection); err != nil {
|
||||
logger.Info(logSender, "", "unable to add SSH command connection: %v", err)
|
||||
return err
|
||||
}
|
||||
defer common.Connections.Remove(c.connection.GetID())
|
||||
|
||||
c.connection.UpdateLastActivity()
|
||||
|
@ -131,7 +134,7 @@ func (c *sshCommand) handle() (err error) {
|
|||
c.sendExitStatus(nil)
|
||||
} else if c.command == "pwd" {
|
||||
// hard coded response to "/"
|
||||
c.connection.channel.Write([]byte("/\n")) //nolint:errcheck
|
||||
c.connection.channel.Write([]byte(util.CleanPath(c.connection.User.Filters.StartDirectory) + "\n")) //nolint:errcheck
|
||||
c.sendExitStatus(nil)
|
||||
} else if c.command == "sftpgo-copy" {
|
||||
return c.handleSFTPGoCopy()
|
||||
|
|
|
@ -43,7 +43,6 @@ func ServeSubSystemConnection(user *dataprovider.User, connectionID string, read
|
|||
logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose)
|
||||
return err
|
||||
}
|
||||
dataprovider.UpdateLastLogin(user)
|
||||
|
||||
connection := &Connection{
|
||||
BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolSFTP, "", "", *user),
|
||||
|
@ -52,9 +51,15 @@ func ServeSubSystemConnection(user *dataprovider.User, connectionID string, read
|
|||
LocalAddr: &net.IPAddr{},
|
||||
channel: newSubsystemChannel(reader, writer),
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
err = common.Connections.Add(connection)
|
||||
if err != nil {
|
||||
errClose := user.CloseFs()
|
||||
logger.Warn(logSender, connectionID, "unable to add connection: %v close fs error: %v", err, errClose)
|
||||
return err
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
dataprovider.UpdateLastLogin(user)
|
||||
server := sftp.NewRequestServer(connection.channel, sftp.Handlers{
|
||||
FileGet: connection,
|
||||
FilePut: connection,
|
||||
|
|
|
@ -203,18 +203,14 @@ func (fs *CryptFs) ReadDir(dirname string) ([]os.FileInfo, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries, err := f.ReadDir(-1)
|
||||
list, err := f.Readdir(-1)
|
||||
f.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]os.FileInfo, len(entries))
|
||||
for idx, entry := range entries {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[idx] = fs.ConvertFileInfo(info)
|
||||
result := make([]os.FileInfo, 0, len(list))
|
||||
for _, info := range list {
|
||||
result = append(result, fs.ConvertFileInfo(info))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
|
13
vfs/osfs.go
13
vfs/osfs.go
|
@ -3,7 +3,6 @@ package vfs
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
|
@ -171,20 +170,12 @@ func (*OsFs) ReadDir(dirname string) ([]os.FileInfo, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries, err := f.ReadDir(-1)
|
||||
list, err := f.Readdir(-1)
|
||||
f.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]fs.FileInfo, len(entries))
|
||||
for idx, entry := range entries {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[idx] = info
|
||||
}
|
||||
return result, nil
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// IsUploadResumeSupported returns true if resuming uploads is supported
|
||||
|
|
|
@ -196,19 +196,25 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
updateLoginMetrics(&user, ipAddr, loginMethod, err)
|
||||
|
||||
ctx := context.WithValue(r.Context(), requestIDKey, connectionID)
|
||||
ctx = context.WithValue(ctx, requestStartKey, time.Now())
|
||||
|
||||
connection := &Connection{
|
||||
BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolWebDAV, util.GetHTTPLocalAddress(r),
|
||||
r.RemoteAddr, user),
|
||||
request: r,
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
if err = common.Connections.Add(connection); err != nil {
|
||||
errClose := user.CloseFs()
|
||||
logger.Warn(logSender, connectionID, "unable add connection: %v close fs error: %v", err, errClose)
|
||||
updateLoginMetrics(&user, ipAddr, loginMethod, err)
|
||||
http.Error(w, err.Error(), http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
updateLoginMetrics(&user, ipAddr, loginMethod, err)
|
||||
|
||||
ctx := context.WithValue(r.Context(), requestIDKey, connectionID)
|
||||
ctx = context.WithValue(ctx, requestStartKey, time.Now())
|
||||
|
||||
dataprovider.UpdateLastLogin(&user)
|
||||
|
||||
if s.checkRequestMethod(ctx, r, connection) {
|
||||
|
@ -311,14 +317,6 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo
|
|||
user.Username, loginMethod)
|
||||
return connID, fmt.Errorf("login method %v is not allowed for user %#v", loginMethod, user.Username)
|
||||
}
|
||||
if user.MaxSessions > 0 {
|
||||
activeSessions := common.Connections.GetActiveSessions(user.Username)
|
||||
if activeSessions >= user.MaxSessions {
|
||||
logger.Info(logSender, connID, "authentication refused for user: %#v, too many open sessions: %v/%v",
|
||||
user.Username, activeSessions, user.MaxSessions)
|
||||
return connID, fmt.Errorf("too many open sessions: %v", activeSessions)
|
||||
}
|
||||
}
|
||||
if !user.IsLoginFromAddrAllowed(r.RemoteAddr) {
|
||||
logger.Info(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v",
|
||||
user.Username, r.RemoteAddr)
|
||||
|
|
|
@ -1167,7 +1167,8 @@ func TestMaxConnections(t *testing.T) {
|
|||
connection := &webdavd.Connection{
|
||||
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user),
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
err = common.Connections.Add(connection)
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, checkBasicFunc(client))
|
||||
common.Connections.Remove(connection.GetID())
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
|
@ -1222,7 +1223,8 @@ func TestMaxSessions(t *testing.T) {
|
|||
connection := &webdavd.Connection{
|
||||
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user),
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
err = common.Connections.Add(connection)
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, checkBasicFunc(client))
|
||||
common.Connections.Remove(connection.GetID())
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
|
|
Loading…
Reference in a new issue