mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-25 00:50:31 +00:00
allow to limit the number of per-host connections
This commit is contained in:
parent
8f736da4b8
commit
8f6cdacd00
21 changed files with 356 additions and 105 deletions
51
common/clientsmap.go
Normal file
51
common/clientsmap.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/drakkan/sftpgo/logger"
|
||||
)
|
||||
|
||||
// clienstMap is a struct containing the map of the connected clients
|
||||
type clientsMap struct {
|
||||
totalConnections int32
|
||||
mu sync.RWMutex
|
||||
clients map[string]int
|
||||
}
|
||||
|
||||
func (c *clientsMap) add(source string) {
|
||||
atomic.AddInt32(&c.totalConnections, 1)
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.clients[source]++
|
||||
}
|
||||
|
||||
func (c *clientsMap) remove(source string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if val, ok := c.clients[source]; ok {
|
||||
atomic.AddInt32(&c.totalConnections, -1)
|
||||
c.clients[source]--
|
||||
if val > 1 {
|
||||
return
|
||||
}
|
||||
delete(c.clients, source)
|
||||
} else {
|
||||
logger.Warn(logSender, "", "cannot remove client %v it is not mapped", source)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientsMap) getTotal() int32 {
|
||||
return atomic.LoadInt32(&c.totalConnections)
|
||||
}
|
||||
|
||||
func (c *clientsMap) getTotalFrom(source string) int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return c.clients[source]
|
||||
}
|
59
common/clientsmap_test.go
Normal file
59
common/clientsmap_test.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestClientsMap(t *testing.T) {
|
||||
m := clientsMap{
|
||||
clients: make(map[string]int),
|
||||
}
|
||||
ip1 := "192.168.1.1"
|
||||
ip2 := "192.168.1.2"
|
||||
m.add(ip1)
|
||||
assert.Equal(t, int32(1), m.getTotal())
|
||||
assert.Equal(t, 1, m.getTotalFrom(ip1))
|
||||
assert.Equal(t, 0, m.getTotalFrom(ip2))
|
||||
|
||||
m.add(ip1)
|
||||
m.add(ip2)
|
||||
assert.Equal(t, int32(3), m.getTotal())
|
||||
assert.Equal(t, 2, m.getTotalFrom(ip1))
|
||||
assert.Equal(t, 1, m.getTotalFrom(ip2))
|
||||
|
||||
m.add(ip1)
|
||||
m.add(ip1)
|
||||
m.add(ip2)
|
||||
assert.Equal(t, int32(6), m.getTotal())
|
||||
assert.Equal(t, 4, m.getTotalFrom(ip1))
|
||||
assert.Equal(t, 2, m.getTotalFrom(ip2))
|
||||
|
||||
m.remove(ip2)
|
||||
assert.Equal(t, int32(5), m.getTotal())
|
||||
assert.Equal(t, 4, m.getTotalFrom(ip1))
|
||||
assert.Equal(t, 1, m.getTotalFrom(ip2))
|
||||
|
||||
m.remove("unknown")
|
||||
assert.Equal(t, int32(5), m.getTotal())
|
||||
assert.Equal(t, 4, m.getTotalFrom(ip1))
|
||||
assert.Equal(t, 1, m.getTotalFrom(ip2))
|
||||
|
||||
m.remove(ip2)
|
||||
assert.Equal(t, int32(4), m.getTotal())
|
||||
assert.Equal(t, 4, m.getTotalFrom(ip1))
|
||||
assert.Equal(t, 0, m.getTotalFrom(ip2))
|
||||
|
||||
m.remove(ip1)
|
||||
m.remove(ip1)
|
||||
m.remove(ip1)
|
||||
assert.Equal(t, int32(1), m.getTotal())
|
||||
assert.Equal(t, 1, m.getTotalFrom(ip1))
|
||||
assert.Equal(t, 0, m.getTotalFrom(ip2))
|
||||
|
||||
m.remove(ip1)
|
||||
assert.Equal(t, int32(0), m.getTotal())
|
||||
assert.Equal(t, 0, m.getTotalFrom(ip1))
|
||||
assert.Equal(t, 0, m.getTotalFrom(ip2))
|
||||
}
|
|
@ -80,6 +80,12 @@ const (
|
|||
UploadModeAtomicWithResume
|
||||
)
|
||||
|
||||
func init() {
|
||||
Connections.clients = clientsMap{
|
||||
clients: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// errors definitions
|
||||
var (
|
||||
ErrPermissionDenied = errors.New("permission denied")
|
||||
|
@ -352,6 +358,8 @@ type Configuration struct {
|
|||
PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"`
|
||||
// Maximum number of concurrent client connections. 0 means unlimited
|
||||
MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"`
|
||||
// Maximum number of concurrent client connections from the same host (IP). 0 means unlimited
|
||||
MaxPerHostConnections int `json:"max_per_host_connections" mapstructure:"max_per_host_connections"`
|
||||
// Defender configuration
|
||||
DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"`
|
||||
// Rate limiter configurations
|
||||
|
@ -524,9 +532,9 @@ func (c *SSHConnection) Close() error {
|
|||
|
||||
// ActiveConnections holds the currect active connections with the associated transfers
|
||||
type ActiveConnections struct {
|
||||
// networkConnections is the counter for the network connections, it contains
|
||||
// both authenticated and estabilished connections and the ones waiting for authentication
|
||||
networkConnections int32
|
||||
// clients contains both authenticated and estabilished connections and the ones waiting
|
||||
// for authentication
|
||||
clients clientsMap
|
||||
sync.RWMutex
|
||||
connections []ActiveConnection
|
||||
sshConnections []*SSHConnection
|
||||
|
@ -693,27 +701,36 @@ func (conns *ActiveConnections) checkIdles() {
|
|||
conns.RUnlock()
|
||||
}
|
||||
|
||||
// AddNetworkConnection increments the network connections counter
|
||||
func (conns *ActiveConnections) AddNetworkConnection() {
|
||||
atomic.AddInt32(&conns.networkConnections, 1)
|
||||
// AddClientConnection stores a new client connection
|
||||
func (conns *ActiveConnections) AddClientConnection(ipAddr string) {
|
||||
conns.clients.add(ipAddr)
|
||||
}
|
||||
|
||||
// RemoveNetworkConnection decrements the network connections counter
|
||||
func (conns *ActiveConnections) RemoveNetworkConnection() {
|
||||
atomic.AddInt32(&conns.networkConnections, -1)
|
||||
// RemoveClientConnection removes a disconnected client from the tracked ones
|
||||
func (conns *ActiveConnections) RemoveClientConnection(ipAddr string) {
|
||||
conns.clients.remove(ipAddr)
|
||||
}
|
||||
|
||||
// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
|
||||
func (conns *ActiveConnections) IsNewConnectionAllowed() bool {
|
||||
if Config.MaxTotalConnections == 0 {
|
||||
func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool {
|
||||
if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
num := atomic.LoadInt32(&conns.networkConnections)
|
||||
if num > int32(Config.MaxTotalConnections) {
|
||||
logger.Debug(logSender, "", "active network connections %v/%v", num, Config.MaxTotalConnections)
|
||||
if Config.MaxPerHostConnections > 0 {
|
||||
if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections {
|
||||
logger.Debug(logSender, "", "active connections from %v %v/%v", ipAddr, total, Config.MaxPerHostConnections)
|
||||
AddDefenderEvent(ipAddr, HostEventLimitExceeded)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if Config.MaxTotalConnections > 0 {
|
||||
if total := conns.clients.getTotal(); total > int32(Config.MaxTotalConnections) {
|
||||
logger.Debug(logSender, "", "active client connections %v/%v", total, Config.MaxTotalConnections)
|
||||
return false
|
||||
}
|
||||
|
||||
// on a single SFTP connection we could have multiple SFTP channels or commands
|
||||
// so we check the estabilished connections too
|
||||
|
||||
|
@ -721,6 +738,9 @@ func (conns *ActiveConnections) IsNewConnectionAllowed() bool {
|
|||
defer conns.RUnlock()
|
||||
|
||||
return len(conns.connections) < Config.MaxTotalConnections
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetStats returns stats for active connections
|
||||
|
|
|
@ -228,32 +228,61 @@ func TestRateLimitersIntegration(t *testing.T) {
|
|||
|
||||
func TestMaxConnections(t *testing.T) {
|
||||
oldValue := Config.MaxTotalConnections
|
||||
Config.MaxTotalConnections = 1
|
||||
perHost := Config.MaxPerHostConnections
|
||||
|
||||
assert.True(t, Connections.IsNewConnectionAllowed())
|
||||
Config.MaxPerHostConnections = 0
|
||||
|
||||
ipAddr := "192.168.7.8"
|
||||
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||
|
||||
Config.MaxTotalConnections = 1
|
||||
Config.MaxPerHostConnections = perHost
|
||||
|
||||
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||
c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{})
|
||||
fakeConn := &fakeConnection{
|
||||
BaseConnection: c,
|
||||
}
|
||||
Connections.Add(fakeConn)
|
||||
assert.Len(t, Connections.GetStats(), 1)
|
||||
assert.False(t, Connections.IsNewConnectionAllowed())
|
||||
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||
|
||||
res := Connections.Close(fakeConn.GetID())
|
||||
assert.True(t, res)
|
||||
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
|
||||
|
||||
assert.True(t, Connections.IsNewConnectionAllowed())
|
||||
Connections.AddNetworkConnection()
|
||||
Connections.AddNetworkConnection()
|
||||
assert.False(t, Connections.IsNewConnectionAllowed())
|
||||
Connections.RemoveNetworkConnection()
|
||||
assert.True(t, Connections.IsNewConnectionAllowed())
|
||||
Connections.RemoveNetworkConnection()
|
||||
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||
Connections.AddClientConnection(ipAddr)
|
||||
Connections.AddClientConnection(ipAddr)
|
||||
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||
Connections.RemoveClientConnection(ipAddr)
|
||||
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||
Connections.RemoveClientConnection(ipAddr)
|
||||
|
||||
Config.MaxTotalConnections = oldValue
|
||||
}
|
||||
|
||||
func TestMaxConnectionPerHost(t *testing.T) {
|
||||
oldValue := Config.MaxPerHostConnections
|
||||
|
||||
Config.MaxPerHostConnections = 2
|
||||
|
||||
ipAddr := "192.168.9.9"
|
||||
Connections.AddClientConnection(ipAddr)
|
||||
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||
|
||||
Connections.AddClientConnection(ipAddr)
|
||||
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||
|
||||
Connections.AddClientConnection(ipAddr)
|
||||
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||
|
||||
Connections.RemoveClientConnection(ipAddr)
|
||||
Connections.RemoveClientConnection(ipAddr)
|
||||
|
||||
Config.MaxPerHostConnections = oldValue
|
||||
}
|
||||
|
||||
func TestIdleConnections(t *testing.T) {
|
||||
configCopy := Config
|
||||
|
||||
|
@ -340,7 +369,7 @@ func TestCloseConnection(t *testing.T) {
|
|||
fakeConn := &fakeConnection{
|
||||
BaseConnection: c,
|
||||
}
|
||||
assert.True(t, Connections.IsNewConnectionAllowed())
|
||||
assert.True(t, Connections.IsNewConnectionAllowed("127.0.0.1"))
|
||||
Connections.Add(fakeConn)
|
||||
assert.Len(t, Connections.GetStats(), 1)
|
||||
res := Connections.Close(fakeConn.GetID())
|
||||
|
|
|
@ -23,7 +23,7 @@ const (
|
|||
HostEventLoginFailed HostEvent = iota
|
||||
HostEventUserNotFound
|
||||
HostEventNoLoginTried
|
||||
HostEventRateExceeded
|
||||
HostEventLimitExceeded
|
||||
)
|
||||
|
||||
// Defender defines the interface that a defender must implements
|
||||
|
@ -51,8 +51,9 @@ type DefenderConfig struct {
|
|||
ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"`
|
||||
// Score for valid login attempts, eg. user accounts that exist
|
||||
ScoreValid int `json:"score_valid" mapstructure:"score_valid"`
|
||||
// Score for rate exceeded events, generated from the rate limiters
|
||||
ScoreRateExceeded int `json:"score_rate_exceeded" mapstructure:"score_rate_exceeded"`
|
||||
// Score for limit exceeded events, generated from the rate limiters or for max connections
|
||||
// per-host exceeded
|
||||
ScoreLimitExceeded int `json:"score_limit_exceeded" mapstructure:"score_limit_exceeded"`
|
||||
// Defines the time window, in minutes, for tracking client errors.
|
||||
// A host is banned if it has exceeded the defined threshold during
|
||||
// the last observation time minutes
|
||||
|
@ -126,8 +127,8 @@ func (c *DefenderConfig) validate() error {
|
|||
if c.ScoreValid >= c.Threshold {
|
||||
return fmt.Errorf("score_valid %v cannot be greater than threshold %v", c.ScoreValid, c.Threshold)
|
||||
}
|
||||
if c.ScoreRateExceeded >= c.Threshold {
|
||||
return fmt.Errorf("score_rate_exceeded %v cannot be greater than threshold %v", c.ScoreRateExceeded, c.Threshold)
|
||||
if c.ScoreLimitExceeded >= c.Threshold {
|
||||
return fmt.Errorf("score_limit_exceeded %v cannot be greater than threshold %v", c.ScoreLimitExceeded, c.Threshold)
|
||||
}
|
||||
if c.BanTime <= 0 {
|
||||
return fmt.Errorf("invalid ban_time %v", c.BanTime)
|
||||
|
@ -254,8 +255,8 @@ func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
|
|||
switch event {
|
||||
case HostEventLoginFailed:
|
||||
score = d.config.ScoreValid
|
||||
case HostEventRateExceeded:
|
||||
score = d.config.ScoreRateExceeded
|
||||
case HostEventLimitExceeded:
|
||||
score = d.config.ScoreLimitExceeded
|
||||
case HostEventUserNotFound, HostEventNoLoginTried:
|
||||
score = d.config.ScoreInvalid
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ func TestBasicDefender(t *testing.T) {
|
|||
Threshold: 5,
|
||||
ScoreInvalid: 2,
|
||||
ScoreValid: 1,
|
||||
ScoreRateExceeded: 3,
|
||||
ScoreLimitExceeded: 3,
|
||||
ObservationTime: 15,
|
||||
EntriesSoftLimit: 1,
|
||||
EntriesHardLimit: 2,
|
||||
|
@ -75,7 +75,7 @@ func TestBasicDefender(t *testing.T) {
|
|||
|
||||
defender.AddEvent("172.16.1.4", HostEventLoginFailed)
|
||||
defender.AddEvent("192.168.8.4", HostEventUserNotFound)
|
||||
defender.AddEvent("172.16.1.3", HostEventRateExceeded)
|
||||
defender.AddEvent("172.16.1.3", HostEventLimitExceeded)
|
||||
assert.Equal(t, 0, defender.countHosts())
|
||||
|
||||
testIP := "12.34.56.78"
|
||||
|
@ -84,7 +84,7 @@ func TestBasicDefender(t *testing.T) {
|
|||
assert.Equal(t, 0, defender.countBanned())
|
||||
assert.Equal(t, 1, defender.GetScore(testIP))
|
||||
assert.Nil(t, defender.GetBanTime(testIP))
|
||||
defender.AddEvent(testIP, HostEventRateExceeded)
|
||||
defender.AddEvent(testIP, HostEventLimitExceeded)
|
||||
assert.Equal(t, 1, defender.countHosts())
|
||||
assert.Equal(t, 0, defender.countBanned())
|
||||
assert.Equal(t, 4, defender.GetScore(testIP))
|
||||
|
@ -317,11 +317,11 @@ func TestDefenderConfig(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
|
||||
c.ScoreInvalid = 2
|
||||
c.ScoreRateExceeded = 10
|
||||
c.ScoreLimitExceeded = 10
|
||||
err = c.validate()
|
||||
require.Error(t, err)
|
||||
|
||||
c.ScoreRateExceeded = 2
|
||||
c.ScoreLimitExceeded = 2
|
||||
c.ScoreValid = 10
|
||||
err = c.validate()
|
||||
require.Error(t, err)
|
||||
|
|
|
@ -149,7 +149,7 @@ func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
|
|||
if delay > rl.maxDelay {
|
||||
res.Cancel()
|
||||
if rl.generateDefenderEvents && rl.globalBucket == nil {
|
||||
AddDefenderEvent(source, HostEventRateExceeded)
|
||||
AddDefenderEvent(source, HostEventLimitExceeded)
|
||||
}
|
||||
return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
|
||||
}
|
||||
|
|
|
@ -116,6 +116,7 @@ func Init() {
|
|||
ProxyAllowed: []string{},
|
||||
PostConnectHook: "",
|
||||
MaxTotalConnections: 0,
|
||||
MaxPerHostConnections: 20,
|
||||
DefenderConfig: common.DefenderConfig{
|
||||
Enabled: false,
|
||||
BanTime: 30,
|
||||
|
@ -123,7 +124,7 @@ func Init() {
|
|||
Threshold: 15,
|
||||
ScoreInvalid: 2,
|
||||
ScoreValid: 1,
|
||||
ScoreRateExceeded: 3,
|
||||
ScoreLimitExceeded: 3,
|
||||
ObservationTime: 30,
|
||||
EntriesSoftLimit: 100,
|
||||
EntriesHardLimit: 150,
|
||||
|
@ -873,13 +874,14 @@ func setViperDefaults() {
|
|||
viper.SetDefault("common.proxy_allowed", globalConf.Common.ProxyAllowed)
|
||||
viper.SetDefault("common.post_connect_hook", globalConf.Common.PostConnectHook)
|
||||
viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections)
|
||||
viper.SetDefault("common.max_per_host_connections", globalConf.Common.MaxPerHostConnections)
|
||||
viper.SetDefault("common.defender.enabled", globalConf.Common.DefenderConfig.Enabled)
|
||||
viper.SetDefault("common.defender.ban_time", globalConf.Common.DefenderConfig.BanTime)
|
||||
viper.SetDefault("common.defender.ban_time_increment", globalConf.Common.DefenderConfig.BanTimeIncrement)
|
||||
viper.SetDefault("common.defender.threshold", globalConf.Common.DefenderConfig.Threshold)
|
||||
viper.SetDefault("common.defender.score_invalid", globalConf.Common.DefenderConfig.ScoreInvalid)
|
||||
viper.SetDefault("common.defender.score_valid", globalConf.Common.DefenderConfig.ScoreValid)
|
||||
viper.SetDefault("common.defender.score_rate_exceeded", globalConf.Common.DefenderConfig.ScoreRateExceeded)
|
||||
viper.SetDefault("common.defender.score_limit_exceeded", globalConf.Common.DefenderConfig.ScoreLimitExceeded)
|
||||
viper.SetDefault("common.defender.observation_time", globalConf.Common.DefenderConfig.ObservationTime)
|
||||
viper.SetDefault("common.defender.entries_soft_limit", globalConf.Common.DefenderConfig.EntriesSoftLimit)
|
||||
viper.SetDefault("common.defender.entries_hard_limit", globalConf.Common.DefenderConfig.EntriesHardLimit)
|
||||
|
|
|
@ -8,7 +8,7 @@ You can configure a score for each event type:
|
|||
|
||||
- `score_valid`, defines the score for valid login attempts, eg. user accounts that exist. Default `1`.
|
||||
- `score_invalid`, defines the score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts. Default `2`.
|
||||
- `score_rate_exceeded`, defines the score for hosts that exceeded the configured rate limits. Default `3`.
|
||||
- `score_limit_exceeded`, defines the score for hosts that exceeded the configured rate limits or the configured max connections per host. Default `3`.
|
||||
|
||||
And then you can configure:
|
||||
|
||||
|
|
|
@ -64,7 +64,8 @@ The configuration file contains the following sections:
|
|||
- If `proxy_protocol` is set to 2 and we receive a proxy header from an IP that is not in the list then the connection will be rejected
|
||||
- `startup_hook`, string. Absolute path to an external program or an HTTP URL to invoke as soon as SFTPGo starts. If you define an HTTP URL it will be invoked using a `GET` request. Please note that SFTPGo services may not yet be available when this hook is run. Leave empty do disable
|
||||
- `post_connect_hook`, string. Absolute path to the command to execute or HTTP URL to notify. See [Post connect hook](./post-connect-hook.md) for more details. Leave empty to disable
|
||||
- `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited
|
||||
- `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited. Default: 0.
|
||||
- `max_per_host_connections`, integer. Maximum number of concurrent client connections from the same host (IP). If the defender is enabled, exceeding this limit will generate `score_limit_exceeded` events and thus hosts that repeatedly exceed the max allowed connections can be automatically blocked. 0 means unlimited. Default: 20.
|
||||
- `defender`, struct containing the defender configuration. See [Defender](./defender.md) for more details.
|
||||
- `enabled`, boolean. Default `false`.
|
||||
- `ban_time`, integer. Ban time in minutes.
|
||||
|
@ -72,7 +73,7 @@ The configuration file contains the following sections:
|
|||
- `threshold`, integer. Threshold value for banning a client.
|
||||
- `score_invalid`, integer. Score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts.
|
||||
- `score_valid`, integer. Score for valid login attempts, eg. user accounts that exist.
|
||||
- `score_rate_exceeded`, integer. Score for hosts that exceeded the configured rate limits.
|
||||
- `score_limit_exceeded`, integer. Score for hosts that exceeded the configured rate limits or the maximum, per-host, allowed connections.
|
||||
- `observation_time`, integer. Defines the time window, in minutes, for tracking client errors. A host is banned if it has exceeded the defined threshold during the last observation time minutes.
|
||||
- `entries_soft_limit`, integer.
|
||||
- `entries_hard_limit`, integer. The number of banned IPs and host scores kept in memory will vary between the soft and hard limit.
|
||||
|
|
|
@ -18,7 +18,7 @@ The supported protocols are:
|
|||
You can also define two types of rate limiters:
|
||||
|
||||
- global, it is independent from the source host and therefore define an aggregate limit for the configured protocol/s
|
||||
- per-host, this type of rate limiter can be connected to the built-in [defender](./defender.md) and generate `score_rate_exceeded` events and thus hosts that repeatedly exceed the configured limit can be automatically blocked
|
||||
- per-host, this type of rate limiter can be connected to the built-in [defender](./defender.md) and generate `score_limit_exceeded` events and thus hosts that repeatedly exceed the configured limit can be automatically blocked
|
||||
|
||||
If you configure a per-host rate limiter, SFTPGo will keep a rate limiter in memory for each host that connects to the service, you can limit the memory usage using the `entries_soft_limit` and `entries_hard_limit` configuration keys.
|
||||
|
||||
|
|
|
@ -779,13 +779,36 @@ func TestMaxConnections(t *testing.T) {
|
|||
common.Config.MaxTotalConnections = oldValue
|
||||
}
|
||||
|
||||
func TestMaxPerHostConnections(t *testing.T) {
|
||||
oldValue := common.Config.MaxPerHostConnections
|
||||
common.Config.MaxPerHostConnections = 1
|
||||
|
||||
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
client, err := getFTPClient(user, true, nil)
|
||||
if assert.NoError(t, err) {
|
||||
err = checkBasicFTP(client)
|
||||
assert.NoError(t, err)
|
||||
_, err = getFTPClient(user, false, nil)
|
||||
assert.Error(t, err)
|
||||
err = client.Quit()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
|
||||
common.Config.MaxPerHostConnections = oldValue
|
||||
}
|
||||
|
||||
func TestRateLimiter(t *testing.T) {
|
||||
oldConfig := config.GetCommonConfig()
|
||||
|
||||
cfg := config.GetCommonConfig()
|
||||
cfg.DefenderConfig.Enabled = true
|
||||
cfg.DefenderConfig.Threshold = 5
|
||||
cfg.DefenderConfig.ScoreRateExceeded = 3
|
||||
cfg.DefenderConfig.ScoreLimitExceeded = 3
|
||||
cfg.RateLimitersConfig = []common.RateLimiterConfig{
|
||||
{
|
||||
Average: 1,
|
||||
|
@ -843,7 +866,7 @@ func TestDefender(t *testing.T) {
|
|||
cfg := config.GetCommonConfig()
|
||||
cfg.DefenderConfig.Enabled = true
|
||||
cfg.DefenderConfig.Threshold = 3
|
||||
cfg.DefenderConfig.ScoreRateExceeded = 2
|
||||
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||
|
||||
err := common.Initialize(cfg)
|
||||
assert.NoError(t, err)
|
||||
|
|
|
@ -135,13 +135,13 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) {
|
|||
|
||||
// ClientConnected is called to send the very first welcome message
|
||||
func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
|
||||
common.Connections.AddNetworkConnection()
|
||||
ipAddr := utils.GetIPFromRemoteAddress(cc.RemoteAddr().String())
|
||||
common.Connections.AddClientConnection(ipAddr)
|
||||
if common.IsBanned(ipAddr) {
|
||||
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, ip %#v is banned", ipAddr)
|
||||
return "Access denied: banned client IP", common.ErrConnectionDenied
|
||||
}
|
||||
if !common.Connections.IsNewConnectionAllowed() {
|
||||
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
|
||||
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
|
||||
return "Access denied: max allowed connection exceeded", common.ErrConnectionDenied
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ func (s *Server) ClientDisconnected(cc ftpserver.ClientContext) {
|
|||
s.cleanTLSConnVerification(cc.ID())
|
||||
connID := fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID())
|
||||
common.Connections.Remove(connID)
|
||||
common.Connections.RemoveNetworkConnection()
|
||||
common.Connections.RemoveClientConnection(utils.GetIPFromRemoteAddress(cc.RemoteAddr().String()))
|
||||
}
|
||||
|
||||
// AuthUser authenticates the user and selects an handling driver
|
||||
|
|
|
@ -2945,7 +2945,7 @@ func TestDefenderAPI(t *testing.T) {
|
|||
cfg := config.GetCommonConfig()
|
||||
cfg.DefenderConfig.Enabled = true
|
||||
cfg.DefenderConfig.Threshold = 3
|
||||
cfg.DefenderConfig.ScoreRateExceeded = 2
|
||||
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||
|
||||
err := common.Initialize(cfg)
|
||||
require.NoError(t, err)
|
||||
|
@ -4615,7 +4615,7 @@ func TestDefender(t *testing.T) {
|
|||
cfg := config.GetCommonConfig()
|
||||
cfg.DefenderConfig.Enabled = true
|
||||
cfg.DefenderConfig.Threshold = 3
|
||||
cfg.DefenderConfig.ScoreRateExceeded = 2
|
||||
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||
|
||||
err := common.Initialize(cfg)
|
||||
assert.NoError(t, err)
|
||||
|
|
|
@ -110,8 +110,10 @@ func (s *httpdServer) refreshCookie(next http.Handler) http.Handler {
|
|||
|
||||
func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxLoginPostSize)
|
||||
common.Connections.AddNetworkConnection()
|
||||
defer common.Connections.RemoveNetworkConnection()
|
||||
|
||||
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
common.Connections.AddClientConnection(ipAddr)
|
||||
defer common.Connections.RemoveClientConnection(ipAddr)
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
renderClientLoginPage(w, err.Error())
|
||||
|
@ -128,8 +130,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
|
|||
return
|
||||
}
|
||||
|
||||
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if !common.Connections.IsNewConnectionAllowed() {
|
||||
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
|
||||
logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
|
||||
renderClientLoginPage(w, "configured connections limit reached")
|
||||
return
|
||||
|
|
|
@ -279,20 +279,20 @@ func handleWebClientLogout(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func handleClientGetFiles(w http.ResponseWriter, r *http.Request) {
|
||||
common.Connections.AddNetworkConnection()
|
||||
defer common.Connections.RemoveNetworkConnection()
|
||||
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
common.Connections.AddClientConnection(ipAddr)
|
||||
defer common.Connections.RemoveClientConnection(ipAddr)
|
||||
|
||||
claims, err := getTokenClaims(r)
|
||||
if err != nil || claims.Username == "" {
|
||||
renderClientForbiddenPage(w, r, "Invalid token claims")
|
||||
return
|
||||
}
|
||||
if !common.Connections.IsNewConnectionAllowed() {
|
||||
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
|
||||
logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
|
||||
renderClientForbiddenPage(w, r, "configured connections limit reached")
|
||||
return
|
||||
}
|
||||
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if common.IsBanned(ipAddr) {
|
||||
renderClientForbiddenPage(w, r, "your IP address is banned")
|
||||
return
|
||||
|
|
|
@ -356,7 +356,7 @@ func canAcceptConnection(ip string) bool {
|
|||
logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, ip %#v is banned", ip)
|
||||
return false
|
||||
}
|
||||
if !common.Connections.IsNewConnectionAllowed() {
|
||||
if !common.Connections.IsNewConnectionAllowed(ip) {
|
||||
logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached")
|
||||
return false
|
||||
}
|
||||
|
@ -378,10 +378,10 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
|
|||
}
|
||||
}()
|
||||
|
||||
common.Connections.AddNetworkConnection()
|
||||
defer common.Connections.RemoveNetworkConnection()
|
||||
|
||||
ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
|
||||
common.Connections.AddClientConnection(ipAddr)
|
||||
defer common.Connections.RemoveClientConnection(ipAddr)
|
||||
|
||||
if !canAcceptConnection(ipAddr) {
|
||||
conn.Close()
|
||||
return
|
||||
|
|
|
@ -527,7 +527,7 @@ func TestDefender(t *testing.T) {
|
|||
cfg := config.GetCommonConfig()
|
||||
cfg.DefenderConfig.Enabled = true
|
||||
cfg.DefenderConfig.Threshold = 3
|
||||
cfg.DefenderConfig.ScoreRateExceeded = 2
|
||||
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||
|
||||
err := common.Initialize(cfg)
|
||||
assert.NoError(t, err)
|
||||
|
@ -663,6 +663,9 @@ func TestOpenReadWritePerm(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
oldValue := common.Config.MaxPerHostConnections
|
||||
common.Config.MaxPerHostConnections = 0
|
||||
|
||||
usePubKey := true
|
||||
numLogins := 50
|
||||
u := getTestUser(usePubKey)
|
||||
|
@ -747,6 +750,8 @@ func TestConcurrency(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
|
||||
common.Config.MaxPerHostConnections = oldValue
|
||||
}
|
||||
|
||||
func TestProxyProtocol(t *testing.T) {
|
||||
|
@ -2896,6 +2901,7 @@ func TestQuotaDisabledError(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
//nolint:dupl
|
||||
func TestMaxConnections(t *testing.T) {
|
||||
oldValue := common.Config.MaxTotalConnections
|
||||
common.Config.MaxTotalConnections = 1
|
||||
|
@ -2910,7 +2916,7 @@ func TestMaxConnections(t *testing.T) {
|
|||
defer client.Close()
|
||||
assert.NoError(t, checkBasicSFTP(client))
|
||||
s, c, err := getSftpClient(user, usePubKey)
|
||||
if !assert.Error(t, err, "max sessions exceeded, new login should not succeed") {
|
||||
if !assert.Error(t, err, "max total connections exceeded, new login should not succeed") {
|
||||
c.Close()
|
||||
s.Close()
|
||||
}
|
||||
|
@ -2923,6 +2929,34 @@ func TestMaxConnections(t *testing.T) {
|
|||
common.Config.MaxTotalConnections = oldValue
|
||||
}
|
||||
|
||||
//nolint:dupl
|
||||
func TestMaxPerHostConnections(t *testing.T) {
|
||||
oldValue := common.Config.MaxPerHostConnections
|
||||
common.Config.MaxPerHostConnections = 1
|
||||
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
conn, client, err := getSftpClient(user, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
assert.NoError(t, checkBasicSFTP(client))
|
||||
s, c, err := getSftpClient(user, usePubKey)
|
||||
if !assert.Error(t, err, "max per host connections exceeded, new login should not succeed") {
|
||||
c.Close()
|
||||
s.Close()
|
||||
}
|
||||
}
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
|
||||
common.Config.MaxPerHostConnections = oldValue
|
||||
}
|
||||
|
||||
func TestMaxSessions(t *testing.T) {
|
||||
usePubKey := false
|
||||
u := getTestUser(usePubKey)
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
"startup_hook": "",
|
||||
"post_connect_hook": "",
|
||||
"max_total_connections": 0,
|
||||
"max_per_host_connections": 20,
|
||||
"defender": {
|
||||
"enabled": false,
|
||||
"ban_time": 30,
|
||||
|
@ -19,7 +20,7 @@
|
|||
"threshold": 15,
|
||||
"score_invalid": 2,
|
||||
"score_valid": 1,
|
||||
"score_rate_exceeded": 3,
|
||||
"score_limit_exceeded": 3,
|
||||
"observation_time": 30,
|
||||
"entries_soft_limit": 100,
|
||||
"entries_hard_limit": 150,
|
||||
|
|
|
@ -144,16 +144,18 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
common.Connections.AddNetworkConnection()
|
||||
defer common.Connections.RemoveNetworkConnection()
|
||||
|
||||
if !common.Connections.IsNewConnectionAllowed() {
|
||||
checkRemoteAddress(r)
|
||||
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
|
||||
common.Connections.AddClientConnection(ipAddr)
|
||||
defer common.Connections.RemoveClientConnection(ipAddr)
|
||||
|
||||
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
|
||||
logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", "connection refused, configured limit reached")
|
||||
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
checkRemoteAddress(r)
|
||||
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||
if common.IsBanned(ipAddr) {
|
||||
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
|
||||
return
|
||||
|
|
|
@ -747,7 +747,7 @@ func TestDefender(t *testing.T) {
|
|||
cfg := config.GetCommonConfig()
|
||||
cfg.DefenderConfig.Enabled = true
|
||||
cfg.DefenderConfig.Threshold = 3
|
||||
cfg.DefenderConfig.ScoreRateExceeded = 2
|
||||
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||
|
||||
err := common.Initialize(cfg)
|
||||
assert.NoError(t, err)
|
||||
|
@ -934,6 +934,33 @@ func TestMaxConnections(t *testing.T) {
|
|||
common.Config.MaxTotalConnections = oldValue
|
||||
}
|
||||
|
||||
func TestMaxPerHostConnections(t *testing.T) {
|
||||
oldValue := common.Config.MaxPerHostConnections
|
||||
common.Config.MaxPerHostConnections = 1
|
||||
|
||||
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
|
||||
assert.NoError(t, err)
|
||||
client := getWebDavClient(user, true, nil)
|
||||
assert.NoError(t, checkBasicFunc(client))
|
||||
// now add a fake connection
|
||||
addrs, err := net.LookupHost("localhost")
|
||||
assert.NoError(t, err)
|
||||
for _, addr := range addrs {
|
||||
common.Connections.AddClientConnection(addr)
|
||||
}
|
||||
assert.Error(t, checkBasicFunc(client))
|
||||
for _, addr := range addrs {
|
||||
common.Connections.RemoveClientConnection(addr)
|
||||
}
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, common.Connections.GetStats(), 0)
|
||||
|
||||
common.Config.MaxPerHostConnections = oldValue
|
||||
}
|
||||
|
||||
func TestMaxSessions(t *testing.T) {
|
||||
u := getTestUser()
|
||||
u.MaxSessions = 1
|
||||
|
|
Loading…
Reference in a new issue