add support for limiting max concurrent client connections

This commit is contained in:
Nicola Murino 2020-12-15 19:29:30 +01:00
parent ea0bf5e4c8
commit f34462e3c3
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
11 changed files with 149 additions and 19 deletions

View file

@ -247,7 +247,9 @@ type Configuration struct {
// Absolute path to an external program or an HTTP URL to invoke after a user connects
// and before he tries to login. It allows you to reject the connection based on the source
// ip address. Leave empty do disable.
PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"`
PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"`
// Maximum number of concurrent client connections. 0 means unlimited
MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"`
idleTimeoutAsDuration time.Duration
idleLoginTimeout time.Duration
}
@ -544,6 +546,18 @@ func (conns *ActiveConnections) checkIdles() {
conns.RUnlock()
}
// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
func (conns *ActiveConnections) IsNewConnectionAllowed() bool {
if Config.MaxTotalConnections == 0 {
return true
}
conns.RLock()
defer conns.RUnlock()
return len(conns.connections) < Config.MaxTotalConnections
}
// GetStats returns stats for active connections
func (conns *ActiveConnections) GetStats() []ConnectionStatus {
conns.RLock()

View file

@ -225,6 +225,26 @@ func TestSSHConnections(t *testing.T) {
assert.NoError(t, sshConn3.Close())
}
func TestMaxConnections(t *testing.T) {
oldValue := Config.MaxTotalConnections
Config.MaxTotalConnections = 1
assert.True(t, Connections.IsNewConnectionAllowed())
c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}, nil)
fakeConn := &fakeConnection{
BaseConnection: c,
}
Connections.Add(fakeConn)
assert.Len(t, Connections.GetStats(), 1)
assert.False(t, Connections.IsNewConnectionAllowed())
res := Connections.Close(fakeConn.GetID())
assert.True(t, res)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
Config.MaxTotalConnections = oldValue
}
func TestIdleConnections(t *testing.T) {
configCopy := Config
@ -310,6 +330,7 @@ func TestCloseConnection(t *testing.T) {
fakeConn := &fakeConnection{
BaseConnection: c,
}
assert.True(t, Connections.IsNewConnectionAllowed())
Connections.Add(fakeConn)
assert.Len(t, Connections.GetStats(), 1)
res := Connections.Close(fakeConn.GetID())

View file

@ -65,9 +65,11 @@ func Init() {
ExecuteOn: []string{},
Hook: "",
},
SetstatMode: 0,
ProxyProtocol: 0,
ProxyAllowed: []string{},
SetstatMode: 0,
ProxyProtocol: 0,
ProxyAllowed: []string{},
PostConnectHook: "",
MaxTotalConnections: 0,
},
SFTPD: sftpd.Configuration{
Banner: defaultSFTPDBanner,
@ -413,6 +415,7 @@ func setViperDefaults() {
viper.SetDefault("common.proxy_protocol", globalConf.Common.ProxyProtocol)
viper.SetDefault("common.proxy_allowed", globalConf.Common.ProxyAllowed)
viper.SetDefault("common.post_connect_hook", globalConf.Common.PostConnectHook)
viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections)
viper.SetDefault("sftpd.bind_port", globalConf.SFTPD.BindPort)
viper.SetDefault("sftpd.bind_address", globalConf.SFTPD.BindAddress)
viper.SetDefault("sftpd.max_auth_tries", globalConf.SFTPD.MaxAuthTries)

View file

@ -63,6 +63,7 @@ The configuration file contains the following sections:
- If `proxy_protocol` is set to 1 and we receive a proxy header from an IP that is not in the list then the connection will be accepted and the header will be ignored
- If `proxy_protocol` is set to 2 and we receive a proxy header from an IP that is not in the list then the connection will be rejected
- `post_connect_hook`, string. Absolute path to the command to execute or HTTP URL to notify. See [Post connect hook](./post-connect-hook.md) for more details. Leave empty to disable
- `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited
- **"sftpd"**, the configuration for the SFTP server
- `bind_port`, integer. The port used for serving SFTP requests. 0 means disabled. Default: 2022
- `bind_address`, string. Leave blank to listen on all available network interfaces. Default: ""

View file

@ -502,6 +502,29 @@ func TestPostConnectHook(t *testing.T) {
common.Config.PostConnectHook = ""
}
func TestMaxConnections(t *testing.T) {
oldValue := common.Config.MaxTotalConnections
common.Config.MaxTotalConnections = 1
user, _, err := httpd.AddUser(getTestUser(), http.StatusOK)
assert.NoError(t, err)
client, err := getFTPClient(user, true)
if assert.NoError(t, err) {
err = checkBasicFTP(client)
assert.NoError(t, err)
_, err = getFTPClient(user, false)
assert.Error(t, err)
err = client.Quit()
assert.NoError(t, err)
}
_, err = httpd.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
common.Config.MaxTotalConnections = oldValue
}
func TestMaxSessions(t *testing.T) {
u := getTestUser()
u.MaxSessions = 1

View file

@ -98,8 +98,12 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) {
// ClientConnected is called to send the very first welcome message
func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
if !common.Connections.IsNewConnectionAllowed() {
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
return "", common.ErrConnectionDenied
}
if err := common.Config.ExecutePostConnectHook(cc.RemoteAddr().String(), common.ProtocolFTP); err != nil {
return common.ErrConnectionDenied.Error(), err
return "", err
}
connID := fmt.Sprintf("%v", cc.ID())
user := dataprovider.User{}

View file

@ -277,23 +277,22 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
logger.Error(logSender, "", "panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
}
}()
if !common.Connections.IsNewConnectionAllowed() {
logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached")
conn.Close()
return
}
// Before beginning a handshake must be performed on the incoming net.Conn
// we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck
remoteAddr := conn.RemoteAddr()
if err := common.Config.ExecutePostConnectHook(remoteAddr.String(), common.ProtocolSSH); err != nil {
if err := common.Config.ExecutePostConnectHook(conn.RemoteAddr().String(), common.ProtocolSSH); err != nil {
conn.Close()
return
}
sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
if err != nil {
logger.Debug(logSender, "", "failed to accept an incoming connection: %v", err)
if _, ok := err.(*ssh.ServerAuthError); !ok {
ip := utils.GetIPFromRemoteAddress(remoteAddr.String())
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
metrics.AddNoAuthTryed()
dataprovider.ExecutePostLoginHook("", dataprovider.LoginMethodNoAuthTryed, ip, common.ProtocolSSH, err)
}
checkAuthError(conn, err)
return
}
// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
@ -315,7 +314,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
"User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String())
user.ID, loginType, user.Username, user.HomeDir, conn.RemoteAddr().String())
dataprovider.UpdateLastLogin(user) //nolint:errcheck
sshConnection := common.NewSSHConnection(connectionID, conn)
@ -354,13 +353,13 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
switch req.Type {
case "subsystem":
if string(req.Payload[4:]) == "sftp" {
fs, err := user.GetFilesystem(connectionID)
fs, err := user.GetFilesystem(connID)
if err == nil {
ok = true
connection := Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user, fs),
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: remoteAddr,
RemoteAddr: conn.RemoteAddr(),
channel: channel,
}
go c.handleSftpConnection(channel, &connection)
@ -368,12 +367,12 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
}
case "exec":
// protocol will be set later inside processSSHCommand it could be SSH or SCP
fs, err := user.GetFilesystem(connectionID)
fs, err := user.GetFilesystem(connID)
if err == nil {
connection := Connection{
BaseConnection: common.NewBaseConnection(connID, "sshd_exec", user, fs),
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: remoteAddr,
RemoteAddr: conn.RemoteAddr(),
channel: channel,
}
ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands)
@ -420,6 +419,15 @@ func (c *Configuration) createHandler(connection *Connection) sftp.Handlers {
}
}
func checkAuthError(conn net.Conn, err error) {
if _, ok := err.(*ssh.ServerAuthError); !ok {
ip := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
metrics.AddNoAuthTryed()
dataprovider.ExecutePostLoginHook("", dataprovider.LoginMethodNoAuthTryed, ip, common.ProtocolSSH, err)
}
}
func checkRootPath(user *dataprovider.User, connectionID string) error {
if user.FsConfig.Provider != dataprovider.SFTPFilesystemProvider {
// for sftp fs check root path does nothing so don't open a useless SFTP connection

View file

@ -2441,6 +2441,31 @@ func TestQuotaDisabledError(t *testing.T) {
assert.NoError(t, err)
}
func TestMaxConnections(t *testing.T) {
oldValue := common.Config.MaxTotalConnections
common.Config.MaxTotalConnections = 1
usePubKey := true
u := getTestUser(usePubKey)
user, _, err := httpd.AddUser(u, http.StatusOK)
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer client.Close()
assert.NoError(t, checkBasicSFTP(client))
c, err := getSftpClient(user, usePubKey)
if !assert.Error(t, err, "max sessions exceeded, new login should not succeed") {
c.Close()
}
}
_, err = httpd.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
common.Config.MaxTotalConnections = oldValue
}
func TestMaxSessions(t *testing.T) {
usePubKey := false
u := getTestUser(usePubKey)

View file

@ -9,7 +9,8 @@
"setstat_mode": 0,
"proxy_protocol": 0,
"proxy_allowed": [],
"post_connect_hook": ""
"post_connect_hook": "",
"max_total_connections": 0
},
"sftpd": {
"bind_port": 2022,

View file

@ -112,6 +112,11 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
}
}()
if !common.Connections.IsNewConnectionAllowed() {
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable)
return
}
checkRemoteAddress(r)
if err := common.Config.ExecutePostConnectHook(r.RemoteAddr, common.ProtocolWebDAV); err != nil {
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)

View file

@ -650,6 +650,31 @@ func TestPostConnectHook(t *testing.T) {
common.Config.PostConnectHook = ""
}
func TestMaxConnections(t *testing.T) {
oldValue := common.Config.MaxTotalConnections
common.Config.MaxTotalConnections = 1
user, _, err := httpd.AddUser(getTestUser(), http.StatusOK)
assert.NoError(t, err)
client := getWebDavClient(user)
assert.NoError(t, checkBasicFunc(client))
// now add a fake connection
fs := vfs.NewOsFs("id", os.TempDir(), nil)
connection := &webdavd.Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
}
common.Connections.Add(connection)
assert.Error(t, checkBasicFunc(client))
common.Connections.Remove(connection.GetID())
_, err = httpd.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
assert.Len(t, common.Connections.GetStats(), 0)
common.Config.MaxTotalConnections = oldValue
}
func TestMaxSessions(t *testing.T) {
u := getTestUser()
u.MaxSessions = 1