From f34462e3c347d1b64832e85816226367a0f3f431 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Tue, 15 Dec 2020 19:29:30 +0100 Subject: [PATCH] add support for limiting max concurrent client connections --- common/common.go | 16 +++++++++++++++- common/common_test.go | 21 +++++++++++++++++++++ config/config.go | 9 ++++++--- docs/full-configuration.md | 1 + ftpd/ftpd_test.go | 23 +++++++++++++++++++++++ ftpd/server.go | 6 +++++- sftpd/server.go | 34 +++++++++++++++++++++------------- sftpd/sftpd_test.go | 25 +++++++++++++++++++++++++ sftpgo.json | 3 ++- webdavd/server.go | 5 +++++ webdavd/webdavd_test.go | 25 +++++++++++++++++++++++++ 11 files changed, 149 insertions(+), 19 deletions(-) diff --git a/common/common.go b/common/common.go index 1250ebd2..d7b0b600 100644 --- a/common/common.go +++ b/common/common.go @@ -247,7 +247,9 @@ type Configuration struct { // Absolute path to an external program or an HTTP URL to invoke after a user connects // and before he tries to login. It allows you to reject the connection based on the source // ip address. Leave empty do disable. - PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"` + PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"` + // Maximum number of concurrent client connections. 0 means unlimited + MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"` idleTimeoutAsDuration time.Duration idleLoginTimeout time.Duration } @@ -544,6 +546,18 @@ func (conns *ActiveConnections) checkIdles() { conns.RUnlock() } +// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded +func (conns *ActiveConnections) IsNewConnectionAllowed() bool { + if Config.MaxTotalConnections == 0 { + return true + } + + conns.RLock() + defer conns.RUnlock() + + return len(conns.connections) < Config.MaxTotalConnections +} + // GetStats returns stats for active connections func (conns *ActiveConnections) GetStats() []ConnectionStatus { conns.RLock() diff --git a/common/common_test.go b/common/common_test.go index 225da90a..44f3708e 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -225,6 +225,26 @@ func TestSSHConnections(t *testing.T) { assert.NoError(t, sshConn3.Close()) } +func TestMaxConnections(t *testing.T) { + oldValue := Config.MaxTotalConnections + Config.MaxTotalConnections = 1 + + assert.True(t, Connections.IsNewConnectionAllowed()) + c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}, nil) + fakeConn := &fakeConnection{ + BaseConnection: c, + } + Connections.Add(fakeConn) + assert.Len(t, Connections.GetStats(), 1) + assert.False(t, Connections.IsNewConnectionAllowed()) + + res := Connections.Close(fakeConn.GetID()) + assert.True(t, res) + assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond) + + Config.MaxTotalConnections = oldValue +} + func TestIdleConnections(t *testing.T) { configCopy := Config @@ -310,6 +330,7 @@ func TestCloseConnection(t *testing.T) { fakeConn := &fakeConnection{ BaseConnection: c, } + assert.True(t, Connections.IsNewConnectionAllowed()) Connections.Add(fakeConn) assert.Len(t, Connections.GetStats(), 1) res := Connections.Close(fakeConn.GetID()) diff --git a/config/config.go b/config/config.go index 1da9322c..e93181f3 100644 --- a/config/config.go +++ b/config/config.go @@ -65,9 +65,11 @@ func Init() { ExecuteOn: []string{}, Hook: "", }, - SetstatMode: 0, - ProxyProtocol: 0, - ProxyAllowed: []string{}, + SetstatMode: 0, + ProxyProtocol: 0, + ProxyAllowed: []string{}, + PostConnectHook: "", + MaxTotalConnections: 0, }, SFTPD: sftpd.Configuration{ Banner: defaultSFTPDBanner, @@ -413,6 +415,7 @@ func setViperDefaults() { viper.SetDefault("common.proxy_protocol", globalConf.Common.ProxyProtocol) viper.SetDefault("common.proxy_allowed", globalConf.Common.ProxyAllowed) viper.SetDefault("common.post_connect_hook", globalConf.Common.PostConnectHook) + viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections) viper.SetDefault("sftpd.bind_port", globalConf.SFTPD.BindPort) viper.SetDefault("sftpd.bind_address", globalConf.SFTPD.BindAddress) viper.SetDefault("sftpd.max_auth_tries", globalConf.SFTPD.MaxAuthTries) diff --git a/docs/full-configuration.md b/docs/full-configuration.md index 03b9e0f4..447924c1 100644 --- a/docs/full-configuration.md +++ b/docs/full-configuration.md @@ -63,6 +63,7 @@ The configuration file contains the following sections: - If `proxy_protocol` is set to 1 and we receive a proxy header from an IP that is not in the list then the connection will be accepted and the header will be ignored - If `proxy_protocol` is set to 2 and we receive a proxy header from an IP that is not in the list then the connection will be rejected - `post_connect_hook`, string. Absolute path to the command to execute or HTTP URL to notify. See [Post connect hook](./post-connect-hook.md) for more details. Leave empty to disable + - `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited - **"sftpd"**, the configuration for the SFTP server - `bind_port`, integer. The port used for serving SFTP requests. 0 means disabled. Default: 2022 - `bind_address`, string. Leave blank to listen on all available network interfaces. Default: "" diff --git a/ftpd/ftpd_test.go b/ftpd/ftpd_test.go index d7de5843..989a804e 100644 --- a/ftpd/ftpd_test.go +++ b/ftpd/ftpd_test.go @@ -502,6 +502,29 @@ func TestPostConnectHook(t *testing.T) { common.Config.PostConnectHook = "" } +func TestMaxConnections(t *testing.T) { + oldValue := common.Config.MaxTotalConnections + common.Config.MaxTotalConnections = 1 + + user, _, err := httpd.AddUser(getTestUser(), http.StatusOK) + assert.NoError(t, err) + client, err := getFTPClient(user, true) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + _, err = getFTPClient(user, false) + assert.Error(t, err) + err = client.Quit() + assert.NoError(t, err) + } + _, err = httpd.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.MaxTotalConnections = oldValue +} + func TestMaxSessions(t *testing.T) { u := getTestUser() u.MaxSessions = 1 diff --git a/ftpd/server.go b/ftpd/server.go index 475ea557..89306119 100644 --- a/ftpd/server.go +++ b/ftpd/server.go @@ -98,8 +98,12 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) { // ClientConnected is called to send the very first welcome message func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) { + if !common.Connections.IsNewConnectionAllowed() { + logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached") + return "", common.ErrConnectionDenied + } if err := common.Config.ExecutePostConnectHook(cc.RemoteAddr().String(), common.ProtocolFTP); err != nil { - return common.ErrConnectionDenied.Error(), err + return "", err } connID := fmt.Sprintf("%v", cc.ID()) user := dataprovider.User{} diff --git a/sftpd/server.go b/sftpd/server.go index f5e955bb..6012d0c2 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -277,23 +277,22 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve logger.Error(logSender, "", "panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack())) } }() + if !common.Connections.IsNewConnectionAllowed() { + logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached") + conn.Close() + return + } // Before beginning a handshake must be performed on the incoming net.Conn // we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck - remoteAddr := conn.RemoteAddr() - if err := common.Config.ExecutePostConnectHook(remoteAddr.String(), common.ProtocolSSH); err != nil { + if err := common.Config.ExecutePostConnectHook(conn.RemoteAddr().String(), common.ProtocolSSH); err != nil { conn.Close() return } sconn, chans, reqs, err := ssh.NewServerConn(conn, config) if err != nil { logger.Debug(logSender, "", "failed to accept an incoming connection: %v", err) - if _, ok := err.(*ssh.ServerAuthError); !ok { - ip := utils.GetIPFromRemoteAddress(remoteAddr.String()) - logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error()) - metrics.AddNoAuthTryed() - dataprovider.ExecutePostLoginHook("", dataprovider.LoginMethodNoAuthTryed, ip, common.ProtocolSSH, err) - } + checkAuthError(conn, err) return } // handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on @@ -315,7 +314,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v", - user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String()) + user.ID, loginType, user.Username, user.HomeDir, conn.RemoteAddr().String()) dataprovider.UpdateLastLogin(user) //nolint:errcheck sshConnection := common.NewSSHConnection(connectionID, conn) @@ -354,13 +353,13 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve switch req.Type { case "subsystem": if string(req.Payload[4:]) == "sftp" { - fs, err := user.GetFilesystem(connectionID) + fs, err := user.GetFilesystem(connID) if err == nil { ok = true connection := Connection{ BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user, fs), ClientVersion: string(sconn.ClientVersion()), - RemoteAddr: remoteAddr, + RemoteAddr: conn.RemoteAddr(), channel: channel, } go c.handleSftpConnection(channel, &connection) @@ -368,12 +367,12 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve } case "exec": // protocol will be set later inside processSSHCommand it could be SSH or SCP - fs, err := user.GetFilesystem(connectionID) + fs, err := user.GetFilesystem(connID) if err == nil { connection := Connection{ BaseConnection: common.NewBaseConnection(connID, "sshd_exec", user, fs), ClientVersion: string(sconn.ClientVersion()), - RemoteAddr: remoteAddr, + RemoteAddr: conn.RemoteAddr(), channel: channel, } ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands) @@ -420,6 +419,15 @@ func (c *Configuration) createHandler(connection *Connection) sftp.Handlers { } } +func checkAuthError(conn net.Conn, err error) { + if _, ok := err.(*ssh.ServerAuthError); !ok { + ip := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()) + logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error()) + metrics.AddNoAuthTryed() + dataprovider.ExecutePostLoginHook("", dataprovider.LoginMethodNoAuthTryed, ip, common.ProtocolSSH, err) + } +} + func checkRootPath(user *dataprovider.User, connectionID string) error { if user.FsConfig.Provider != dataprovider.SFTPFilesystemProvider { // for sftp fs check root path does nothing so don't open a useless SFTP connection diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index dfb48595..8cf293fc 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -2441,6 +2441,31 @@ func TestQuotaDisabledError(t *testing.T) { assert.NoError(t, err) } +func TestMaxConnections(t *testing.T) { + oldValue := common.Config.MaxTotalConnections + common.Config.MaxTotalConnections = 1 + + usePubKey := true + u := getTestUser(usePubKey) + user, _, err := httpd.AddUser(u, http.StatusOK) + assert.NoError(t, err) + client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + c, err := getSftpClient(user, usePubKey) + if !assert.Error(t, err, "max sessions exceeded, new login should not succeed") { + c.Close() + } + } + _, err = httpd.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.MaxTotalConnections = oldValue +} + func TestMaxSessions(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) diff --git a/sftpgo.json b/sftpgo.json index 17e48f00..8736fa03 100644 --- a/sftpgo.json +++ b/sftpgo.json @@ -9,7 +9,8 @@ "setstat_mode": 0, "proxy_protocol": 0, "proxy_allowed": [], - "post_connect_hook": "" + "post_connect_hook": "", + "max_total_connections": 0 }, "sftpd": { "bind_port": 2022, diff --git a/webdavd/server.go b/webdavd/server.go index 4adedf72..8cabedf6 100644 --- a/webdavd/server.go +++ b/webdavd/server.go @@ -112,6 +112,11 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError) } }() + if !common.Connections.IsNewConnectionAllowed() { + logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached") + http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable) + return + } checkRemoteAddress(r) if err := common.Config.ExecutePostConnectHook(r.RemoteAddr, common.ProtocolWebDAV); err != nil { http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden) diff --git a/webdavd/webdavd_test.go b/webdavd/webdavd_test.go index b5dc04ef..ceb3c64c 100644 --- a/webdavd/webdavd_test.go +++ b/webdavd/webdavd_test.go @@ -650,6 +650,31 @@ func TestPostConnectHook(t *testing.T) { common.Config.PostConnectHook = "" } +func TestMaxConnections(t *testing.T) { + oldValue := common.Config.MaxTotalConnections + common.Config.MaxTotalConnections = 1 + + user, _, err := httpd.AddUser(getTestUser(), http.StatusOK) + assert.NoError(t, err) + client := getWebDavClient(user) + assert.NoError(t, checkBasicFunc(client)) + // now add a fake connection + fs := vfs.NewOsFs("id", os.TempDir(), nil) + connection := &webdavd.Connection{ + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs), + } + common.Connections.Add(connection) + assert.Error(t, checkBasicFunc(client)) + common.Connections.Remove(connection.GetID()) + _, err = httpd.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Len(t, common.Connections.GetStats(), 0) + + common.Config.MaxTotalConnections = oldValue +} + func TestMaxSessions(t *testing.T) { u := getTestUser() u.MaxSessions = 1