Kaynağa Gözat

allow to limit the number of per-host connections

Nicola Murino 4 yıl önce
ebeveyn
işleme
8f6cdacd00

+ 51 - 0
common/clientsmap.go

@@ -0,0 +1,51 @@
+package common
+
+import (
+	"sync"
+	"sync/atomic"
+
+	"github.com/drakkan/sftpgo/logger"
+)
+
+// clienstMap is a struct containing the map of the connected clients
+type clientsMap struct {
+	totalConnections int32
+	mu               sync.RWMutex
+	clients          map[string]int
+}
+
+func (c *clientsMap) add(source string) {
+	atomic.AddInt32(&c.totalConnections, 1)
+
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	c.clients[source]++
+}
+
+func (c *clientsMap) remove(source string) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	if val, ok := c.clients[source]; ok {
+		atomic.AddInt32(&c.totalConnections, -1)
+		c.clients[source]--
+		if val > 1 {
+			return
+		}
+		delete(c.clients, source)
+	} else {
+		logger.Warn(logSender, "", "cannot remove client %v it is not mapped", source)
+	}
+}
+
+func (c *clientsMap) getTotal() int32 {
+	return atomic.LoadInt32(&c.totalConnections)
+}
+
+func (c *clientsMap) getTotalFrom(source string) int {
+	c.mu.RLock()
+	defer c.mu.RUnlock()
+
+	return c.clients[source]
+}

+ 59 - 0
common/clientsmap_test.go

@@ -0,0 +1,59 @@
+package common
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestClientsMap(t *testing.T) {
+	m := clientsMap{
+		clients: make(map[string]int),
+	}
+	ip1 := "192.168.1.1"
+	ip2 := "192.168.1.2"
+	m.add(ip1)
+	assert.Equal(t, int32(1), m.getTotal())
+	assert.Equal(t, 1, m.getTotalFrom(ip1))
+	assert.Equal(t, 0, m.getTotalFrom(ip2))
+
+	m.add(ip1)
+	m.add(ip2)
+	assert.Equal(t, int32(3), m.getTotal())
+	assert.Equal(t, 2, m.getTotalFrom(ip1))
+	assert.Equal(t, 1, m.getTotalFrom(ip2))
+
+	m.add(ip1)
+	m.add(ip1)
+	m.add(ip2)
+	assert.Equal(t, int32(6), m.getTotal())
+	assert.Equal(t, 4, m.getTotalFrom(ip1))
+	assert.Equal(t, 2, m.getTotalFrom(ip2))
+
+	m.remove(ip2)
+	assert.Equal(t, int32(5), m.getTotal())
+	assert.Equal(t, 4, m.getTotalFrom(ip1))
+	assert.Equal(t, 1, m.getTotalFrom(ip2))
+
+	m.remove("unknown")
+	assert.Equal(t, int32(5), m.getTotal())
+	assert.Equal(t, 4, m.getTotalFrom(ip1))
+	assert.Equal(t, 1, m.getTotalFrom(ip2))
+
+	m.remove(ip2)
+	assert.Equal(t, int32(4), m.getTotal())
+	assert.Equal(t, 4, m.getTotalFrom(ip1))
+	assert.Equal(t, 0, m.getTotalFrom(ip2))
+
+	m.remove(ip1)
+	m.remove(ip1)
+	m.remove(ip1)
+	assert.Equal(t, int32(1), m.getTotal())
+	assert.Equal(t, 1, m.getTotalFrom(ip1))
+	assert.Equal(t, 0, m.getTotalFrom(ip2))
+
+	m.remove(ip1)
+	assert.Equal(t, int32(0), m.getTotal())
+	assert.Equal(t, 0, m.getTotalFrom(ip1))
+	assert.Equal(t, 0, m.getTotalFrom(ip2))
+}

+ 40 - 20
common/common.go

@@ -80,6 +80,12 @@ const (
 	UploadModeAtomicWithResume
 )
 
+func init() {
+	Connections.clients = clientsMap{
+		clients: make(map[string]int),
+	}
+}
+
 // errors definitions
 var (
 	ErrPermissionDenied     = errors.New("permission denied")
@@ -352,6 +358,8 @@ type Configuration struct {
 	PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"`
 	// Maximum number of concurrent client connections. 0 means unlimited
 	MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"`
+	// Maximum number of concurrent client connections from the same host (IP). 0 means unlimited
+	MaxPerHostConnections int `json:"max_per_host_connections" mapstructure:"max_per_host_connections"`
 	// Defender configuration
 	DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"`
 	// Rate limiter configurations
@@ -524,9 +532,9 @@ func (c *SSHConnection) Close() error {
 
 // ActiveConnections holds the currect active connections with the associated transfers
 type ActiveConnections struct {
-	// networkConnections is the counter for the network connections, it contains
-	// both authenticated and estabilished connections and the ones waiting for authentication
-	networkConnections int32
+	// clients contains both authenticated and estabilished connections and the ones waiting
+	// for authentication
+	clients clientsMap
 	sync.RWMutex
 	connections    []ActiveConnection
 	sshConnections []*SSHConnection
@@ -693,34 +701,46 @@ func (conns *ActiveConnections) checkIdles() {
 	conns.RUnlock()
 }
 
-// AddNetworkConnection increments the network connections counter
-func (conns *ActiveConnections) AddNetworkConnection() {
-	atomic.AddInt32(&conns.networkConnections, 1)
+// AddClientConnection stores a new client connection
+func (conns *ActiveConnections) AddClientConnection(ipAddr string) {
+	conns.clients.add(ipAddr)
 }
 
-// RemoveNetworkConnection decrements the network connections counter
-func (conns *ActiveConnections) RemoveNetworkConnection() {
-	atomic.AddInt32(&conns.networkConnections, -1)
+// RemoveClientConnection removes a disconnected client from the tracked ones
+func (conns *ActiveConnections) RemoveClientConnection(ipAddr string) {
+	conns.clients.remove(ipAddr)
 }
 
 // IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
-func (conns *ActiveConnections) IsNewConnectionAllowed() bool {
-	if Config.MaxTotalConnections == 0 {
+func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool {
+	if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 {
 		return true
 	}
 
-	num := atomic.LoadInt32(&conns.networkConnections)
-	if num > int32(Config.MaxTotalConnections) {
-		logger.Debug(logSender, "", "active network connections %v/%v", num, Config.MaxTotalConnections)
-		return false
+	if Config.MaxPerHostConnections > 0 {
+		if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections {
+			logger.Debug(logSender, "", "active connections from %v %v/%v", ipAddr, total, Config.MaxPerHostConnections)
+			AddDefenderEvent(ipAddr, HostEventLimitExceeded)
+			return false
+		}
 	}
-	// on a single SFTP connection we could have multiple SFTP channels or commands
-	// so we check the estabilished connections too
 
-	conns.RLock()
-	defer conns.RUnlock()
+	if Config.MaxTotalConnections > 0 {
+		if total := conns.clients.getTotal(); total > int32(Config.MaxTotalConnections) {
+			logger.Debug(logSender, "", "active client connections %v/%v", total, Config.MaxTotalConnections)
+			return false
+		}
+
+		// on a single SFTP connection we could have multiple SFTP channels or commands
+		// so we check the estabilished connections too
+
+		conns.RLock()
+		defer conns.RUnlock()
 
-	return len(conns.connections) < Config.MaxTotalConnections
+		return len(conns.connections) < Config.MaxTotalConnections
+	}
+
+	return true
 }
 
 // GetStats returns stats for active connections

+ 39 - 10
common/common_test.go

@@ -228,32 +228,61 @@ func TestRateLimitersIntegration(t *testing.T) {
 
 func TestMaxConnections(t *testing.T) {
 	oldValue := Config.MaxTotalConnections
+	perHost := Config.MaxPerHostConnections
+
+	Config.MaxPerHostConnections = 0
+
+	ipAddr := "192.168.7.8"
+	assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
+
 	Config.MaxTotalConnections = 1
+	Config.MaxPerHostConnections = perHost
 
-	assert.True(t, Connections.IsNewConnectionAllowed())
+	assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
 	c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{})
 	fakeConn := &fakeConnection{
 		BaseConnection: c,
 	}
 	Connections.Add(fakeConn)
 	assert.Len(t, Connections.GetStats(), 1)
-	assert.False(t, Connections.IsNewConnectionAllowed())
+	assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
 
 	res := Connections.Close(fakeConn.GetID())
 	assert.True(t, res)
 	assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
 
-	assert.True(t, Connections.IsNewConnectionAllowed())
-	Connections.AddNetworkConnection()
-	Connections.AddNetworkConnection()
-	assert.False(t, Connections.IsNewConnectionAllowed())
-	Connections.RemoveNetworkConnection()
-	assert.True(t, Connections.IsNewConnectionAllowed())
-	Connections.RemoveNetworkConnection()
+	assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
+	Connections.AddClientConnection(ipAddr)
+	Connections.AddClientConnection(ipAddr)
+	assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
+	Connections.RemoveClientConnection(ipAddr)
+	assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
+	Connections.RemoveClientConnection(ipAddr)
 
 	Config.MaxTotalConnections = oldValue
 }
 
+func TestMaxConnectionPerHost(t *testing.T) {
+	oldValue := Config.MaxPerHostConnections
+
+	Config.MaxPerHostConnections = 2
+
+	ipAddr := "192.168.9.9"
+	Connections.AddClientConnection(ipAddr)
+	assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
+
+	Connections.AddClientConnection(ipAddr)
+	assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
+
+	Connections.AddClientConnection(ipAddr)
+	assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
+
+	Connections.RemoveClientConnection(ipAddr)
+	Connections.RemoveClientConnection(ipAddr)
+
+	Config.MaxPerHostConnections = oldValue
+}
+
 func TestIdleConnections(t *testing.T) {
 	configCopy := Config
 
@@ -340,7 +369,7 @@ func TestCloseConnection(t *testing.T) {
 	fakeConn := &fakeConnection{
 		BaseConnection: c,
 	}
-	assert.True(t, Connections.IsNewConnectionAllowed())
+	assert.True(t, Connections.IsNewConnectionAllowed("127.0.0.1"))
 	Connections.Add(fakeConn)
 	assert.Len(t, Connections.GetStats(), 1)
 	res := Connections.Close(fakeConn.GetID())

+ 8 - 7
common/defender.go

@@ -23,7 +23,7 @@ const (
 	HostEventLoginFailed HostEvent = iota
 	HostEventUserNotFound
 	HostEventNoLoginTried
-	HostEventRateExceeded
+	HostEventLimitExceeded
 )
 
 // Defender defines the interface that a defender must implements
@@ -51,8 +51,9 @@ type DefenderConfig struct {
 	ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"`
 	// Score for valid login attempts, eg. user accounts that exist
 	ScoreValid int `json:"score_valid" mapstructure:"score_valid"`
-	// Score for rate exceeded events, generated from the rate limiters
-	ScoreRateExceeded int `json:"score_rate_exceeded" mapstructure:"score_rate_exceeded"`
+	// Score for limit exceeded events, generated from the rate limiters or for max connections
+	// per-host exceeded
+	ScoreLimitExceeded int `json:"score_limit_exceeded" mapstructure:"score_limit_exceeded"`
 	// Defines the time window, in minutes, for tracking client errors.
 	// A host is banned if it has exceeded the defined threshold during
 	// the last observation time minutes
@@ -126,8 +127,8 @@ func (c *DefenderConfig) validate() error {
 	if c.ScoreValid >= c.Threshold {
 		return fmt.Errorf("score_valid %v cannot be greater than threshold %v", c.ScoreValid, c.Threshold)
 	}
-	if c.ScoreRateExceeded >= c.Threshold {
-		return fmt.Errorf("score_rate_exceeded %v cannot be greater than threshold %v", c.ScoreRateExceeded, c.Threshold)
+	if c.ScoreLimitExceeded >= c.Threshold {
+		return fmt.Errorf("score_limit_exceeded %v cannot be greater than threshold %v", c.ScoreLimitExceeded, c.Threshold)
 	}
 	if c.BanTime <= 0 {
 		return fmt.Errorf("invalid ban_time %v", c.BanTime)
@@ -254,8 +255,8 @@ func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
 	switch event {
 	case HostEventLoginFailed:
 		score = d.config.ScoreValid
-	case HostEventRateExceeded:
-		score = d.config.ScoreRateExceeded
+	case HostEventLimitExceeded:
+		score = d.config.ScoreLimitExceeded
 	case HostEventUserNotFound, HostEventNoLoginTried:
 		score = d.config.ScoreInvalid
 	}

+ 16 - 16
common/defender_test.go

@@ -41,18 +41,18 @@ func TestBasicDefender(t *testing.T) {
 	assert.NoError(t, err)
 
 	config := &DefenderConfig{
-		Enabled:           true,
-		BanTime:           10,
-		BanTimeIncrement:  2,
-		Threshold:         5,
-		ScoreInvalid:      2,
-		ScoreValid:        1,
-		ScoreRateExceeded: 3,
-		ObservationTime:   15,
-		EntriesSoftLimit:  1,
-		EntriesHardLimit:  2,
-		SafeListFile:      "slFile",
-		BlockListFile:     "blFile",
+		Enabled:            true,
+		BanTime:            10,
+		BanTimeIncrement:   2,
+		Threshold:          5,
+		ScoreInvalid:       2,
+		ScoreValid:         1,
+		ScoreLimitExceeded: 3,
+		ObservationTime:    15,
+		EntriesSoftLimit:   1,
+		EntriesHardLimit:   2,
+		SafeListFile:       "slFile",
+		BlockListFile:      "blFile",
 	}
 
 	_, err = newInMemoryDefender(config)
@@ -75,7 +75,7 @@ func TestBasicDefender(t *testing.T) {
 
 	defender.AddEvent("172.16.1.4", HostEventLoginFailed)
 	defender.AddEvent("192.168.8.4", HostEventUserNotFound)
-	defender.AddEvent("172.16.1.3", HostEventRateExceeded)
+	defender.AddEvent("172.16.1.3", HostEventLimitExceeded)
 	assert.Equal(t, 0, defender.countHosts())
 
 	testIP := "12.34.56.78"
@@ -84,7 +84,7 @@ func TestBasicDefender(t *testing.T) {
 	assert.Equal(t, 0, defender.countBanned())
 	assert.Equal(t, 1, defender.GetScore(testIP))
 	assert.Nil(t, defender.GetBanTime(testIP))
-	defender.AddEvent(testIP, HostEventRateExceeded)
+	defender.AddEvent(testIP, HostEventLimitExceeded)
 	assert.Equal(t, 1, defender.countHosts())
 	assert.Equal(t, 0, defender.countBanned())
 	assert.Equal(t, 4, defender.GetScore(testIP))
@@ -317,11 +317,11 @@ func TestDefenderConfig(t *testing.T) {
 	require.Error(t, err)
 
 	c.ScoreInvalid = 2
-	c.ScoreRateExceeded = 10
+	c.ScoreLimitExceeded = 10
 	err = c.validate()
 	require.Error(t, err)
 
-	c.ScoreRateExceeded = 2
+	c.ScoreLimitExceeded = 2
 	c.ScoreValid = 10
 	err = c.validate()
 	require.Error(t, err)

+ 1 - 1
common/ratelimiter.go

@@ -149,7 +149,7 @@ func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
 	if delay > rl.maxDelay {
 		res.Cancel()
 		if rl.generateDefenderEvents && rl.globalBucket == nil {
-			AddDefenderEvent(source, HostEventRateExceeded)
+			AddDefenderEvent(source, HostEventLimitExceeded)
 		}
 		return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
 	}

+ 20 - 18
config/config.go

@@ -111,24 +111,25 @@ func Init() {
 				ExecuteOn: []string{},
 				Hook:      "",
 			},
-			SetstatMode:         0,
-			ProxyProtocol:       0,
-			ProxyAllowed:        []string{},
-			PostConnectHook:     "",
-			MaxTotalConnections: 0,
+			SetstatMode:           0,
+			ProxyProtocol:         0,
+			ProxyAllowed:          []string{},
+			PostConnectHook:       "",
+			MaxTotalConnections:   0,
+			MaxPerHostConnections: 20,
 			DefenderConfig: common.DefenderConfig{
-				Enabled:           false,
-				BanTime:           30,
-				BanTimeIncrement:  50,
-				Threshold:         15,
-				ScoreInvalid:      2,
-				ScoreValid:        1,
-				ScoreRateExceeded: 3,
-				ObservationTime:   30,
-				EntriesSoftLimit:  100,
-				EntriesHardLimit:  150,
-				SafeListFile:      "",
-				BlockListFile:     "",
+				Enabled:            false,
+				BanTime:            30,
+				BanTimeIncrement:   50,
+				Threshold:          15,
+				ScoreInvalid:       2,
+				ScoreValid:         1,
+				ScoreLimitExceeded: 3,
+				ObservationTime:    30,
+				EntriesSoftLimit:   100,
+				EntriesHardLimit:   150,
+				SafeListFile:       "",
+				BlockListFile:      "",
 			},
 			RateLimitersConfig: []common.RateLimiterConfig{defaultRateLimiter},
 		},
@@ -873,13 +874,14 @@ func setViperDefaults() {
 	viper.SetDefault("common.proxy_allowed", globalConf.Common.ProxyAllowed)
 	viper.SetDefault("common.post_connect_hook", globalConf.Common.PostConnectHook)
 	viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections)
+	viper.SetDefault("common.max_per_host_connections", globalConf.Common.MaxPerHostConnections)
 	viper.SetDefault("common.defender.enabled", globalConf.Common.DefenderConfig.Enabled)
 	viper.SetDefault("common.defender.ban_time", globalConf.Common.DefenderConfig.BanTime)
 	viper.SetDefault("common.defender.ban_time_increment", globalConf.Common.DefenderConfig.BanTimeIncrement)
 	viper.SetDefault("common.defender.threshold", globalConf.Common.DefenderConfig.Threshold)
 	viper.SetDefault("common.defender.score_invalid", globalConf.Common.DefenderConfig.ScoreInvalid)
 	viper.SetDefault("common.defender.score_valid", globalConf.Common.DefenderConfig.ScoreValid)
-	viper.SetDefault("common.defender.score_rate_exceeded", globalConf.Common.DefenderConfig.ScoreRateExceeded)
+	viper.SetDefault("common.defender.score_limit_exceeded", globalConf.Common.DefenderConfig.ScoreLimitExceeded)
 	viper.SetDefault("common.defender.observation_time", globalConf.Common.DefenderConfig.ObservationTime)
 	viper.SetDefault("common.defender.entries_soft_limit", globalConf.Common.DefenderConfig.EntriesSoftLimit)
 	viper.SetDefault("common.defender.entries_hard_limit", globalConf.Common.DefenderConfig.EntriesHardLimit)

+ 1 - 1
docs/defender.md

@@ -8,7 +8,7 @@ You can configure a score for each event type:
 
 - `score_valid`, defines the score for valid login attempts, eg. user accounts that exist. Default `1`.
 - `score_invalid`, defines the score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts. Default `2`.
-- `score_rate_exceeded`, defines the score for hosts that exceeded the configured rate limits. Default `3`.
+- `score_limit_exceeded`, defines the score for hosts that exceeded the configured rate limits or the configured max connections per host. Default `3`.
 
 And then you can configure:
 

+ 3 - 2
docs/full-configuration.md

@@ -64,7 +64,8 @@ The configuration file contains the following sections:
     - If `proxy_protocol` is set to 2 and we receive a proxy header from an IP that is not in the list then the connection will be rejected
   - `startup_hook`, string. Absolute path to an external program or an HTTP URL to invoke as soon as SFTPGo starts. If you define an HTTP URL it will be invoked using a `GET` request. Please note that SFTPGo services may not yet be available when this hook is run. Leave empty do disable
   - `post_connect_hook`, string. Absolute path to the command to execute or HTTP URL to notify. See [Post connect hook](./post-connect-hook.md) for more details. Leave empty to disable
-  - `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited
+  - `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited. Default: 0.
+  - `max_per_host_connections`, integer.  Maximum number of concurrent client connections from the same host (IP). If the defender is enabled, exceeding this limit will generate `score_limit_exceeded` events and thus hosts that repeatedly exceed the max allowed connections can be automatically blocked. 0 means unlimited. Default: 20.
   - `defender`, struct containing the defender configuration. See [Defender](./defender.md) for more details.
     - `enabled`, boolean. Default `false`.
     - `ban_time`, integer. Ban time in minutes.
@@ -72,7 +73,7 @@ The configuration file contains the following sections:
     - `threshold`, integer. Threshold value for banning a client.
     - `score_invalid`, integer. Score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts.
     - `score_valid`, integer. Score for valid login attempts, eg. user accounts that exist.
-    - `score_rate_exceeded`, integer. Score for hosts that exceeded the configured rate limits.
+    - `score_limit_exceeded`, integer. Score for hosts that exceeded the configured rate limits or the maximum, per-host, allowed connections.
     - `observation_time`, integer. Defines the time window, in minutes, for tracking client errors. A host is banned if it has exceeded the defined threshold during the last observation time minutes.
     - `entries_soft_limit`, integer.
     - `entries_hard_limit`, integer. The number of banned IPs and host scores kept in memory will vary between the soft and hard limit.

+ 1 - 1
docs/rate-limiting.md

@@ -18,7 +18,7 @@ The supported protocols are:
 You can also define two types of rate limiters:
 
 - global, it is independent from the source host and therefore define an aggregate limit for the configured protocol/s
-- per-host, this type of rate limiter can be connected to the built-in [defender](./defender.md) and generate `score_rate_exceeded` events and thus hosts that repeatedly exceed the configured limit can be automatically blocked
+- per-host, this type of rate limiter can be connected to the built-in [defender](./defender.md) and generate `score_limit_exceeded` events and thus hosts that repeatedly exceed the configured limit can be automatically blocked
 
 If you configure a per-host rate limiter, SFTPGo will keep a rate limiter in memory for each host that connects to the service, you can limit the memory usage using the `entries_soft_limit` and `entries_hard_limit` configuration keys.
 

+ 25 - 2
ftpd/ftpd_test.go

@@ -779,13 +779,36 @@ func TestMaxConnections(t *testing.T) {
 	common.Config.MaxTotalConnections = oldValue
 }
 
+func TestMaxPerHostConnections(t *testing.T) {
+	oldValue := common.Config.MaxPerHostConnections
+	common.Config.MaxPerHostConnections = 1
+
+	user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
+	assert.NoError(t, err)
+	client, err := getFTPClient(user, true, nil)
+	if assert.NoError(t, err) {
+		err = checkBasicFTP(client)
+		assert.NoError(t, err)
+		_, err = getFTPClient(user, false, nil)
+		assert.Error(t, err)
+		err = client.Quit()
+		assert.NoError(t, err)
+	}
+	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+
+	common.Config.MaxPerHostConnections = oldValue
+}
+
 func TestRateLimiter(t *testing.T) {
 	oldConfig := config.GetCommonConfig()
 
 	cfg := config.GetCommonConfig()
 	cfg.DefenderConfig.Enabled = true
 	cfg.DefenderConfig.Threshold = 5
-	cfg.DefenderConfig.ScoreRateExceeded = 3
+	cfg.DefenderConfig.ScoreLimitExceeded = 3
 	cfg.RateLimitersConfig = []common.RateLimiterConfig{
 		{
 			Average:                1,
@@ -843,7 +866,7 @@ func TestDefender(t *testing.T) {
 	cfg := config.GetCommonConfig()
 	cfg.DefenderConfig.Enabled = true
 	cfg.DefenderConfig.Threshold = 3
-	cfg.DefenderConfig.ScoreRateExceeded = 2
+	cfg.DefenderConfig.ScoreLimitExceeded = 2
 
 	err := common.Initialize(cfg)
 	assert.NoError(t, err)

+ 3 - 3
ftpd/server.go

@@ -135,13 +135,13 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) {
 
 // ClientConnected is called to send the very first welcome message
 func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
-	common.Connections.AddNetworkConnection()
 	ipAddr := utils.GetIPFromRemoteAddress(cc.RemoteAddr().String())
+	common.Connections.AddClientConnection(ipAddr)
 	if common.IsBanned(ipAddr) {
 		logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, ip %#v is banned", ipAddr)
 		return "Access denied: banned client IP", common.ErrConnectionDenied
 	}
-	if !common.Connections.IsNewConnectionAllowed() {
+	if !common.Connections.IsNewConnectionAllowed(ipAddr) {
 		logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
 		return "Access denied: max allowed connection exceeded", common.ErrConnectionDenied
 	}
@@ -167,7 +167,7 @@ func (s *Server) ClientDisconnected(cc ftpserver.ClientContext) {
 	s.cleanTLSConnVerification(cc.ID())
 	connID := fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID())
 	common.Connections.Remove(connID)
-	common.Connections.RemoveNetworkConnection()
+	common.Connections.RemoveClientConnection(utils.GetIPFromRemoteAddress(cc.RemoteAddr().String()))
 }
 
 // AuthUser authenticates the user and selects an handling driver

+ 2 - 2
httpd/httpd_test.go

@@ -2945,7 +2945,7 @@ func TestDefenderAPI(t *testing.T) {
 	cfg := config.GetCommonConfig()
 	cfg.DefenderConfig.Enabled = true
 	cfg.DefenderConfig.Threshold = 3
-	cfg.DefenderConfig.ScoreRateExceeded = 2
+	cfg.DefenderConfig.ScoreLimitExceeded = 2
 
 	err := common.Initialize(cfg)
 	require.NoError(t, err)
@@ -4615,7 +4615,7 @@ func TestDefender(t *testing.T) {
 	cfg := config.GetCommonConfig()
 	cfg.DefenderConfig.Enabled = true
 	cfg.DefenderConfig.Threshold = 3
-	cfg.DefenderConfig.ScoreRateExceeded = 2
+	cfg.DefenderConfig.ScoreLimitExceeded = 2
 
 	err := common.Initialize(cfg)
 	assert.NoError(t, err)

+ 5 - 4
httpd/server.go

@@ -110,8 +110,10 @@ func (s *httpdServer) refreshCookie(next http.Handler) http.Handler {
 
 func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) {
 	r.Body = http.MaxBytesReader(w, r.Body, maxLoginPostSize)
-	common.Connections.AddNetworkConnection()
-	defer common.Connections.RemoveNetworkConnection()
+
+	ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
+	common.Connections.AddClientConnection(ipAddr)
+	defer common.Connections.RemoveClientConnection(ipAddr)
 
 	if err := r.ParseForm(); err != nil {
 		renderClientLoginPage(w, err.Error())
@@ -128,8 +130,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
 		return
 	}
 
-	ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
-	if !common.Connections.IsNewConnectionAllowed() {
+	if !common.Connections.IsNewConnectionAllowed(ipAddr) {
 		logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
 		renderClientLoginPage(w, "configured connections limit reached")
 		return

+ 4 - 4
httpd/webclient.go

@@ -279,20 +279,20 @@ func handleWebClientLogout(w http.ResponseWriter, r *http.Request) {
 }
 
 func handleClientGetFiles(w http.ResponseWriter, r *http.Request) {
-	common.Connections.AddNetworkConnection()
-	defer common.Connections.RemoveNetworkConnection()
+	ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
+	common.Connections.AddClientConnection(ipAddr)
+	defer common.Connections.RemoveClientConnection(ipAddr)
 
 	claims, err := getTokenClaims(r)
 	if err != nil || claims.Username == "" {
 		renderClientForbiddenPage(w, r, "Invalid token claims")
 		return
 	}
-	if !common.Connections.IsNewConnectionAllowed() {
+	if !common.Connections.IsNewConnectionAllowed(ipAddr) {
 		logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
 		renderClientForbiddenPage(w, r, "configured connections limit reached")
 		return
 	}
-	ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
 	if common.IsBanned(ipAddr) {
 		renderClientForbiddenPage(w, r, "your IP address is banned")
 		return

+ 4 - 4
sftpd/server.go

@@ -356,7 +356,7 @@ func canAcceptConnection(ip string) bool {
 		logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, ip %#v is banned", ip)
 		return false
 	}
-	if !common.Connections.IsNewConnectionAllowed() {
+	if !common.Connections.IsNewConnectionAllowed(ip) {
 		logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached")
 		return false
 	}
@@ -378,10 +378,10 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 		}
 	}()
 
-	common.Connections.AddNetworkConnection()
-	defer common.Connections.RemoveNetworkConnection()
-
 	ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
+	common.Connections.AddClientConnection(ipAddr)
+	defer common.Connections.RemoveClientConnection(ipAddr)
+
 	if !canAcceptConnection(ipAddr) {
 		conn.Close()
 		return

+ 36 - 2
sftpd/sftpd_test.go

@@ -527,7 +527,7 @@ func TestDefender(t *testing.T) {
 	cfg := config.GetCommonConfig()
 	cfg.DefenderConfig.Enabled = true
 	cfg.DefenderConfig.Threshold = 3
-	cfg.DefenderConfig.ScoreRateExceeded = 2
+	cfg.DefenderConfig.ScoreLimitExceeded = 2
 
 	err := common.Initialize(cfg)
 	assert.NoError(t, err)
@@ -663,6 +663,9 @@ func TestOpenReadWritePerm(t *testing.T) {
 }
 
 func TestConcurrency(t *testing.T) {
+	oldValue := common.Config.MaxPerHostConnections
+	common.Config.MaxPerHostConnections = 0
+
 	usePubKey := true
 	numLogins := 50
 	u := getTestUser(usePubKey)
@@ -747,6 +750,8 @@ func TestConcurrency(t *testing.T) {
 	assert.NoError(t, err)
 	err = os.RemoveAll(user.GetHomeDir())
 	assert.NoError(t, err)
+
+	common.Config.MaxPerHostConnections = oldValue
 }
 
 func TestProxyProtocol(t *testing.T) {
@@ -2896,6 +2901,7 @@ func TestQuotaDisabledError(t *testing.T) {
 	assert.NoError(t, err)
 }
 
+//nolint:dupl
 func TestMaxConnections(t *testing.T) {
 	oldValue := common.Config.MaxTotalConnections
 	common.Config.MaxTotalConnections = 1
@@ -2910,7 +2916,7 @@ func TestMaxConnections(t *testing.T) {
 		defer client.Close()
 		assert.NoError(t, checkBasicSFTP(client))
 		s, c, err := getSftpClient(user, usePubKey)
-		if !assert.Error(t, err, "max sessions exceeded, new login should not succeed") {
+		if !assert.Error(t, err, "max total connections exceeded, new login should not succeed") {
 			c.Close()
 			s.Close()
 		}
@@ -2923,6 +2929,34 @@ func TestMaxConnections(t *testing.T) {
 	common.Config.MaxTotalConnections = oldValue
 }
 
+//nolint:dupl
+func TestMaxPerHostConnections(t *testing.T) {
+	oldValue := common.Config.MaxPerHostConnections
+	common.Config.MaxPerHostConnections = 1
+
+	usePubKey := true
+	u := getTestUser(usePubKey)
+	user, _, err := httpdtest.AddUser(u, http.StatusCreated)
+	assert.NoError(t, err)
+	conn, client, err := getSftpClient(user, usePubKey)
+	if assert.NoError(t, err) {
+		defer conn.Close()
+		defer client.Close()
+		assert.NoError(t, checkBasicSFTP(client))
+		s, c, err := getSftpClient(user, usePubKey)
+		if !assert.Error(t, err, "max per host connections exceeded, new login should not succeed") {
+			c.Close()
+			s.Close()
+		}
+	}
+	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+
+	common.Config.MaxPerHostConnections = oldValue
+}
+
 func TestMaxSessions(t *testing.T) {
 	usePubKey := false
 	u := getTestUser(usePubKey)

+ 2 - 1
sftpgo.json

@@ -12,6 +12,7 @@
     "startup_hook": "",
     "post_connect_hook": "",
     "max_total_connections": 0,
+    "max_per_host_connections": 20,
     "defender": {
       "enabled": false,
       "ban_time": 30,
@@ -19,7 +20,7 @@
       "threshold": 15,
       "score_invalid": 2,
       "score_valid": 1,
-      "score_rate_exceeded": 3,
+      "score_limit_exceeded": 3,
       "observation_time": 30,
       "entries_soft_limit": 100,
       "entries_hard_limit": 150,

+ 7 - 5
webdavd/server.go

@@ -144,16 +144,18 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 			http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
 		}
 	}()
-	common.Connections.AddNetworkConnection()
-	defer common.Connections.RemoveNetworkConnection()
 
-	if !common.Connections.IsNewConnectionAllowed() {
+	checkRemoteAddress(r)
+	ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
+
+	common.Connections.AddClientConnection(ipAddr)
+	defer common.Connections.RemoveClientConnection(ipAddr)
+
+	if !common.Connections.IsNewConnectionAllowed(ipAddr) {
 		logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", "connection refused, configured limit reached")
 		http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable)
 		return
 	}
-	checkRemoteAddress(r)
-	ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
 	if common.IsBanned(ipAddr) {
 		http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
 		return

+ 28 - 1
webdavd/webdavd_test.go

@@ -747,7 +747,7 @@ func TestDefender(t *testing.T) {
 	cfg := config.GetCommonConfig()
 	cfg.DefenderConfig.Enabled = true
 	cfg.DefenderConfig.Threshold = 3
-	cfg.DefenderConfig.ScoreRateExceeded = 2
+	cfg.DefenderConfig.ScoreLimitExceeded = 2
 
 	err := common.Initialize(cfg)
 	assert.NoError(t, err)
@@ -934,6 +934,33 @@ func TestMaxConnections(t *testing.T) {
 	common.Config.MaxTotalConnections = oldValue
 }
 
+func TestMaxPerHostConnections(t *testing.T) {
+	oldValue := common.Config.MaxPerHostConnections
+	common.Config.MaxPerHostConnections = 1
+
+	user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
+	assert.NoError(t, err)
+	client := getWebDavClient(user, true, nil)
+	assert.NoError(t, checkBasicFunc(client))
+	// now add a fake connection
+	addrs, err := net.LookupHost("localhost")
+	assert.NoError(t, err)
+	for _, addr := range addrs {
+		common.Connections.AddClientConnection(addr)
+	}
+	assert.Error(t, checkBasicFunc(client))
+	for _, addr := range addrs {
+		common.Connections.RemoveClientConnection(addr)
+	}
+	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+	assert.Len(t, common.Connections.GetStats(), 0)
+
+	common.Config.MaxPerHostConnections = oldValue
+}
+
 func TestMaxSessions(t *testing.T) {
 	u := getTestUser()
 	u.MaxSessions = 1