mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-22 07:30:25 +00:00
add support for limiting max concurrent client connections
This commit is contained in:
parent
ea0bf5e4c8
commit
f34462e3c3
11 changed files with 149 additions and 19 deletions
|
@ -248,6 +248,8 @@ type Configuration struct {
|
|||
// and before he tries to login. It allows you to reject the connection based on the source
|
||||
// ip address. Leave empty do disable.
|
||||
PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"`
|
||||
// Maximum number of concurrent client connections. 0 means unlimited
|
||||
MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"`
|
||||
idleTimeoutAsDuration time.Duration
|
||||
idleLoginTimeout time.Duration
|
||||
}
|
||||
|
@ -544,6 +546,18 @@ func (conns *ActiveConnections) checkIdles() {
|
|||
conns.RUnlock()
|
||||
}
|
||||
|
||||
// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
|
||||
func (conns *ActiveConnections) IsNewConnectionAllowed() bool {
|
||||
if Config.MaxTotalConnections == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
conns.RLock()
|
||||
defer conns.RUnlock()
|
||||
|
||||
return len(conns.connections) < Config.MaxTotalConnections
|
||||
}
|
||||
|
||||
// GetStats returns stats for active connections
|
||||
func (conns *ActiveConnections) GetStats() []ConnectionStatus {
|
||||
conns.RLock()
|
||||
|
|
|
@ -225,6 +225,26 @@ func TestSSHConnections(t *testing.T) {
|
|||
assert.NoError(t, sshConn3.Close())
|
||||
}
|
||||
|
||||
func TestMaxConnections(t *testing.T) {
|
||||
oldValue := Config.MaxTotalConnections
|
||||
Config.MaxTotalConnections = 1
|
||||
|
||||
assert.True(t, Connections.IsNewConnectionAllowed())
|
||||
c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}, nil)
|
||||
fakeConn := &fakeConnection{
|
||||
BaseConnection: c,
|
||||
}
|
||||
Connections.Add(fakeConn)
|
||||
assert.Len(t, Connections.GetStats(), 1)
|
||||
assert.False(t, Connections.IsNewConnectionAllowed())
|
||||
|
||||
res := Connections.Close(fakeConn.GetID())
|
||||
assert.True(t, res)
|
||||
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
|
||||
|
||||
Config.MaxTotalConnections = oldValue
|
||||
}
|
||||
|
||||
func TestIdleConnections(t *testing.T) {
|
||||
configCopy := Config
|
||||
|
||||
|
@ -310,6 +330,7 @@ func TestCloseConnection(t *testing.T) {
|
|||
fakeConn := &fakeConnection{
|
||||
BaseConnection: c,
|
||||
}
|
||||
assert.True(t, Connections.IsNewConnectionAllowed())
|
||||
Connections.Add(fakeConn)
|
||||
assert.Len(t, Connections.GetStats(), 1)
|
||||
res := Connections.Close(fakeConn.GetID())
|
||||
|
|
|
@ -68,6 +68,8 @@ func Init() {
|
|||
SetstatMode: 0,
|
||||
ProxyProtocol: 0,
|
||||
ProxyAllowed: []string{},
|
||||
PostConnectHook: "",
|
||||
MaxTotalConnections: 0,
|
||||
},
|
||||
SFTPD: sftpd.Configuration{
|
||||
Banner: defaultSFTPDBanner,
|
||||
|
@ -413,6 +415,7 @@ func setViperDefaults() {
|
|||
viper.SetDefault("common.proxy_protocol", globalConf.Common.ProxyProtocol)
|
||||
viper.SetDefault("common.proxy_allowed", globalConf.Common.ProxyAllowed)
|
||||
viper.SetDefault("common.post_connect_hook", globalConf.Common.PostConnectHook)
|
||||
viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections)
|
||||
viper.SetDefault("sftpd.bind_port", globalConf.SFTPD.BindPort)
|
||||
viper.SetDefault("sftpd.bind_address", globalConf.SFTPD.BindAddress)
|
||||
viper.SetDefault("sftpd.max_auth_tries", globalConf.SFTPD.MaxAuthTries)
|
||||
|
|
|
@ -63,6 +63,7 @@ The configuration file contains the following sections:
|
|||
- If `proxy_protocol` is set to 1 and we receive a proxy header from an IP that is not in the list then the connection will be accepted and the header will be ignored
|
||||
- If `proxy_protocol` is set to 2 and we receive a proxy header from an IP that is not in the list then the connection will be rejected
|
||||
- `post_connect_hook`, string. Absolute path to the command to execute or HTTP URL to notify. See [Post connect hook](./post-connect-hook.md) for more details. Leave empty to disable
|
||||
- `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited
|
||||
- **"sftpd"**, the configuration for the SFTP server
|
||||
- `bind_port`, integer. The port used for serving SFTP requests. 0 means disabled. Default: 2022
|
||||
- `bind_address`, string. Leave blank to listen on all available network interfaces. Default: ""
|
||||
|
|
|
@ -502,6 +502,29 @@ func TestPostConnectHook(t *testing.T) {
|
|||
common.Config.PostConnectHook = ""
|
||||
}
|
||||
|
||||
func TestMaxConnections(t *testing.T) {
|
||||
oldValue := common.Config.MaxTotalConnections
|
||||
common.Config.MaxTotalConnections = 1
|
||||
|
||||
user, _, err := httpd.AddUser(getTestUser(), http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
client, err := getFTPClient(user, true)
|
||||
if assert.NoError(t, err) {
|
||||
err = checkBasicFTP(client)
|
||||
assert.NoError(t, err)
|
||||
_, err = getFTPClient(user, false)
|
||||
assert.Error(t, err)
|
||||
err = client.Quit()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
_, err = httpd.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
|
||||
common.Config.MaxTotalConnections = oldValue
|
||||
}
|
||||
|
||||
func TestMaxSessions(t *testing.T) {
|
||||
u := getTestUser()
|
||||
u.MaxSessions = 1
|
||||
|
|
|
@ -98,8 +98,12 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) {
|
|||
|
||||
// ClientConnected is called to send the very first welcome message
|
||||
func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
|
||||
if !common.Connections.IsNewConnectionAllowed() {
|
||||
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
|
||||
return "", common.ErrConnectionDenied
|
||||
}
|
||||
if err := common.Config.ExecutePostConnectHook(cc.RemoteAddr().String(), common.ProtocolFTP); err != nil {
|
||||
return common.ErrConnectionDenied.Error(), err
|
||||
return "", err
|
||||
}
|
||||
connID := fmt.Sprintf("%v", cc.ID())
|
||||
user := dataprovider.User{}
|
||||
|
|
|
@ -277,23 +277,22 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
|
|||
logger.Error(logSender, "", "panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
if !common.Connections.IsNewConnectionAllowed() {
|
||||
logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached")
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
// Before beginning a handshake must be performed on the incoming net.Conn
|
||||
// we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
|
||||
conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck
|
||||
remoteAddr := conn.RemoteAddr()
|
||||
if err := common.Config.ExecutePostConnectHook(remoteAddr.String(), common.ProtocolSSH); err != nil {
|
||||
if err := common.Config.ExecutePostConnectHook(conn.RemoteAddr().String(), common.ProtocolSSH); err != nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
|
||||
if err != nil {
|
||||
logger.Debug(logSender, "", "failed to accept an incoming connection: %v", err)
|
||||
if _, ok := err.(*ssh.ServerAuthError); !ok {
|
||||
ip := utils.GetIPFromRemoteAddress(remoteAddr.String())
|
||||
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
|
||||
metrics.AddNoAuthTryed()
|
||||
dataprovider.ExecutePostLoginHook("", dataprovider.LoginMethodNoAuthTryed, ip, common.ProtocolSSH, err)
|
||||
}
|
||||
checkAuthError(conn, err)
|
||||
return
|
||||
}
|
||||
// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
|
||||
|
@ -315,7 +314,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
|
|||
|
||||
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
|
||||
"User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
|
||||
user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String())
|
||||
user.ID, loginType, user.Username, user.HomeDir, conn.RemoteAddr().String())
|
||||
dataprovider.UpdateLastLogin(user) //nolint:errcheck
|
||||
|
||||
sshConnection := common.NewSSHConnection(connectionID, conn)
|
||||
|
@ -354,13 +353,13 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
|
|||
switch req.Type {
|
||||
case "subsystem":
|
||||
if string(req.Payload[4:]) == "sftp" {
|
||||
fs, err := user.GetFilesystem(connectionID)
|
||||
fs, err := user.GetFilesystem(connID)
|
||||
if err == nil {
|
||||
ok = true
|
||||
connection := Connection{
|
||||
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user, fs),
|
||||
ClientVersion: string(sconn.ClientVersion()),
|
||||
RemoteAddr: remoteAddr,
|
||||
RemoteAddr: conn.RemoteAddr(),
|
||||
channel: channel,
|
||||
}
|
||||
go c.handleSftpConnection(channel, &connection)
|
||||
|
@ -368,12 +367,12 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
|
|||
}
|
||||
case "exec":
|
||||
// protocol will be set later inside processSSHCommand it could be SSH or SCP
|
||||
fs, err := user.GetFilesystem(connectionID)
|
||||
fs, err := user.GetFilesystem(connID)
|
||||
if err == nil {
|
||||
connection := Connection{
|
||||
BaseConnection: common.NewBaseConnection(connID, "sshd_exec", user, fs),
|
||||
ClientVersion: string(sconn.ClientVersion()),
|
||||
RemoteAddr: remoteAddr,
|
||||
RemoteAddr: conn.RemoteAddr(),
|
||||
channel: channel,
|
||||
}
|
||||
ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands)
|
||||
|
@ -420,6 +419,15 @@ func (c *Configuration) createHandler(connection *Connection) sftp.Handlers {
|
|||
}
|
||||
}
|
||||
|
||||
func checkAuthError(conn net.Conn, err error) {
|
||||
if _, ok := err.(*ssh.ServerAuthError); !ok {
|
||||
ip := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
|
||||
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
|
||||
metrics.AddNoAuthTryed()
|
||||
dataprovider.ExecutePostLoginHook("", dataprovider.LoginMethodNoAuthTryed, ip, common.ProtocolSSH, err)
|
||||
}
|
||||
}
|
||||
|
||||
func checkRootPath(user *dataprovider.User, connectionID string) error {
|
||||
if user.FsConfig.Provider != dataprovider.SFTPFilesystemProvider {
|
||||
// for sftp fs check root path does nothing so don't open a useless SFTP connection
|
||||
|
|
|
@ -2441,6 +2441,31 @@ func TestQuotaDisabledError(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestMaxConnections(t *testing.T) {
|
||||
oldValue := common.Config.MaxTotalConnections
|
||||
common.Config.MaxTotalConnections = 1
|
||||
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
user, _, err := httpd.AddUser(u, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
client, err := getSftpClient(user, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
defer client.Close()
|
||||
assert.NoError(t, checkBasicSFTP(client))
|
||||
c, err := getSftpClient(user, usePubKey)
|
||||
if !assert.Error(t, err, "max sessions exceeded, new login should not succeed") {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
_, err = httpd.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
|
||||
common.Config.MaxTotalConnections = oldValue
|
||||
}
|
||||
|
||||
func TestMaxSessions(t *testing.T) {
|
||||
usePubKey := false
|
||||
u := getTestUser(usePubKey)
|
||||
|
|
|
@ -9,7 +9,8 @@
|
|||
"setstat_mode": 0,
|
||||
"proxy_protocol": 0,
|
||||
"proxy_allowed": [],
|
||||
"post_connect_hook": ""
|
||||
"post_connect_hook": "",
|
||||
"max_total_connections": 0
|
||||
},
|
||||
"sftpd": {
|
||||
"bind_port": 2022,
|
||||
|
|
|
@ -112,6 +112,11 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
if !common.Connections.IsNewConnectionAllowed() {
|
||||
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
|
||||
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
checkRemoteAddress(r)
|
||||
if err := common.Config.ExecutePostConnectHook(r.RemoteAddr, common.ProtocolWebDAV); err != nil {
|
||||
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
|
||||
|
|
|
@ -650,6 +650,31 @@ func TestPostConnectHook(t *testing.T) {
|
|||
common.Config.PostConnectHook = ""
|
||||
}
|
||||
|
||||
func TestMaxConnections(t *testing.T) {
|
||||
oldValue := common.Config.MaxTotalConnections
|
||||
common.Config.MaxTotalConnections = 1
|
||||
|
||||
user, _, err := httpd.AddUser(getTestUser(), http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
client := getWebDavClient(user)
|
||||
assert.NoError(t, checkBasicFunc(client))
|
||||
// now add a fake connection
|
||||
fs := vfs.NewOsFs("id", os.TempDir(), nil)
|
||||
connection := &webdavd.Connection{
|
||||
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
|
||||
}
|
||||
common.Connections.Add(connection)
|
||||
assert.Error(t, checkBasicFunc(client))
|
||||
common.Connections.Remove(connection.GetID())
|
||||
_, err = httpd.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, common.Connections.GetStats(), 0)
|
||||
|
||||
common.Config.MaxTotalConnections = oldValue
|
||||
}
|
||||
|
||||
func TestMaxSessions(t *testing.T) {
|
||||
u := getTestUser()
|
||||
u.MaxSessions = 1
|
||||
|
|
Loading…
Reference in a new issue