allow to limit the number of per-host connections

This commit is contained in:
Nicola Murino 2021-05-08 19:45:21 +02:00
parent 8f736da4b8
commit 8f6cdacd00
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
21 changed files with 356 additions and 105 deletions

51
common/clientsmap.go Normal file
View file

@ -0,0 +1,51 @@
package common
import (
"sync"
"sync/atomic"
"github.com/drakkan/sftpgo/logger"
)
// clienstMap is a struct containing the map of the connected clients
type clientsMap struct {
totalConnections int32
mu sync.RWMutex
clients map[string]int
}
func (c *clientsMap) add(source string) {
atomic.AddInt32(&c.totalConnections, 1)
c.mu.Lock()
defer c.mu.Unlock()
c.clients[source]++
}
func (c *clientsMap) remove(source string) {
c.mu.Lock()
defer c.mu.Unlock()
if val, ok := c.clients[source]; ok {
atomic.AddInt32(&c.totalConnections, -1)
c.clients[source]--
if val > 1 {
return
}
delete(c.clients, source)
} else {
logger.Warn(logSender, "", "cannot remove client %v it is not mapped", source)
}
}
func (c *clientsMap) getTotal() int32 {
return atomic.LoadInt32(&c.totalConnections)
}
func (c *clientsMap) getTotalFrom(source string) int {
c.mu.RLock()
defer c.mu.RUnlock()
return c.clients[source]
}

59
common/clientsmap_test.go Normal file
View file

@ -0,0 +1,59 @@
package common
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestClientsMap(t *testing.T) {
m := clientsMap{
clients: make(map[string]int),
}
ip1 := "192.168.1.1"
ip2 := "192.168.1.2"
m.add(ip1)
assert.Equal(t, int32(1), m.getTotal())
assert.Equal(t, 1, m.getTotalFrom(ip1))
assert.Equal(t, 0, m.getTotalFrom(ip2))
m.add(ip1)
m.add(ip2)
assert.Equal(t, int32(3), m.getTotal())
assert.Equal(t, 2, m.getTotalFrom(ip1))
assert.Equal(t, 1, m.getTotalFrom(ip2))
m.add(ip1)
m.add(ip1)
m.add(ip2)
assert.Equal(t, int32(6), m.getTotal())
assert.Equal(t, 4, m.getTotalFrom(ip1))
assert.Equal(t, 2, m.getTotalFrom(ip2))
m.remove(ip2)
assert.Equal(t, int32(5), m.getTotal())
assert.Equal(t, 4, m.getTotalFrom(ip1))
assert.Equal(t, 1, m.getTotalFrom(ip2))
m.remove("unknown")
assert.Equal(t, int32(5), m.getTotal())
assert.Equal(t, 4, m.getTotalFrom(ip1))
assert.Equal(t, 1, m.getTotalFrom(ip2))
m.remove(ip2)
assert.Equal(t, int32(4), m.getTotal())
assert.Equal(t, 4, m.getTotalFrom(ip1))
assert.Equal(t, 0, m.getTotalFrom(ip2))
m.remove(ip1)
m.remove(ip1)
m.remove(ip1)
assert.Equal(t, int32(1), m.getTotal())
assert.Equal(t, 1, m.getTotalFrom(ip1))
assert.Equal(t, 0, m.getTotalFrom(ip2))
m.remove(ip1)
assert.Equal(t, int32(0), m.getTotal())
assert.Equal(t, 0, m.getTotalFrom(ip1))
assert.Equal(t, 0, m.getTotalFrom(ip2))
}

View file

@ -80,6 +80,12 @@ const (
UploadModeAtomicWithResume
)
func init() {
Connections.clients = clientsMap{
clients: make(map[string]int),
}
}
// errors definitions
var (
ErrPermissionDenied = errors.New("permission denied")
@ -352,6 +358,8 @@ type Configuration struct {
PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"`
// Maximum number of concurrent client connections. 0 means unlimited
MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"`
// Maximum number of concurrent client connections from the same host (IP). 0 means unlimited
MaxPerHostConnections int `json:"max_per_host_connections" mapstructure:"max_per_host_connections"`
// Defender configuration
DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"`
// Rate limiter configurations
@ -524,9 +532,9 @@ func (c *SSHConnection) Close() error {
// ActiveConnections holds the currect active connections with the associated transfers
type ActiveConnections struct {
// networkConnections is the counter for the network connections, it contains
// both authenticated and estabilished connections and the ones waiting for authentication
networkConnections int32
// clients contains both authenticated and estabilished connections and the ones waiting
// for authentication
clients clientsMap
sync.RWMutex
connections []ActiveConnection
sshConnections []*SSHConnection
@ -693,27 +701,36 @@ func (conns *ActiveConnections) checkIdles() {
conns.RUnlock()
}
// AddNetworkConnection increments the network connections counter
func (conns *ActiveConnections) AddNetworkConnection() {
atomic.AddInt32(&conns.networkConnections, 1)
// AddClientConnection stores a new client connection
func (conns *ActiveConnections) AddClientConnection(ipAddr string) {
conns.clients.add(ipAddr)
}
// RemoveNetworkConnection decrements the network connections counter
func (conns *ActiveConnections) RemoveNetworkConnection() {
atomic.AddInt32(&conns.networkConnections, -1)
// RemoveClientConnection removes a disconnected client from the tracked ones
func (conns *ActiveConnections) RemoveClientConnection(ipAddr string) {
conns.clients.remove(ipAddr)
}
// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
func (conns *ActiveConnections) IsNewConnectionAllowed() bool {
if Config.MaxTotalConnections == 0 {
func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool {
if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 {
return true
}
num := atomic.LoadInt32(&conns.networkConnections)
if num > int32(Config.MaxTotalConnections) {
logger.Debug(logSender, "", "active network connections %v/%v", num, Config.MaxTotalConnections)
if Config.MaxPerHostConnections > 0 {
if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections {
logger.Debug(logSender, "", "active connections from %v %v/%v", ipAddr, total, Config.MaxPerHostConnections)
AddDefenderEvent(ipAddr, HostEventLimitExceeded)
return false
}
}
if Config.MaxTotalConnections > 0 {
if total := conns.clients.getTotal(); total > int32(Config.MaxTotalConnections) {
logger.Debug(logSender, "", "active client connections %v/%v", total, Config.MaxTotalConnections)
return false
}
// on a single SFTP connection we could have multiple SFTP channels or commands
// so we check the estabilished connections too
@ -721,6 +738,9 @@ func (conns *ActiveConnections) IsNewConnectionAllowed() bool {
defer conns.RUnlock()
return len(conns.connections) < Config.MaxTotalConnections
}
return true
}
// GetStats returns stats for active connections

View file

@ -228,32 +228,61 @@ func TestRateLimitersIntegration(t *testing.T) {
func TestMaxConnections(t *testing.T) {
oldValue := Config.MaxTotalConnections
Config.MaxTotalConnections = 1
perHost := Config.MaxPerHostConnections
assert.True(t, Connections.IsNewConnectionAllowed())
Config.MaxPerHostConnections = 0
ipAddr := "192.168.7.8"
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Config.MaxTotalConnections = 1
Config.MaxPerHostConnections = perHost
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{})
fakeConn := &fakeConnection{
BaseConnection: c,
}
Connections.Add(fakeConn)
assert.Len(t, Connections.GetStats(), 1)
assert.False(t, Connections.IsNewConnectionAllowed())
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
res := Connections.Close(fakeConn.GetID())
assert.True(t, res)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
assert.True(t, Connections.IsNewConnectionAllowed())
Connections.AddNetworkConnection()
Connections.AddNetworkConnection()
assert.False(t, Connections.IsNewConnectionAllowed())
Connections.RemoveNetworkConnection()
assert.True(t, Connections.IsNewConnectionAllowed())
Connections.RemoveNetworkConnection()
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.AddClientConnection(ipAddr)
Connections.AddClientConnection(ipAddr)
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.RemoveClientConnection(ipAddr)
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.RemoveClientConnection(ipAddr)
Config.MaxTotalConnections = oldValue
}
func TestMaxConnectionPerHost(t *testing.T) {
oldValue := Config.MaxPerHostConnections
Config.MaxPerHostConnections = 2
ipAddr := "192.168.9.9"
Connections.AddClientConnection(ipAddr)
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.AddClientConnection(ipAddr)
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.AddClientConnection(ipAddr)
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.RemoveClientConnection(ipAddr)
Connections.RemoveClientConnection(ipAddr)
Config.MaxPerHostConnections = oldValue
}
func TestIdleConnections(t *testing.T) {
configCopy := Config
@ -340,7 +369,7 @@ func TestCloseConnection(t *testing.T) {
fakeConn := &fakeConnection{
BaseConnection: c,
}
assert.True(t, Connections.IsNewConnectionAllowed())
assert.True(t, Connections.IsNewConnectionAllowed("127.0.0.1"))
Connections.Add(fakeConn)
assert.Len(t, Connections.GetStats(), 1)
res := Connections.Close(fakeConn.GetID())

View file

@ -23,7 +23,7 @@ const (
HostEventLoginFailed HostEvent = iota
HostEventUserNotFound
HostEventNoLoginTried
HostEventRateExceeded
HostEventLimitExceeded
)
// Defender defines the interface that a defender must implements
@ -51,8 +51,9 @@ type DefenderConfig struct {
ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"`
// Score for valid login attempts, eg. user accounts that exist
ScoreValid int `json:"score_valid" mapstructure:"score_valid"`
// Score for rate exceeded events, generated from the rate limiters
ScoreRateExceeded int `json:"score_rate_exceeded" mapstructure:"score_rate_exceeded"`
// Score for limit exceeded events, generated from the rate limiters or for max connections
// per-host exceeded
ScoreLimitExceeded int `json:"score_limit_exceeded" mapstructure:"score_limit_exceeded"`
// Defines the time window, in minutes, for tracking client errors.
// A host is banned if it has exceeded the defined threshold during
// the last observation time minutes
@ -126,8 +127,8 @@ func (c *DefenderConfig) validate() error {
if c.ScoreValid >= c.Threshold {
return fmt.Errorf("score_valid %v cannot be greater than threshold %v", c.ScoreValid, c.Threshold)
}
if c.ScoreRateExceeded >= c.Threshold {
return fmt.Errorf("score_rate_exceeded %v cannot be greater than threshold %v", c.ScoreRateExceeded, c.Threshold)
if c.ScoreLimitExceeded >= c.Threshold {
return fmt.Errorf("score_limit_exceeded %v cannot be greater than threshold %v", c.ScoreLimitExceeded, c.Threshold)
}
if c.BanTime <= 0 {
return fmt.Errorf("invalid ban_time %v", c.BanTime)
@ -254,8 +255,8 @@ func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
switch event {
case HostEventLoginFailed:
score = d.config.ScoreValid
case HostEventRateExceeded:
score = d.config.ScoreRateExceeded
case HostEventLimitExceeded:
score = d.config.ScoreLimitExceeded
case HostEventUserNotFound, HostEventNoLoginTried:
score = d.config.ScoreInvalid
}

View file

@ -47,7 +47,7 @@ func TestBasicDefender(t *testing.T) {
Threshold: 5,
ScoreInvalid: 2,
ScoreValid: 1,
ScoreRateExceeded: 3,
ScoreLimitExceeded: 3,
ObservationTime: 15,
EntriesSoftLimit: 1,
EntriesHardLimit: 2,
@ -75,7 +75,7 @@ func TestBasicDefender(t *testing.T) {
defender.AddEvent("172.16.1.4", HostEventLoginFailed)
defender.AddEvent("192.168.8.4", HostEventUserNotFound)
defender.AddEvent("172.16.1.3", HostEventRateExceeded)
defender.AddEvent("172.16.1.3", HostEventLimitExceeded)
assert.Equal(t, 0, defender.countHosts())
testIP := "12.34.56.78"
@ -84,7 +84,7 @@ func TestBasicDefender(t *testing.T) {
assert.Equal(t, 0, defender.countBanned())
assert.Equal(t, 1, defender.GetScore(testIP))
assert.Nil(t, defender.GetBanTime(testIP))
defender.AddEvent(testIP, HostEventRateExceeded)
defender.AddEvent(testIP, HostEventLimitExceeded)
assert.Equal(t, 1, defender.countHosts())
assert.Equal(t, 0, defender.countBanned())
assert.Equal(t, 4, defender.GetScore(testIP))
@ -317,11 +317,11 @@ func TestDefenderConfig(t *testing.T) {
require.Error(t, err)
c.ScoreInvalid = 2
c.ScoreRateExceeded = 10
c.ScoreLimitExceeded = 10
err = c.validate()
require.Error(t, err)
c.ScoreRateExceeded = 2
c.ScoreLimitExceeded = 2
c.ScoreValid = 10
err = c.validate()
require.Error(t, err)

View file

@ -149,7 +149,7 @@ func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
if delay > rl.maxDelay {
res.Cancel()
if rl.generateDefenderEvents && rl.globalBucket == nil {
AddDefenderEvent(source, HostEventRateExceeded)
AddDefenderEvent(source, HostEventLimitExceeded)
}
return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
}

View file

@ -116,6 +116,7 @@ func Init() {
ProxyAllowed: []string{},
PostConnectHook: "",
MaxTotalConnections: 0,
MaxPerHostConnections: 20,
DefenderConfig: common.DefenderConfig{
Enabled: false,
BanTime: 30,
@ -123,7 +124,7 @@ func Init() {
Threshold: 15,
ScoreInvalid: 2,
ScoreValid: 1,
ScoreRateExceeded: 3,
ScoreLimitExceeded: 3,
ObservationTime: 30,
EntriesSoftLimit: 100,
EntriesHardLimit: 150,
@ -873,13 +874,14 @@ func setViperDefaults() {
viper.SetDefault("common.proxy_allowed", globalConf.Common.ProxyAllowed)
viper.SetDefault("common.post_connect_hook", globalConf.Common.PostConnectHook)
viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections)
viper.SetDefault("common.max_per_host_connections", globalConf.Common.MaxPerHostConnections)
viper.SetDefault("common.defender.enabled", globalConf.Common.DefenderConfig.Enabled)
viper.SetDefault("common.defender.ban_time", globalConf.Common.DefenderConfig.BanTime)
viper.SetDefault("common.defender.ban_time_increment", globalConf.Common.DefenderConfig.BanTimeIncrement)
viper.SetDefault("common.defender.threshold", globalConf.Common.DefenderConfig.Threshold)
viper.SetDefault("common.defender.score_invalid", globalConf.Common.DefenderConfig.ScoreInvalid)
viper.SetDefault("common.defender.score_valid", globalConf.Common.DefenderConfig.ScoreValid)
viper.SetDefault("common.defender.score_rate_exceeded", globalConf.Common.DefenderConfig.ScoreRateExceeded)
viper.SetDefault("common.defender.score_limit_exceeded", globalConf.Common.DefenderConfig.ScoreLimitExceeded)
viper.SetDefault("common.defender.observation_time", globalConf.Common.DefenderConfig.ObservationTime)
viper.SetDefault("common.defender.entries_soft_limit", globalConf.Common.DefenderConfig.EntriesSoftLimit)
viper.SetDefault("common.defender.entries_hard_limit", globalConf.Common.DefenderConfig.EntriesHardLimit)

View file

@ -8,7 +8,7 @@ You can configure a score for each event type:
- `score_valid`, defines the score for valid login attempts, eg. user accounts that exist. Default `1`.
- `score_invalid`, defines the score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts. Default `2`.
- `score_rate_exceeded`, defines the score for hosts that exceeded the configured rate limits. Default `3`.
- `score_limit_exceeded`, defines the score for hosts that exceeded the configured rate limits or the configured max connections per host. Default `3`.
And then you can configure:

View file

@ -64,7 +64,8 @@ The configuration file contains the following sections:
- If `proxy_protocol` is set to 2 and we receive a proxy header from an IP that is not in the list then the connection will be rejected
- `startup_hook`, string. Absolute path to an external program or an HTTP URL to invoke as soon as SFTPGo starts. If you define an HTTP URL it will be invoked using a `GET` request. Please note that SFTPGo services may not yet be available when this hook is run. Leave empty do disable
- `post_connect_hook`, string. Absolute path to the command to execute or HTTP URL to notify. See [Post connect hook](./post-connect-hook.md) for more details. Leave empty to disable
- `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited
- `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited. Default: 0.
- `max_per_host_connections`, integer. Maximum number of concurrent client connections from the same host (IP). If the defender is enabled, exceeding this limit will generate `score_limit_exceeded` events and thus hosts that repeatedly exceed the max allowed connections can be automatically blocked. 0 means unlimited. Default: 20.
- `defender`, struct containing the defender configuration. See [Defender](./defender.md) for more details.
- `enabled`, boolean. Default `false`.
- `ban_time`, integer. Ban time in minutes.
@ -72,7 +73,7 @@ The configuration file contains the following sections:
- `threshold`, integer. Threshold value for banning a client.
- `score_invalid`, integer. Score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts.
- `score_valid`, integer. Score for valid login attempts, eg. user accounts that exist.
- `score_rate_exceeded`, integer. Score for hosts that exceeded the configured rate limits.
- `score_limit_exceeded`, integer. Score for hosts that exceeded the configured rate limits or the maximum, per-host, allowed connections.
- `observation_time`, integer. Defines the time window, in minutes, for tracking client errors. A host is banned if it has exceeded the defined threshold during the last observation time minutes.
- `entries_soft_limit`, integer.
- `entries_hard_limit`, integer. The number of banned IPs and host scores kept in memory will vary between the soft and hard limit.

View file

@ -18,7 +18,7 @@ The supported protocols are:
You can also define two types of rate limiters:
- global, it is independent from the source host and therefore define an aggregate limit for the configured protocol/s
- per-host, this type of rate limiter can be connected to the built-in [defender](./defender.md) and generate `score_rate_exceeded` events and thus hosts that repeatedly exceed the configured limit can be automatically blocked
- per-host, this type of rate limiter can be connected to the built-in [defender](./defender.md) and generate `score_limit_exceeded` events and thus hosts that repeatedly exceed the configured limit can be automatically blocked
If you configure a per-host rate limiter, SFTPGo will keep a rate limiter in memory for each host that connects to the service, you can limit the memory usage using the `entries_soft_limit` and `entries_hard_limit` configuration keys.

View file

@ -779,13 +779,36 @@ func TestMaxConnections(t *testing.T) {
common.Config.MaxTotalConnections = oldValue
}
func TestMaxPerHostConnections(t *testing.T) {
oldValue := common.Config.MaxPerHostConnections
common.Config.MaxPerHostConnections = 1
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
assert.NoError(t, err)
client, err := getFTPClient(user, true, nil)
if assert.NoError(t, err) {
err = checkBasicFTP(client)
assert.NoError(t, err)
_, err = getFTPClient(user, false, nil)
assert.Error(t, err)
err = client.Quit()
assert.NoError(t, err)
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
common.Config.MaxPerHostConnections = oldValue
}
func TestRateLimiter(t *testing.T) {
oldConfig := config.GetCommonConfig()
cfg := config.GetCommonConfig()
cfg.DefenderConfig.Enabled = true
cfg.DefenderConfig.Threshold = 5
cfg.DefenderConfig.ScoreRateExceeded = 3
cfg.DefenderConfig.ScoreLimitExceeded = 3
cfg.RateLimitersConfig = []common.RateLimiterConfig{
{
Average: 1,
@ -843,7 +866,7 @@ func TestDefender(t *testing.T) {
cfg := config.GetCommonConfig()
cfg.DefenderConfig.Enabled = true
cfg.DefenderConfig.Threshold = 3
cfg.DefenderConfig.ScoreRateExceeded = 2
cfg.DefenderConfig.ScoreLimitExceeded = 2
err := common.Initialize(cfg)
assert.NoError(t, err)

View file

@ -135,13 +135,13 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) {
// ClientConnected is called to send the very first welcome message
func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
common.Connections.AddNetworkConnection()
ipAddr := utils.GetIPFromRemoteAddress(cc.RemoteAddr().String())
common.Connections.AddClientConnection(ipAddr)
if common.IsBanned(ipAddr) {
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, ip %#v is banned", ipAddr)
return "Access denied: banned client IP", common.ErrConnectionDenied
}
if !common.Connections.IsNewConnectionAllowed() {
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
return "Access denied: max allowed connection exceeded", common.ErrConnectionDenied
}
@ -167,7 +167,7 @@ func (s *Server) ClientDisconnected(cc ftpserver.ClientContext) {
s.cleanTLSConnVerification(cc.ID())
connID := fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID())
common.Connections.Remove(connID)
common.Connections.RemoveNetworkConnection()
common.Connections.RemoveClientConnection(utils.GetIPFromRemoteAddress(cc.RemoteAddr().String()))
}
// AuthUser authenticates the user and selects an handling driver

View file

@ -2945,7 +2945,7 @@ func TestDefenderAPI(t *testing.T) {
cfg := config.GetCommonConfig()
cfg.DefenderConfig.Enabled = true
cfg.DefenderConfig.Threshold = 3
cfg.DefenderConfig.ScoreRateExceeded = 2
cfg.DefenderConfig.ScoreLimitExceeded = 2
err := common.Initialize(cfg)
require.NoError(t, err)
@ -4615,7 +4615,7 @@ func TestDefender(t *testing.T) {
cfg := config.GetCommonConfig()
cfg.DefenderConfig.Enabled = true
cfg.DefenderConfig.Threshold = 3
cfg.DefenderConfig.ScoreRateExceeded = 2
cfg.DefenderConfig.ScoreLimitExceeded = 2
err := common.Initialize(cfg)
assert.NoError(t, err)

View file

@ -110,8 +110,10 @@ func (s *httpdServer) refreshCookie(next http.Handler) http.Handler {
func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginPostSize)
common.Connections.AddNetworkConnection()
defer common.Connections.RemoveNetworkConnection()
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
common.Connections.AddClientConnection(ipAddr)
defer common.Connections.RemoveClientConnection(ipAddr)
if err := r.ParseForm(); err != nil {
renderClientLoginPage(w, err.Error())
@ -128,8 +130,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
return
}
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
if !common.Connections.IsNewConnectionAllowed() {
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
renderClientLoginPage(w, "configured connections limit reached")
return

View file

@ -279,20 +279,20 @@ func handleWebClientLogout(w http.ResponseWriter, r *http.Request) {
}
func handleClientGetFiles(w http.ResponseWriter, r *http.Request) {
common.Connections.AddNetworkConnection()
defer common.Connections.RemoveNetworkConnection()
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
common.Connections.AddClientConnection(ipAddr)
defer common.Connections.RemoveClientConnection(ipAddr)
claims, err := getTokenClaims(r)
if err != nil || claims.Username == "" {
renderClientForbiddenPage(w, r, "Invalid token claims")
return
}
if !common.Connections.IsNewConnectionAllowed() {
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
renderClientForbiddenPage(w, r, "configured connections limit reached")
return
}
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
if common.IsBanned(ipAddr) {
renderClientForbiddenPage(w, r, "your IP address is banned")
return

View file

@ -356,7 +356,7 @@ func canAcceptConnection(ip string) bool {
logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, ip %#v is banned", ip)
return false
}
if !common.Connections.IsNewConnectionAllowed() {
if !common.Connections.IsNewConnectionAllowed(ip) {
logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached")
return false
}
@ -378,10 +378,10 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
}
}()
common.Connections.AddNetworkConnection()
defer common.Connections.RemoveNetworkConnection()
ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
common.Connections.AddClientConnection(ipAddr)
defer common.Connections.RemoveClientConnection(ipAddr)
if !canAcceptConnection(ipAddr) {
conn.Close()
return

View file

@ -527,7 +527,7 @@ func TestDefender(t *testing.T) {
cfg := config.GetCommonConfig()
cfg.DefenderConfig.Enabled = true
cfg.DefenderConfig.Threshold = 3
cfg.DefenderConfig.ScoreRateExceeded = 2
cfg.DefenderConfig.ScoreLimitExceeded = 2
err := common.Initialize(cfg)
assert.NoError(t, err)
@ -663,6 +663,9 @@ func TestOpenReadWritePerm(t *testing.T) {
}
func TestConcurrency(t *testing.T) {
oldValue := common.Config.MaxPerHostConnections
common.Config.MaxPerHostConnections = 0
usePubKey := true
numLogins := 50
u := getTestUser(usePubKey)
@ -747,6 +750,8 @@ func TestConcurrency(t *testing.T) {
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
common.Config.MaxPerHostConnections = oldValue
}
func TestProxyProtocol(t *testing.T) {
@ -2896,6 +2901,7 @@ func TestQuotaDisabledError(t *testing.T) {
assert.NoError(t, err)
}
//nolint:dupl
func TestMaxConnections(t *testing.T) {
oldValue := common.Config.MaxTotalConnections
common.Config.MaxTotalConnections = 1
@ -2910,7 +2916,7 @@ func TestMaxConnections(t *testing.T) {
defer client.Close()
assert.NoError(t, checkBasicSFTP(client))
s, c, err := getSftpClient(user, usePubKey)
if !assert.Error(t, err, "max sessions exceeded, new login should not succeed") {
if !assert.Error(t, err, "max total connections exceeded, new login should not succeed") {
c.Close()
s.Close()
}
@ -2923,6 +2929,34 @@ func TestMaxConnections(t *testing.T) {
common.Config.MaxTotalConnections = oldValue
}
//nolint:dupl
func TestMaxPerHostConnections(t *testing.T) {
oldValue := common.Config.MaxPerHostConnections
common.Config.MaxPerHostConnections = 1
usePubKey := true
u := getTestUser(usePubKey)
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
assert.NoError(t, checkBasicSFTP(client))
s, c, err := getSftpClient(user, usePubKey)
if !assert.Error(t, err, "max per host connections exceeded, new login should not succeed") {
c.Close()
s.Close()
}
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
common.Config.MaxPerHostConnections = oldValue
}
func TestMaxSessions(t *testing.T) {
usePubKey := false
u := getTestUser(usePubKey)

View file

@ -12,6 +12,7 @@
"startup_hook": "",
"post_connect_hook": "",
"max_total_connections": 0,
"max_per_host_connections": 20,
"defender": {
"enabled": false,
"ban_time": 30,
@ -19,7 +20,7 @@
"threshold": 15,
"score_invalid": 2,
"score_valid": 1,
"score_rate_exceeded": 3,
"score_limit_exceeded": 3,
"observation_time": 30,
"entries_soft_limit": 100,
"entries_hard_limit": 150,

View file

@ -144,16 +144,18 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
}
}()
common.Connections.AddNetworkConnection()
defer common.Connections.RemoveNetworkConnection()
if !common.Connections.IsNewConnectionAllowed() {
checkRemoteAddress(r)
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
common.Connections.AddClientConnection(ipAddr)
defer common.Connections.RemoveClientConnection(ipAddr)
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", "connection refused, configured limit reached")
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable)
return
}
checkRemoteAddress(r)
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
if common.IsBanned(ipAddr) {
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
return

View file

@ -747,7 +747,7 @@ func TestDefender(t *testing.T) {
cfg := config.GetCommonConfig()
cfg.DefenderConfig.Enabled = true
cfg.DefenderConfig.Threshold = 3
cfg.DefenderConfig.ScoreRateExceeded = 2
cfg.DefenderConfig.ScoreLimitExceeded = 2
err := common.Initialize(cfg)
assert.NoError(t, err)
@ -934,6 +934,33 @@ func TestMaxConnections(t *testing.T) {
common.Config.MaxTotalConnections = oldValue
}
func TestMaxPerHostConnections(t *testing.T) {
oldValue := common.Config.MaxPerHostConnections
common.Config.MaxPerHostConnections = 1
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
assert.NoError(t, err)
client := getWebDavClient(user, true, nil)
assert.NoError(t, checkBasicFunc(client))
// now add a fake connection
addrs, err := net.LookupHost("localhost")
assert.NoError(t, err)
for _, addr := range addrs {
common.Connections.AddClientConnection(addr)
}
assert.Error(t, checkBasicFunc(client))
for _, addr := range addrs {
common.Connections.RemoveClientConnection(addr)
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
assert.Len(t, common.Connections.GetStats(), 0)
common.Config.MaxPerHostConnections = oldValue
}
func TestMaxSessions(t *testing.T) {
u := getTestUser()
u.MaxSessions = 1