From 8f6cdacd008925f7c94c070e4fda44da91664956 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sat, 8 May 2021 19:45:21 +0200 Subject: [PATCH] allow to limit the number of per-host connections --- common/clientsmap.go | 51 ++++++++++++++++++++++++++++++++ common/clientsmap_test.go | 59 +++++++++++++++++++++++++++++++++++++ common/common.go | 60 +++++++++++++++++++++++++------------- common/common_test.go | 51 +++++++++++++++++++++++++------- common/defender.go | 15 +++++----- common/defender_test.go | 32 ++++++++++---------- common/ratelimiter.go | 2 +- config/config.go | 38 ++++++++++++------------ docs/defender.md | 2 +- docs/full-configuration.md | 5 ++-- docs/rate-limiting.md | 2 +- ftpd/ftpd_test.go | 27 +++++++++++++++-- ftpd/server.go | 6 ++-- httpd/httpd_test.go | 4 +-- httpd/server.go | 9 +++--- httpd/webclient.go | 8 ++--- sftpd/server.go | 8 ++--- sftpd/sftpd_test.go | 38 ++++++++++++++++++++++-- sftpgo.json | 3 +- webdavd/server.go | 12 ++++---- webdavd/webdavd_test.go | 29 +++++++++++++++++- 21 files changed, 356 insertions(+), 105 deletions(-) create mode 100644 common/clientsmap.go create mode 100644 common/clientsmap_test.go diff --git a/common/clientsmap.go b/common/clientsmap.go new file mode 100644 index 00000000..93212a09 --- /dev/null +++ b/common/clientsmap.go @@ -0,0 +1,51 @@ +package common + +import ( + "sync" + "sync/atomic" + + "github.com/drakkan/sftpgo/logger" +) + +// clienstMap is a struct containing the map of the connected clients +type clientsMap struct { + totalConnections int32 + mu sync.RWMutex + clients map[string]int +} + +func (c *clientsMap) add(source string) { + atomic.AddInt32(&c.totalConnections, 1) + + c.mu.Lock() + defer c.mu.Unlock() + + c.clients[source]++ +} + +func (c *clientsMap) remove(source string) { + c.mu.Lock() + defer c.mu.Unlock() + + if val, ok := c.clients[source]; ok { + atomic.AddInt32(&c.totalConnections, -1) + c.clients[source]-- + if val > 1 { + return + } + delete(c.clients, source) + } else { + logger.Warn(logSender, "", "cannot remove client %v it is not mapped", source) + } +} + +func (c *clientsMap) getTotal() int32 { + return atomic.LoadInt32(&c.totalConnections) +} + +func (c *clientsMap) getTotalFrom(source string) int { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.clients[source] +} diff --git a/common/clientsmap_test.go b/common/clientsmap_test.go new file mode 100644 index 00000000..fa851984 --- /dev/null +++ b/common/clientsmap_test.go @@ -0,0 +1,59 @@ +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestClientsMap(t *testing.T) { + m := clientsMap{ + clients: make(map[string]int), + } + ip1 := "192.168.1.1" + ip2 := "192.168.1.2" + m.add(ip1) + assert.Equal(t, int32(1), m.getTotal()) + assert.Equal(t, 1, m.getTotalFrom(ip1)) + assert.Equal(t, 0, m.getTotalFrom(ip2)) + + m.add(ip1) + m.add(ip2) + assert.Equal(t, int32(3), m.getTotal()) + assert.Equal(t, 2, m.getTotalFrom(ip1)) + assert.Equal(t, 1, m.getTotalFrom(ip2)) + + m.add(ip1) + m.add(ip1) + m.add(ip2) + assert.Equal(t, int32(6), m.getTotal()) + assert.Equal(t, 4, m.getTotalFrom(ip1)) + assert.Equal(t, 2, m.getTotalFrom(ip2)) + + m.remove(ip2) + assert.Equal(t, int32(5), m.getTotal()) + assert.Equal(t, 4, m.getTotalFrom(ip1)) + assert.Equal(t, 1, m.getTotalFrom(ip2)) + + m.remove("unknown") + assert.Equal(t, int32(5), m.getTotal()) + assert.Equal(t, 4, m.getTotalFrom(ip1)) + assert.Equal(t, 1, m.getTotalFrom(ip2)) + + m.remove(ip2) + assert.Equal(t, int32(4), m.getTotal()) + assert.Equal(t, 4, m.getTotalFrom(ip1)) + assert.Equal(t, 0, m.getTotalFrom(ip2)) + + m.remove(ip1) + m.remove(ip1) + m.remove(ip1) + assert.Equal(t, int32(1), m.getTotal()) + assert.Equal(t, 1, m.getTotalFrom(ip1)) + assert.Equal(t, 0, m.getTotalFrom(ip2)) + + m.remove(ip1) + assert.Equal(t, int32(0), m.getTotal()) + assert.Equal(t, 0, m.getTotalFrom(ip1)) + assert.Equal(t, 0, m.getTotalFrom(ip2)) +} diff --git a/common/common.go b/common/common.go index 45557af5..fe753aca 100644 --- a/common/common.go +++ b/common/common.go @@ -80,6 +80,12 @@ const ( UploadModeAtomicWithResume ) +func init() { + Connections.clients = clientsMap{ + clients: make(map[string]int), + } +} + // errors definitions var ( ErrPermissionDenied = errors.New("permission denied") @@ -352,6 +358,8 @@ type Configuration struct { PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"` // Maximum number of concurrent client connections. 0 means unlimited MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"` + // Maximum number of concurrent client connections from the same host (IP). 0 means unlimited + MaxPerHostConnections int `json:"max_per_host_connections" mapstructure:"max_per_host_connections"` // Defender configuration DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"` // Rate limiter configurations @@ -524,9 +532,9 @@ func (c *SSHConnection) Close() error { // ActiveConnections holds the currect active connections with the associated transfers type ActiveConnections struct { - // networkConnections is the counter for the network connections, it contains - // both authenticated and estabilished connections and the ones waiting for authentication - networkConnections int32 + // clients contains both authenticated and estabilished connections and the ones waiting + // for authentication + clients clientsMap sync.RWMutex connections []ActiveConnection sshConnections []*SSHConnection @@ -693,34 +701,46 @@ func (conns *ActiveConnections) checkIdles() { conns.RUnlock() } -// AddNetworkConnection increments the network connections counter -func (conns *ActiveConnections) AddNetworkConnection() { - atomic.AddInt32(&conns.networkConnections, 1) +// AddClientConnection stores a new client connection +func (conns *ActiveConnections) AddClientConnection(ipAddr string) { + conns.clients.add(ipAddr) } -// RemoveNetworkConnection decrements the network connections counter -func (conns *ActiveConnections) RemoveNetworkConnection() { - atomic.AddInt32(&conns.networkConnections, -1) +// RemoveClientConnection removes a disconnected client from the tracked ones +func (conns *ActiveConnections) RemoveClientConnection(ipAddr string) { + conns.clients.remove(ipAddr) } // IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded -func (conns *ActiveConnections) IsNewConnectionAllowed() bool { - if Config.MaxTotalConnections == 0 { +func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool { + if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 { return true } - num := atomic.LoadInt32(&conns.networkConnections) - if num > int32(Config.MaxTotalConnections) { - logger.Debug(logSender, "", "active network connections %v/%v", num, Config.MaxTotalConnections) - return false + if Config.MaxPerHostConnections > 0 { + if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections { + logger.Debug(logSender, "", "active connections from %v %v/%v", ipAddr, total, Config.MaxPerHostConnections) + AddDefenderEvent(ipAddr, HostEventLimitExceeded) + return false + } } - // on a single SFTP connection we could have multiple SFTP channels or commands - // so we check the estabilished connections too - conns.RLock() - defer conns.RUnlock() + if Config.MaxTotalConnections > 0 { + if total := conns.clients.getTotal(); total > int32(Config.MaxTotalConnections) { + logger.Debug(logSender, "", "active client connections %v/%v", total, Config.MaxTotalConnections) + return false + } - return len(conns.connections) < Config.MaxTotalConnections + // on a single SFTP connection we could have multiple SFTP channels or commands + // so we check the estabilished connections too + + conns.RLock() + defer conns.RUnlock() + + return len(conns.connections) < Config.MaxTotalConnections + } + + return true } // GetStats returns stats for active connections diff --git a/common/common_test.go b/common/common_test.go index 64ec0ccb..875f87ba 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -228,32 +228,61 @@ func TestRateLimitersIntegration(t *testing.T) { func TestMaxConnections(t *testing.T) { oldValue := Config.MaxTotalConnections - Config.MaxTotalConnections = 1 + perHost := Config.MaxPerHostConnections - assert.True(t, Connections.IsNewConnectionAllowed()) + Config.MaxPerHostConnections = 0 + + ipAddr := "192.168.7.8" + assert.True(t, Connections.IsNewConnectionAllowed(ipAddr)) + + Config.MaxTotalConnections = 1 + Config.MaxPerHostConnections = perHost + + assert.True(t, Connections.IsNewConnectionAllowed(ipAddr)) c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}) fakeConn := &fakeConnection{ BaseConnection: c, } Connections.Add(fakeConn) assert.Len(t, Connections.GetStats(), 1) - assert.False(t, Connections.IsNewConnectionAllowed()) + assert.False(t, Connections.IsNewConnectionAllowed(ipAddr)) res := Connections.Close(fakeConn.GetID()) assert.True(t, res) assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond) - assert.True(t, Connections.IsNewConnectionAllowed()) - Connections.AddNetworkConnection() - Connections.AddNetworkConnection() - assert.False(t, Connections.IsNewConnectionAllowed()) - Connections.RemoveNetworkConnection() - assert.True(t, Connections.IsNewConnectionAllowed()) - Connections.RemoveNetworkConnection() + assert.True(t, Connections.IsNewConnectionAllowed(ipAddr)) + Connections.AddClientConnection(ipAddr) + Connections.AddClientConnection(ipAddr) + assert.False(t, Connections.IsNewConnectionAllowed(ipAddr)) + Connections.RemoveClientConnection(ipAddr) + assert.True(t, Connections.IsNewConnectionAllowed(ipAddr)) + Connections.RemoveClientConnection(ipAddr) Config.MaxTotalConnections = oldValue } +func TestMaxConnectionPerHost(t *testing.T) { + oldValue := Config.MaxPerHostConnections + + Config.MaxPerHostConnections = 2 + + ipAddr := "192.168.9.9" + Connections.AddClientConnection(ipAddr) + assert.True(t, Connections.IsNewConnectionAllowed(ipAddr)) + + Connections.AddClientConnection(ipAddr) + assert.True(t, Connections.IsNewConnectionAllowed(ipAddr)) + + Connections.AddClientConnection(ipAddr) + assert.False(t, Connections.IsNewConnectionAllowed(ipAddr)) + + Connections.RemoveClientConnection(ipAddr) + Connections.RemoveClientConnection(ipAddr) + + Config.MaxPerHostConnections = oldValue +} + func TestIdleConnections(t *testing.T) { configCopy := Config @@ -340,7 +369,7 @@ func TestCloseConnection(t *testing.T) { fakeConn := &fakeConnection{ BaseConnection: c, } - assert.True(t, Connections.IsNewConnectionAllowed()) + assert.True(t, Connections.IsNewConnectionAllowed("127.0.0.1")) Connections.Add(fakeConn) assert.Len(t, Connections.GetStats(), 1) res := Connections.Close(fakeConn.GetID()) diff --git a/common/defender.go b/common/defender.go index a15e37df..df2c8d92 100644 --- a/common/defender.go +++ b/common/defender.go @@ -23,7 +23,7 @@ const ( HostEventLoginFailed HostEvent = iota HostEventUserNotFound HostEventNoLoginTried - HostEventRateExceeded + HostEventLimitExceeded ) // Defender defines the interface that a defender must implements @@ -51,8 +51,9 @@ type DefenderConfig struct { ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"` // Score for valid login attempts, eg. user accounts that exist ScoreValid int `json:"score_valid" mapstructure:"score_valid"` - // Score for rate exceeded events, generated from the rate limiters - ScoreRateExceeded int `json:"score_rate_exceeded" mapstructure:"score_rate_exceeded"` + // Score for limit exceeded events, generated from the rate limiters or for max connections + // per-host exceeded + ScoreLimitExceeded int `json:"score_limit_exceeded" mapstructure:"score_limit_exceeded"` // Defines the time window, in minutes, for tracking client errors. // A host is banned if it has exceeded the defined threshold during // the last observation time minutes @@ -126,8 +127,8 @@ func (c *DefenderConfig) validate() error { if c.ScoreValid >= c.Threshold { return fmt.Errorf("score_valid %v cannot be greater than threshold %v", c.ScoreValid, c.Threshold) } - if c.ScoreRateExceeded >= c.Threshold { - return fmt.Errorf("score_rate_exceeded %v cannot be greater than threshold %v", c.ScoreRateExceeded, c.Threshold) + if c.ScoreLimitExceeded >= c.Threshold { + return fmt.Errorf("score_limit_exceeded %v cannot be greater than threshold %v", c.ScoreLimitExceeded, c.Threshold) } if c.BanTime <= 0 { return fmt.Errorf("invalid ban_time %v", c.BanTime) @@ -254,8 +255,8 @@ func (d *memoryDefender) AddEvent(ip string, event HostEvent) { switch event { case HostEventLoginFailed: score = d.config.ScoreValid - case HostEventRateExceeded: - score = d.config.ScoreRateExceeded + case HostEventLimitExceeded: + score = d.config.ScoreLimitExceeded case HostEventUserNotFound, HostEventNoLoginTried: score = d.config.ScoreInvalid } diff --git a/common/defender_test.go b/common/defender_test.go index 29975d3b..db3fe41c 100644 --- a/common/defender_test.go +++ b/common/defender_test.go @@ -41,18 +41,18 @@ func TestBasicDefender(t *testing.T) { assert.NoError(t, err) config := &DefenderConfig{ - Enabled: true, - BanTime: 10, - BanTimeIncrement: 2, - Threshold: 5, - ScoreInvalid: 2, - ScoreValid: 1, - ScoreRateExceeded: 3, - ObservationTime: 15, - EntriesSoftLimit: 1, - EntriesHardLimit: 2, - SafeListFile: "slFile", - BlockListFile: "blFile", + Enabled: true, + BanTime: 10, + BanTimeIncrement: 2, + Threshold: 5, + ScoreInvalid: 2, + ScoreValid: 1, + ScoreLimitExceeded: 3, + ObservationTime: 15, + EntriesSoftLimit: 1, + EntriesHardLimit: 2, + SafeListFile: "slFile", + BlockListFile: "blFile", } _, err = newInMemoryDefender(config) @@ -75,7 +75,7 @@ func TestBasicDefender(t *testing.T) { defender.AddEvent("172.16.1.4", HostEventLoginFailed) defender.AddEvent("192.168.8.4", HostEventUserNotFound) - defender.AddEvent("172.16.1.3", HostEventRateExceeded) + defender.AddEvent("172.16.1.3", HostEventLimitExceeded) assert.Equal(t, 0, defender.countHosts()) testIP := "12.34.56.78" @@ -84,7 +84,7 @@ func TestBasicDefender(t *testing.T) { assert.Equal(t, 0, defender.countBanned()) assert.Equal(t, 1, defender.GetScore(testIP)) assert.Nil(t, defender.GetBanTime(testIP)) - defender.AddEvent(testIP, HostEventRateExceeded) + defender.AddEvent(testIP, HostEventLimitExceeded) assert.Equal(t, 1, defender.countHosts()) assert.Equal(t, 0, defender.countBanned()) assert.Equal(t, 4, defender.GetScore(testIP)) @@ -317,11 +317,11 @@ func TestDefenderConfig(t *testing.T) { require.Error(t, err) c.ScoreInvalid = 2 - c.ScoreRateExceeded = 10 + c.ScoreLimitExceeded = 10 err = c.validate() require.Error(t, err) - c.ScoreRateExceeded = 2 + c.ScoreLimitExceeded = 2 c.ScoreValid = 10 err = c.validate() require.Error(t, err) diff --git a/common/ratelimiter.go b/common/ratelimiter.go index 4104b957..2c387508 100644 --- a/common/ratelimiter.go +++ b/common/ratelimiter.go @@ -149,7 +149,7 @@ func (rl *rateLimiter) Wait(source string) (time.Duration, error) { if delay > rl.maxDelay { res.Cancel() if rl.generateDefenderEvents && rl.globalBucket == nil { - AddDefenderEvent(source, HostEventRateExceeded) + AddDefenderEvent(source, HostEventLimitExceeded) } return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay) } diff --git a/config/config.go b/config/config.go index 33bac3e2..b3aa3a4d 100644 --- a/config/config.go +++ b/config/config.go @@ -111,24 +111,25 @@ func Init() { ExecuteOn: []string{}, Hook: "", }, - SetstatMode: 0, - ProxyProtocol: 0, - ProxyAllowed: []string{}, - PostConnectHook: "", - MaxTotalConnections: 0, + SetstatMode: 0, + ProxyProtocol: 0, + ProxyAllowed: []string{}, + PostConnectHook: "", + MaxTotalConnections: 0, + MaxPerHostConnections: 20, DefenderConfig: common.DefenderConfig{ - Enabled: false, - BanTime: 30, - BanTimeIncrement: 50, - Threshold: 15, - ScoreInvalid: 2, - ScoreValid: 1, - ScoreRateExceeded: 3, - ObservationTime: 30, - EntriesSoftLimit: 100, - EntriesHardLimit: 150, - SafeListFile: "", - BlockListFile: "", + Enabled: false, + BanTime: 30, + BanTimeIncrement: 50, + Threshold: 15, + ScoreInvalid: 2, + ScoreValid: 1, + ScoreLimitExceeded: 3, + ObservationTime: 30, + EntriesSoftLimit: 100, + EntriesHardLimit: 150, + SafeListFile: "", + BlockListFile: "", }, RateLimitersConfig: []common.RateLimiterConfig{defaultRateLimiter}, }, @@ -873,13 +874,14 @@ func setViperDefaults() { viper.SetDefault("common.proxy_allowed", globalConf.Common.ProxyAllowed) viper.SetDefault("common.post_connect_hook", globalConf.Common.PostConnectHook) viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections) + viper.SetDefault("common.max_per_host_connections", globalConf.Common.MaxPerHostConnections) viper.SetDefault("common.defender.enabled", globalConf.Common.DefenderConfig.Enabled) viper.SetDefault("common.defender.ban_time", globalConf.Common.DefenderConfig.BanTime) viper.SetDefault("common.defender.ban_time_increment", globalConf.Common.DefenderConfig.BanTimeIncrement) viper.SetDefault("common.defender.threshold", globalConf.Common.DefenderConfig.Threshold) viper.SetDefault("common.defender.score_invalid", globalConf.Common.DefenderConfig.ScoreInvalid) viper.SetDefault("common.defender.score_valid", globalConf.Common.DefenderConfig.ScoreValid) - viper.SetDefault("common.defender.score_rate_exceeded", globalConf.Common.DefenderConfig.ScoreRateExceeded) + viper.SetDefault("common.defender.score_limit_exceeded", globalConf.Common.DefenderConfig.ScoreLimitExceeded) viper.SetDefault("common.defender.observation_time", globalConf.Common.DefenderConfig.ObservationTime) viper.SetDefault("common.defender.entries_soft_limit", globalConf.Common.DefenderConfig.EntriesSoftLimit) viper.SetDefault("common.defender.entries_hard_limit", globalConf.Common.DefenderConfig.EntriesHardLimit) diff --git a/docs/defender.md b/docs/defender.md index 24da3e49..a646eeab 100644 --- a/docs/defender.md +++ b/docs/defender.md @@ -8,7 +8,7 @@ You can configure a score for each event type: - `score_valid`, defines the score for valid login attempts, eg. user accounts that exist. Default `1`. - `score_invalid`, defines the score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts. Default `2`. -- `score_rate_exceeded`, defines the score for hosts that exceeded the configured rate limits. Default `3`. +- `score_limit_exceeded`, defines the score for hosts that exceeded the configured rate limits or the configured max connections per host. Default `3`. And then you can configure: diff --git a/docs/full-configuration.md b/docs/full-configuration.md index 70618a33..d49dd544 100644 --- a/docs/full-configuration.md +++ b/docs/full-configuration.md @@ -64,7 +64,8 @@ The configuration file contains the following sections: - If `proxy_protocol` is set to 2 and we receive a proxy header from an IP that is not in the list then the connection will be rejected - `startup_hook`, string. Absolute path to an external program or an HTTP URL to invoke as soon as SFTPGo starts. If you define an HTTP URL it will be invoked using a `GET` request. Please note that SFTPGo services may not yet be available when this hook is run. Leave empty do disable - `post_connect_hook`, string. Absolute path to the command to execute or HTTP URL to notify. See [Post connect hook](./post-connect-hook.md) for more details. Leave empty to disable - - `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited + - `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited. Default: 0. + - `max_per_host_connections`, integer. Maximum number of concurrent client connections from the same host (IP). If the defender is enabled, exceeding this limit will generate `score_limit_exceeded` events and thus hosts that repeatedly exceed the max allowed connections can be automatically blocked. 0 means unlimited. Default: 20. - `defender`, struct containing the defender configuration. See [Defender](./defender.md) for more details. - `enabled`, boolean. Default `false`. - `ban_time`, integer. Ban time in minutes. @@ -72,7 +73,7 @@ The configuration file contains the following sections: - `threshold`, integer. Threshold value for banning a client. - `score_invalid`, integer. Score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts. - `score_valid`, integer. Score for valid login attempts, eg. user accounts that exist. - - `score_rate_exceeded`, integer. Score for hosts that exceeded the configured rate limits. + - `score_limit_exceeded`, integer. Score for hosts that exceeded the configured rate limits or the maximum, per-host, allowed connections. - `observation_time`, integer. Defines the time window, in minutes, for tracking client errors. A host is banned if it has exceeded the defined threshold during the last observation time minutes. - `entries_soft_limit`, integer. - `entries_hard_limit`, integer. The number of banned IPs and host scores kept in memory will vary between the soft and hard limit. diff --git a/docs/rate-limiting.md b/docs/rate-limiting.md index b2c9dd4f..6c1631fa 100644 --- a/docs/rate-limiting.md +++ b/docs/rate-limiting.md @@ -18,7 +18,7 @@ The supported protocols are: You can also define two types of rate limiters: - global, it is independent from the source host and therefore define an aggregate limit for the configured protocol/s -- per-host, this type of rate limiter can be connected to the built-in [defender](./defender.md) and generate `score_rate_exceeded` events and thus hosts that repeatedly exceed the configured limit can be automatically blocked +- per-host, this type of rate limiter can be connected to the built-in [defender](./defender.md) and generate `score_limit_exceeded` events and thus hosts that repeatedly exceed the configured limit can be automatically blocked If you configure a per-host rate limiter, SFTPGo will keep a rate limiter in memory for each host that connects to the service, you can limit the memory usage using the `entries_soft_limit` and `entries_hard_limit` configuration keys. diff --git a/ftpd/ftpd_test.go b/ftpd/ftpd_test.go index 35b718e7..3b28d450 100644 --- a/ftpd/ftpd_test.go +++ b/ftpd/ftpd_test.go @@ -779,13 +779,36 @@ func TestMaxConnections(t *testing.T) { common.Config.MaxTotalConnections = oldValue } +func TestMaxPerHostConnections(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 1 + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + _, err = getFTPClient(user, false, nil) + assert.Error(t, err) + err = client.Quit() + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.MaxPerHostConnections = oldValue +} + func TestRateLimiter(t *testing.T) { oldConfig := config.GetCommonConfig() cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Threshold = 5 - cfg.DefenderConfig.ScoreRateExceeded = 3 + cfg.DefenderConfig.ScoreLimitExceeded = 3 cfg.RateLimitersConfig = []common.RateLimiterConfig{ { Average: 1, @@ -843,7 +866,7 @@ func TestDefender(t *testing.T) { cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Threshold = 3 - cfg.DefenderConfig.ScoreRateExceeded = 2 + cfg.DefenderConfig.ScoreLimitExceeded = 2 err := common.Initialize(cfg) assert.NoError(t, err) diff --git a/ftpd/server.go b/ftpd/server.go index 153858de..a15fb9b3 100644 --- a/ftpd/server.go +++ b/ftpd/server.go @@ -135,13 +135,13 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) { // ClientConnected is called to send the very first welcome message func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) { - common.Connections.AddNetworkConnection() ipAddr := utils.GetIPFromRemoteAddress(cc.RemoteAddr().String()) + common.Connections.AddClientConnection(ipAddr) if common.IsBanned(ipAddr) { logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, ip %#v is banned", ipAddr) return "Access denied: banned client IP", common.ErrConnectionDenied } - if !common.Connections.IsNewConnectionAllowed() { + if !common.Connections.IsNewConnectionAllowed(ipAddr) { logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached") return "Access denied: max allowed connection exceeded", common.ErrConnectionDenied } @@ -167,7 +167,7 @@ func (s *Server) ClientDisconnected(cc ftpserver.ClientContext) { s.cleanTLSConnVerification(cc.ID()) connID := fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID()) common.Connections.Remove(connID) - common.Connections.RemoveNetworkConnection() + common.Connections.RemoveClientConnection(utils.GetIPFromRemoteAddress(cc.RemoteAddr().String())) } // AuthUser authenticates the user and selects an handling driver diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index e29418e6..b56d196b 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -2945,7 +2945,7 @@ func TestDefenderAPI(t *testing.T) { cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Threshold = 3 - cfg.DefenderConfig.ScoreRateExceeded = 2 + cfg.DefenderConfig.ScoreLimitExceeded = 2 err := common.Initialize(cfg) require.NoError(t, err) @@ -4615,7 +4615,7 @@ func TestDefender(t *testing.T) { cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Threshold = 3 - cfg.DefenderConfig.ScoreRateExceeded = 2 + cfg.DefenderConfig.ScoreLimitExceeded = 2 err := common.Initialize(cfg) assert.NoError(t, err) diff --git a/httpd/server.go b/httpd/server.go index 09348283..5895f335 100644 --- a/httpd/server.go +++ b/httpd/server.go @@ -110,8 +110,10 @@ func (s *httpdServer) refreshCookie(next http.Handler) http.Handler { func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginPostSize) - common.Connections.AddNetworkConnection() - defer common.Connections.RemoveNetworkConnection() + + ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr) + common.Connections.AddClientConnection(ipAddr) + defer common.Connections.RemoveClientConnection(ipAddr) if err := r.ParseForm(); err != nil { renderClientLoginPage(w, err.Error()) @@ -128,8 +130,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re return } - ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr) - if !common.Connections.IsNewConnectionAllowed() { + if !common.Connections.IsNewConnectionAllowed(ipAddr) { logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached") renderClientLoginPage(w, "configured connections limit reached") return diff --git a/httpd/webclient.go b/httpd/webclient.go index 97bfe669..280fcee7 100644 --- a/httpd/webclient.go +++ b/httpd/webclient.go @@ -279,20 +279,20 @@ func handleWebClientLogout(w http.ResponseWriter, r *http.Request) { } func handleClientGetFiles(w http.ResponseWriter, r *http.Request) { - common.Connections.AddNetworkConnection() - defer common.Connections.RemoveNetworkConnection() + ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr) + common.Connections.AddClientConnection(ipAddr) + defer common.Connections.RemoveClientConnection(ipAddr) claims, err := getTokenClaims(r) if err != nil || claims.Username == "" { renderClientForbiddenPage(w, r, "Invalid token claims") return } - if !common.Connections.IsNewConnectionAllowed() { + if !common.Connections.IsNewConnectionAllowed(ipAddr) { logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached") renderClientForbiddenPage(w, r, "configured connections limit reached") return } - ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr) if common.IsBanned(ipAddr) { renderClientForbiddenPage(w, r, "your IP address is banned") return diff --git a/sftpd/server.go b/sftpd/server.go index 4a3bf896..fc7e954a 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -356,7 +356,7 @@ func canAcceptConnection(ip string) bool { logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, ip %#v is banned", ip) return false } - if !common.Connections.IsNewConnectionAllowed() { + if !common.Connections.IsNewConnectionAllowed(ip) { logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached") return false } @@ -378,10 +378,10 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve } }() - common.Connections.AddNetworkConnection() - defer common.Connections.RemoveNetworkConnection() - ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()) + common.Connections.AddClientConnection(ipAddr) + defer common.Connections.RemoveClientConnection(ipAddr) + if !canAcceptConnection(ipAddr) { conn.Close() return diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index 03f7a71d..df1b0246 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -527,7 +527,7 @@ func TestDefender(t *testing.T) { cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Threshold = 3 - cfg.DefenderConfig.ScoreRateExceeded = 2 + cfg.DefenderConfig.ScoreLimitExceeded = 2 err := common.Initialize(cfg) assert.NoError(t, err) @@ -663,6 +663,9 @@ func TestOpenReadWritePerm(t *testing.T) { } func TestConcurrency(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 0 + usePubKey := true numLogins := 50 u := getTestUser(usePubKey) @@ -747,6 +750,8 @@ func TestConcurrency(t *testing.T) { assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) + + common.Config.MaxPerHostConnections = oldValue } func TestProxyProtocol(t *testing.T) { @@ -2896,6 +2901,7 @@ func TestQuotaDisabledError(t *testing.T) { assert.NoError(t, err) } +//nolint:dupl func TestMaxConnections(t *testing.T) { oldValue := common.Config.MaxTotalConnections common.Config.MaxTotalConnections = 1 @@ -2910,7 +2916,7 @@ func TestMaxConnections(t *testing.T) { defer client.Close() assert.NoError(t, checkBasicSFTP(client)) s, c, err := getSftpClient(user, usePubKey) - if !assert.Error(t, err, "max sessions exceeded, new login should not succeed") { + if !assert.Error(t, err, "max total connections exceeded, new login should not succeed") { c.Close() s.Close() } @@ -2923,6 +2929,34 @@ func TestMaxConnections(t *testing.T) { common.Config.MaxTotalConnections = oldValue } +//nolint:dupl +func TestMaxPerHostConnections(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 1 + + usePubKey := true + u := getTestUser(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + s, c, err := getSftpClient(user, usePubKey) + if !assert.Error(t, err, "max per host connections exceeded, new login should not succeed") { + c.Close() + s.Close() + } + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.MaxPerHostConnections = oldValue +} + func TestMaxSessions(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) diff --git a/sftpgo.json b/sftpgo.json index 2a9932f9..ddaafec3 100644 --- a/sftpgo.json +++ b/sftpgo.json @@ -12,6 +12,7 @@ "startup_hook": "", "post_connect_hook": "", "max_total_connections": 0, + "max_per_host_connections": 20, "defender": { "enabled": false, "ban_time": 30, @@ -19,7 +20,7 @@ "threshold": 15, "score_invalid": 2, "score_valid": 1, - "score_rate_exceeded": 3, + "score_limit_exceeded": 3, "observation_time": 30, "entries_soft_limit": 100, "entries_hard_limit": 150, diff --git a/webdavd/server.go b/webdavd/server.go index 9f32b688..f0aa3d30 100644 --- a/webdavd/server.go +++ b/webdavd/server.go @@ -144,16 +144,18 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError) } }() - common.Connections.AddNetworkConnection() - defer common.Connections.RemoveNetworkConnection() - if !common.Connections.IsNewConnectionAllowed() { + checkRemoteAddress(r) + ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr) + + common.Connections.AddClientConnection(ipAddr) + defer common.Connections.RemoveClientConnection(ipAddr) + + if !common.Connections.IsNewConnectionAllowed(ipAddr) { logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", "connection refused, configured limit reached") http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable) return } - checkRemoteAddress(r) - ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr) if common.IsBanned(ipAddr) { http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden) return diff --git a/webdavd/webdavd_test.go b/webdavd/webdavd_test.go index e1188ab5..91c7e7bf 100644 --- a/webdavd/webdavd_test.go +++ b/webdavd/webdavd_test.go @@ -747,7 +747,7 @@ func TestDefender(t *testing.T) { cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Threshold = 3 - cfg.DefenderConfig.ScoreRateExceeded = 2 + cfg.DefenderConfig.ScoreLimitExceeded = 2 err := common.Initialize(cfg) assert.NoError(t, err) @@ -934,6 +934,33 @@ func TestMaxConnections(t *testing.T) { common.Config.MaxTotalConnections = oldValue } +func TestMaxPerHostConnections(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 1 + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, true, nil) + assert.NoError(t, checkBasicFunc(client)) + // now add a fake connection + addrs, err := net.LookupHost("localhost") + assert.NoError(t, err) + for _, addr := range addrs { + common.Connections.AddClientConnection(addr) + } + assert.Error(t, checkBasicFunc(client)) + for _, addr := range addrs { + common.Connections.RemoveClientConnection(addr) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Len(t, common.Connections.GetStats(), 0) + + common.Config.MaxPerHostConnections = oldValue +} + func TestMaxSessions(t *testing.T) { u := getTestUser() u.MaxSessions = 1