mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-22 07:30:25 +00:00
allow to limit the number of per-host connections
This commit is contained in:
parent
8f736da4b8
commit
8f6cdacd00
21 changed files with 356 additions and 105 deletions
51
common/clientsmap.go
Normal file
51
common/clientsmap.go
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/drakkan/sftpgo/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// clienstMap is a struct containing the map of the connected clients
|
||||||
|
type clientsMap struct {
|
||||||
|
totalConnections int32
|
||||||
|
mu sync.RWMutex
|
||||||
|
clients map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientsMap) add(source string) {
|
||||||
|
atomic.AddInt32(&c.totalConnections, 1)
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.clients[source]++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientsMap) remove(source string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if val, ok := c.clients[source]; ok {
|
||||||
|
atomic.AddInt32(&c.totalConnections, -1)
|
||||||
|
c.clients[source]--
|
||||||
|
if val > 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(c.clients, source)
|
||||||
|
} else {
|
||||||
|
logger.Warn(logSender, "", "cannot remove client %v it is not mapped", source)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientsMap) getTotal() int32 {
|
||||||
|
return atomic.LoadInt32(&c.totalConnections)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientsMap) getTotalFrom(source string) int {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
return c.clients[source]
|
||||||
|
}
|
59
common/clientsmap_test.go
Normal file
59
common/clientsmap_test.go
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClientsMap(t *testing.T) {
|
||||||
|
m := clientsMap{
|
||||||
|
clients: make(map[string]int),
|
||||||
|
}
|
||||||
|
ip1 := "192.168.1.1"
|
||||||
|
ip2 := "192.168.1.2"
|
||||||
|
m.add(ip1)
|
||||||
|
assert.Equal(t, int32(1), m.getTotal())
|
||||||
|
assert.Equal(t, 1, m.getTotalFrom(ip1))
|
||||||
|
assert.Equal(t, 0, m.getTotalFrom(ip2))
|
||||||
|
|
||||||
|
m.add(ip1)
|
||||||
|
m.add(ip2)
|
||||||
|
assert.Equal(t, int32(3), m.getTotal())
|
||||||
|
assert.Equal(t, 2, m.getTotalFrom(ip1))
|
||||||
|
assert.Equal(t, 1, m.getTotalFrom(ip2))
|
||||||
|
|
||||||
|
m.add(ip1)
|
||||||
|
m.add(ip1)
|
||||||
|
m.add(ip2)
|
||||||
|
assert.Equal(t, int32(6), m.getTotal())
|
||||||
|
assert.Equal(t, 4, m.getTotalFrom(ip1))
|
||||||
|
assert.Equal(t, 2, m.getTotalFrom(ip2))
|
||||||
|
|
||||||
|
m.remove(ip2)
|
||||||
|
assert.Equal(t, int32(5), m.getTotal())
|
||||||
|
assert.Equal(t, 4, m.getTotalFrom(ip1))
|
||||||
|
assert.Equal(t, 1, m.getTotalFrom(ip2))
|
||||||
|
|
||||||
|
m.remove("unknown")
|
||||||
|
assert.Equal(t, int32(5), m.getTotal())
|
||||||
|
assert.Equal(t, 4, m.getTotalFrom(ip1))
|
||||||
|
assert.Equal(t, 1, m.getTotalFrom(ip2))
|
||||||
|
|
||||||
|
m.remove(ip2)
|
||||||
|
assert.Equal(t, int32(4), m.getTotal())
|
||||||
|
assert.Equal(t, 4, m.getTotalFrom(ip1))
|
||||||
|
assert.Equal(t, 0, m.getTotalFrom(ip2))
|
||||||
|
|
||||||
|
m.remove(ip1)
|
||||||
|
m.remove(ip1)
|
||||||
|
m.remove(ip1)
|
||||||
|
assert.Equal(t, int32(1), m.getTotal())
|
||||||
|
assert.Equal(t, 1, m.getTotalFrom(ip1))
|
||||||
|
assert.Equal(t, 0, m.getTotalFrom(ip2))
|
||||||
|
|
||||||
|
m.remove(ip1)
|
||||||
|
assert.Equal(t, int32(0), m.getTotal())
|
||||||
|
assert.Equal(t, 0, m.getTotalFrom(ip1))
|
||||||
|
assert.Equal(t, 0, m.getTotalFrom(ip2))
|
||||||
|
}
|
|
@ -80,6 +80,12 @@ const (
|
||||||
UploadModeAtomicWithResume
|
UploadModeAtomicWithResume
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
Connections.clients = clientsMap{
|
||||||
|
clients: make(map[string]int),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// errors definitions
|
// errors definitions
|
||||||
var (
|
var (
|
||||||
ErrPermissionDenied = errors.New("permission denied")
|
ErrPermissionDenied = errors.New("permission denied")
|
||||||
|
@ -352,6 +358,8 @@ type Configuration struct {
|
||||||
PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"`
|
PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"`
|
||||||
// Maximum number of concurrent client connections. 0 means unlimited
|
// Maximum number of concurrent client connections. 0 means unlimited
|
||||||
MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"`
|
MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"`
|
||||||
|
// Maximum number of concurrent client connections from the same host (IP). 0 means unlimited
|
||||||
|
MaxPerHostConnections int `json:"max_per_host_connections" mapstructure:"max_per_host_connections"`
|
||||||
// Defender configuration
|
// Defender configuration
|
||||||
DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"`
|
DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"`
|
||||||
// Rate limiter configurations
|
// Rate limiter configurations
|
||||||
|
@ -524,9 +532,9 @@ func (c *SSHConnection) Close() error {
|
||||||
|
|
||||||
// ActiveConnections holds the currect active connections with the associated transfers
|
// ActiveConnections holds the currect active connections with the associated transfers
|
||||||
type ActiveConnections struct {
|
type ActiveConnections struct {
|
||||||
// networkConnections is the counter for the network connections, it contains
|
// clients contains both authenticated and estabilished connections and the ones waiting
|
||||||
// both authenticated and estabilished connections and the ones waiting for authentication
|
// for authentication
|
||||||
networkConnections int32
|
clients clientsMap
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
connections []ActiveConnection
|
connections []ActiveConnection
|
||||||
sshConnections []*SSHConnection
|
sshConnections []*SSHConnection
|
||||||
|
@ -693,27 +701,36 @@ func (conns *ActiveConnections) checkIdles() {
|
||||||
conns.RUnlock()
|
conns.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddNetworkConnection increments the network connections counter
|
// AddClientConnection stores a new client connection
|
||||||
func (conns *ActiveConnections) AddNetworkConnection() {
|
func (conns *ActiveConnections) AddClientConnection(ipAddr string) {
|
||||||
atomic.AddInt32(&conns.networkConnections, 1)
|
conns.clients.add(ipAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveNetworkConnection decrements the network connections counter
|
// RemoveClientConnection removes a disconnected client from the tracked ones
|
||||||
func (conns *ActiveConnections) RemoveNetworkConnection() {
|
func (conns *ActiveConnections) RemoveClientConnection(ipAddr string) {
|
||||||
atomic.AddInt32(&conns.networkConnections, -1)
|
conns.clients.remove(ipAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
|
// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
|
||||||
func (conns *ActiveConnections) IsNewConnectionAllowed() bool {
|
func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool {
|
||||||
if Config.MaxTotalConnections == 0 {
|
if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
num := atomic.LoadInt32(&conns.networkConnections)
|
if Config.MaxPerHostConnections > 0 {
|
||||||
if num > int32(Config.MaxTotalConnections) {
|
if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections {
|
||||||
logger.Debug(logSender, "", "active network connections %v/%v", num, Config.MaxTotalConnections)
|
logger.Debug(logSender, "", "active connections from %v %v/%v", ipAddr, total, Config.MaxPerHostConnections)
|
||||||
|
AddDefenderEvent(ipAddr, HostEventLimitExceeded)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if Config.MaxTotalConnections > 0 {
|
||||||
|
if total := conns.clients.getTotal(); total > int32(Config.MaxTotalConnections) {
|
||||||
|
logger.Debug(logSender, "", "active client connections %v/%v", total, Config.MaxTotalConnections)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// on a single SFTP connection we could have multiple SFTP channels or commands
|
// on a single SFTP connection we could have multiple SFTP channels or commands
|
||||||
// so we check the estabilished connections too
|
// so we check the estabilished connections too
|
||||||
|
|
||||||
|
@ -721,6 +738,9 @@ func (conns *ActiveConnections) IsNewConnectionAllowed() bool {
|
||||||
defer conns.RUnlock()
|
defer conns.RUnlock()
|
||||||
|
|
||||||
return len(conns.connections) < Config.MaxTotalConnections
|
return len(conns.connections) < Config.MaxTotalConnections
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStats returns stats for active connections
|
// GetStats returns stats for active connections
|
||||||
|
|
|
@ -228,32 +228,61 @@ func TestRateLimitersIntegration(t *testing.T) {
|
||||||
|
|
||||||
func TestMaxConnections(t *testing.T) {
|
func TestMaxConnections(t *testing.T) {
|
||||||
oldValue := Config.MaxTotalConnections
|
oldValue := Config.MaxTotalConnections
|
||||||
Config.MaxTotalConnections = 1
|
perHost := Config.MaxPerHostConnections
|
||||||
|
|
||||||
assert.True(t, Connections.IsNewConnectionAllowed())
|
Config.MaxPerHostConnections = 0
|
||||||
|
|
||||||
|
ipAddr := "192.168.7.8"
|
||||||
|
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||||
|
|
||||||
|
Config.MaxTotalConnections = 1
|
||||||
|
Config.MaxPerHostConnections = perHost
|
||||||
|
|
||||||
|
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||||
c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{})
|
c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{})
|
||||||
fakeConn := &fakeConnection{
|
fakeConn := &fakeConnection{
|
||||||
BaseConnection: c,
|
BaseConnection: c,
|
||||||
}
|
}
|
||||||
Connections.Add(fakeConn)
|
Connections.Add(fakeConn)
|
||||||
assert.Len(t, Connections.GetStats(), 1)
|
assert.Len(t, Connections.GetStats(), 1)
|
||||||
assert.False(t, Connections.IsNewConnectionAllowed())
|
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||||
|
|
||||||
res := Connections.Close(fakeConn.GetID())
|
res := Connections.Close(fakeConn.GetID())
|
||||||
assert.True(t, res)
|
assert.True(t, res)
|
||||||
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
|
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
|
||||||
|
|
||||||
assert.True(t, Connections.IsNewConnectionAllowed())
|
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||||
Connections.AddNetworkConnection()
|
Connections.AddClientConnection(ipAddr)
|
||||||
Connections.AddNetworkConnection()
|
Connections.AddClientConnection(ipAddr)
|
||||||
assert.False(t, Connections.IsNewConnectionAllowed())
|
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||||
Connections.RemoveNetworkConnection()
|
Connections.RemoveClientConnection(ipAddr)
|
||||||
assert.True(t, Connections.IsNewConnectionAllowed())
|
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||||
Connections.RemoveNetworkConnection()
|
Connections.RemoveClientConnection(ipAddr)
|
||||||
|
|
||||||
Config.MaxTotalConnections = oldValue
|
Config.MaxTotalConnections = oldValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMaxConnectionPerHost(t *testing.T) {
|
||||||
|
oldValue := Config.MaxPerHostConnections
|
||||||
|
|
||||||
|
Config.MaxPerHostConnections = 2
|
||||||
|
|
||||||
|
ipAddr := "192.168.9.9"
|
||||||
|
Connections.AddClientConnection(ipAddr)
|
||||||
|
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||||
|
|
||||||
|
Connections.AddClientConnection(ipAddr)
|
||||||
|
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||||
|
|
||||||
|
Connections.AddClientConnection(ipAddr)
|
||||||
|
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
|
||||||
|
|
||||||
|
Connections.RemoveClientConnection(ipAddr)
|
||||||
|
Connections.RemoveClientConnection(ipAddr)
|
||||||
|
|
||||||
|
Config.MaxPerHostConnections = oldValue
|
||||||
|
}
|
||||||
|
|
||||||
func TestIdleConnections(t *testing.T) {
|
func TestIdleConnections(t *testing.T) {
|
||||||
configCopy := Config
|
configCopy := Config
|
||||||
|
|
||||||
|
@ -340,7 +369,7 @@ func TestCloseConnection(t *testing.T) {
|
||||||
fakeConn := &fakeConnection{
|
fakeConn := &fakeConnection{
|
||||||
BaseConnection: c,
|
BaseConnection: c,
|
||||||
}
|
}
|
||||||
assert.True(t, Connections.IsNewConnectionAllowed())
|
assert.True(t, Connections.IsNewConnectionAllowed("127.0.0.1"))
|
||||||
Connections.Add(fakeConn)
|
Connections.Add(fakeConn)
|
||||||
assert.Len(t, Connections.GetStats(), 1)
|
assert.Len(t, Connections.GetStats(), 1)
|
||||||
res := Connections.Close(fakeConn.GetID())
|
res := Connections.Close(fakeConn.GetID())
|
||||||
|
|
|
@ -23,7 +23,7 @@ const (
|
||||||
HostEventLoginFailed HostEvent = iota
|
HostEventLoginFailed HostEvent = iota
|
||||||
HostEventUserNotFound
|
HostEventUserNotFound
|
||||||
HostEventNoLoginTried
|
HostEventNoLoginTried
|
||||||
HostEventRateExceeded
|
HostEventLimitExceeded
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defender defines the interface that a defender must implements
|
// Defender defines the interface that a defender must implements
|
||||||
|
@ -51,8 +51,9 @@ type DefenderConfig struct {
|
||||||
ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"`
|
ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"`
|
||||||
// Score for valid login attempts, eg. user accounts that exist
|
// Score for valid login attempts, eg. user accounts that exist
|
||||||
ScoreValid int `json:"score_valid" mapstructure:"score_valid"`
|
ScoreValid int `json:"score_valid" mapstructure:"score_valid"`
|
||||||
// Score for rate exceeded events, generated from the rate limiters
|
// Score for limit exceeded events, generated from the rate limiters or for max connections
|
||||||
ScoreRateExceeded int `json:"score_rate_exceeded" mapstructure:"score_rate_exceeded"`
|
// per-host exceeded
|
||||||
|
ScoreLimitExceeded int `json:"score_limit_exceeded" mapstructure:"score_limit_exceeded"`
|
||||||
// Defines the time window, in minutes, for tracking client errors.
|
// Defines the time window, in minutes, for tracking client errors.
|
||||||
// A host is banned if it has exceeded the defined threshold during
|
// A host is banned if it has exceeded the defined threshold during
|
||||||
// the last observation time minutes
|
// the last observation time minutes
|
||||||
|
@ -126,8 +127,8 @@ func (c *DefenderConfig) validate() error {
|
||||||
if c.ScoreValid >= c.Threshold {
|
if c.ScoreValid >= c.Threshold {
|
||||||
return fmt.Errorf("score_valid %v cannot be greater than threshold %v", c.ScoreValid, c.Threshold)
|
return fmt.Errorf("score_valid %v cannot be greater than threshold %v", c.ScoreValid, c.Threshold)
|
||||||
}
|
}
|
||||||
if c.ScoreRateExceeded >= c.Threshold {
|
if c.ScoreLimitExceeded >= c.Threshold {
|
||||||
return fmt.Errorf("score_rate_exceeded %v cannot be greater than threshold %v", c.ScoreRateExceeded, c.Threshold)
|
return fmt.Errorf("score_limit_exceeded %v cannot be greater than threshold %v", c.ScoreLimitExceeded, c.Threshold)
|
||||||
}
|
}
|
||||||
if c.BanTime <= 0 {
|
if c.BanTime <= 0 {
|
||||||
return fmt.Errorf("invalid ban_time %v", c.BanTime)
|
return fmt.Errorf("invalid ban_time %v", c.BanTime)
|
||||||
|
@ -254,8 +255,8 @@ func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
|
||||||
switch event {
|
switch event {
|
||||||
case HostEventLoginFailed:
|
case HostEventLoginFailed:
|
||||||
score = d.config.ScoreValid
|
score = d.config.ScoreValid
|
||||||
case HostEventRateExceeded:
|
case HostEventLimitExceeded:
|
||||||
score = d.config.ScoreRateExceeded
|
score = d.config.ScoreLimitExceeded
|
||||||
case HostEventUserNotFound, HostEventNoLoginTried:
|
case HostEventUserNotFound, HostEventNoLoginTried:
|
||||||
score = d.config.ScoreInvalid
|
score = d.config.ScoreInvalid
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,7 +47,7 @@ func TestBasicDefender(t *testing.T) {
|
||||||
Threshold: 5,
|
Threshold: 5,
|
||||||
ScoreInvalid: 2,
|
ScoreInvalid: 2,
|
||||||
ScoreValid: 1,
|
ScoreValid: 1,
|
||||||
ScoreRateExceeded: 3,
|
ScoreLimitExceeded: 3,
|
||||||
ObservationTime: 15,
|
ObservationTime: 15,
|
||||||
EntriesSoftLimit: 1,
|
EntriesSoftLimit: 1,
|
||||||
EntriesHardLimit: 2,
|
EntriesHardLimit: 2,
|
||||||
|
@ -75,7 +75,7 @@ func TestBasicDefender(t *testing.T) {
|
||||||
|
|
||||||
defender.AddEvent("172.16.1.4", HostEventLoginFailed)
|
defender.AddEvent("172.16.1.4", HostEventLoginFailed)
|
||||||
defender.AddEvent("192.168.8.4", HostEventUserNotFound)
|
defender.AddEvent("192.168.8.4", HostEventUserNotFound)
|
||||||
defender.AddEvent("172.16.1.3", HostEventRateExceeded)
|
defender.AddEvent("172.16.1.3", HostEventLimitExceeded)
|
||||||
assert.Equal(t, 0, defender.countHosts())
|
assert.Equal(t, 0, defender.countHosts())
|
||||||
|
|
||||||
testIP := "12.34.56.78"
|
testIP := "12.34.56.78"
|
||||||
|
@ -84,7 +84,7 @@ func TestBasicDefender(t *testing.T) {
|
||||||
assert.Equal(t, 0, defender.countBanned())
|
assert.Equal(t, 0, defender.countBanned())
|
||||||
assert.Equal(t, 1, defender.GetScore(testIP))
|
assert.Equal(t, 1, defender.GetScore(testIP))
|
||||||
assert.Nil(t, defender.GetBanTime(testIP))
|
assert.Nil(t, defender.GetBanTime(testIP))
|
||||||
defender.AddEvent(testIP, HostEventRateExceeded)
|
defender.AddEvent(testIP, HostEventLimitExceeded)
|
||||||
assert.Equal(t, 1, defender.countHosts())
|
assert.Equal(t, 1, defender.countHosts())
|
||||||
assert.Equal(t, 0, defender.countBanned())
|
assert.Equal(t, 0, defender.countBanned())
|
||||||
assert.Equal(t, 4, defender.GetScore(testIP))
|
assert.Equal(t, 4, defender.GetScore(testIP))
|
||||||
|
@ -317,11 +317,11 @@ func TestDefenderConfig(t *testing.T) {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
c.ScoreInvalid = 2
|
c.ScoreInvalid = 2
|
||||||
c.ScoreRateExceeded = 10
|
c.ScoreLimitExceeded = 10
|
||||||
err = c.validate()
|
err = c.validate()
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
c.ScoreRateExceeded = 2
|
c.ScoreLimitExceeded = 2
|
||||||
c.ScoreValid = 10
|
c.ScoreValid = 10
|
||||||
err = c.validate()
|
err = c.validate()
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
|
@ -149,7 +149,7 @@ func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
|
||||||
if delay > rl.maxDelay {
|
if delay > rl.maxDelay {
|
||||||
res.Cancel()
|
res.Cancel()
|
||||||
if rl.generateDefenderEvents && rl.globalBucket == nil {
|
if rl.generateDefenderEvents && rl.globalBucket == nil {
|
||||||
AddDefenderEvent(source, HostEventRateExceeded)
|
AddDefenderEvent(source, HostEventLimitExceeded)
|
||||||
}
|
}
|
||||||
return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
|
return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,6 +116,7 @@ func Init() {
|
||||||
ProxyAllowed: []string{},
|
ProxyAllowed: []string{},
|
||||||
PostConnectHook: "",
|
PostConnectHook: "",
|
||||||
MaxTotalConnections: 0,
|
MaxTotalConnections: 0,
|
||||||
|
MaxPerHostConnections: 20,
|
||||||
DefenderConfig: common.DefenderConfig{
|
DefenderConfig: common.DefenderConfig{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
BanTime: 30,
|
BanTime: 30,
|
||||||
|
@ -123,7 +124,7 @@ func Init() {
|
||||||
Threshold: 15,
|
Threshold: 15,
|
||||||
ScoreInvalid: 2,
|
ScoreInvalid: 2,
|
||||||
ScoreValid: 1,
|
ScoreValid: 1,
|
||||||
ScoreRateExceeded: 3,
|
ScoreLimitExceeded: 3,
|
||||||
ObservationTime: 30,
|
ObservationTime: 30,
|
||||||
EntriesSoftLimit: 100,
|
EntriesSoftLimit: 100,
|
||||||
EntriesHardLimit: 150,
|
EntriesHardLimit: 150,
|
||||||
|
@ -873,13 +874,14 @@ func setViperDefaults() {
|
||||||
viper.SetDefault("common.proxy_allowed", globalConf.Common.ProxyAllowed)
|
viper.SetDefault("common.proxy_allowed", globalConf.Common.ProxyAllowed)
|
||||||
viper.SetDefault("common.post_connect_hook", globalConf.Common.PostConnectHook)
|
viper.SetDefault("common.post_connect_hook", globalConf.Common.PostConnectHook)
|
||||||
viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections)
|
viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections)
|
||||||
|
viper.SetDefault("common.max_per_host_connections", globalConf.Common.MaxPerHostConnections)
|
||||||
viper.SetDefault("common.defender.enabled", globalConf.Common.DefenderConfig.Enabled)
|
viper.SetDefault("common.defender.enabled", globalConf.Common.DefenderConfig.Enabled)
|
||||||
viper.SetDefault("common.defender.ban_time", globalConf.Common.DefenderConfig.BanTime)
|
viper.SetDefault("common.defender.ban_time", globalConf.Common.DefenderConfig.BanTime)
|
||||||
viper.SetDefault("common.defender.ban_time_increment", globalConf.Common.DefenderConfig.BanTimeIncrement)
|
viper.SetDefault("common.defender.ban_time_increment", globalConf.Common.DefenderConfig.BanTimeIncrement)
|
||||||
viper.SetDefault("common.defender.threshold", globalConf.Common.DefenderConfig.Threshold)
|
viper.SetDefault("common.defender.threshold", globalConf.Common.DefenderConfig.Threshold)
|
||||||
viper.SetDefault("common.defender.score_invalid", globalConf.Common.DefenderConfig.ScoreInvalid)
|
viper.SetDefault("common.defender.score_invalid", globalConf.Common.DefenderConfig.ScoreInvalid)
|
||||||
viper.SetDefault("common.defender.score_valid", globalConf.Common.DefenderConfig.ScoreValid)
|
viper.SetDefault("common.defender.score_valid", globalConf.Common.DefenderConfig.ScoreValid)
|
||||||
viper.SetDefault("common.defender.score_rate_exceeded", globalConf.Common.DefenderConfig.ScoreRateExceeded)
|
viper.SetDefault("common.defender.score_limit_exceeded", globalConf.Common.DefenderConfig.ScoreLimitExceeded)
|
||||||
viper.SetDefault("common.defender.observation_time", globalConf.Common.DefenderConfig.ObservationTime)
|
viper.SetDefault("common.defender.observation_time", globalConf.Common.DefenderConfig.ObservationTime)
|
||||||
viper.SetDefault("common.defender.entries_soft_limit", globalConf.Common.DefenderConfig.EntriesSoftLimit)
|
viper.SetDefault("common.defender.entries_soft_limit", globalConf.Common.DefenderConfig.EntriesSoftLimit)
|
||||||
viper.SetDefault("common.defender.entries_hard_limit", globalConf.Common.DefenderConfig.EntriesHardLimit)
|
viper.SetDefault("common.defender.entries_hard_limit", globalConf.Common.DefenderConfig.EntriesHardLimit)
|
||||||
|
|
|
@ -8,7 +8,7 @@ You can configure a score for each event type:
|
||||||
|
|
||||||
- `score_valid`, defines the score for valid login attempts, eg. user accounts that exist. Default `1`.
|
- `score_valid`, defines the score for valid login attempts, eg. user accounts that exist. Default `1`.
|
||||||
- `score_invalid`, defines the score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts. Default `2`.
|
- `score_invalid`, defines the score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts. Default `2`.
|
||||||
- `score_rate_exceeded`, defines the score for hosts that exceeded the configured rate limits. Default `3`.
|
- `score_limit_exceeded`, defines the score for hosts that exceeded the configured rate limits or the configured max connections per host. Default `3`.
|
||||||
|
|
||||||
And then you can configure:
|
And then you can configure:
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,8 @@ The configuration file contains the following sections:
|
||||||
- If `proxy_protocol` is set to 2 and we receive a proxy header from an IP that is not in the list then the connection will be rejected
|
- If `proxy_protocol` is set to 2 and we receive a proxy header from an IP that is not in the list then the connection will be rejected
|
||||||
- `startup_hook`, string. Absolute path to an external program or an HTTP URL to invoke as soon as SFTPGo starts. If you define an HTTP URL it will be invoked using a `GET` request. Please note that SFTPGo services may not yet be available when this hook is run. Leave empty do disable
|
- `startup_hook`, string. Absolute path to an external program or an HTTP URL to invoke as soon as SFTPGo starts. If you define an HTTP URL it will be invoked using a `GET` request. Please note that SFTPGo services may not yet be available when this hook is run. Leave empty do disable
|
||||||
- `post_connect_hook`, string. Absolute path to the command to execute or HTTP URL to notify. See [Post connect hook](./post-connect-hook.md) for more details. Leave empty to disable
|
- `post_connect_hook`, string. Absolute path to the command to execute or HTTP URL to notify. See [Post connect hook](./post-connect-hook.md) for more details. Leave empty to disable
|
||||||
- `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited
|
- `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited. Default: 0.
|
||||||
|
- `max_per_host_connections`, integer. Maximum number of concurrent client connections from the same host (IP). If the defender is enabled, exceeding this limit will generate `score_limit_exceeded` events and thus hosts that repeatedly exceed the max allowed connections can be automatically blocked. 0 means unlimited. Default: 20.
|
||||||
- `defender`, struct containing the defender configuration. See [Defender](./defender.md) for more details.
|
- `defender`, struct containing the defender configuration. See [Defender](./defender.md) for more details.
|
||||||
- `enabled`, boolean. Default `false`.
|
- `enabled`, boolean. Default `false`.
|
||||||
- `ban_time`, integer. Ban time in minutes.
|
- `ban_time`, integer. Ban time in minutes.
|
||||||
|
@ -72,7 +73,7 @@ The configuration file contains the following sections:
|
||||||
- `threshold`, integer. Threshold value for banning a client.
|
- `threshold`, integer. Threshold value for banning a client.
|
||||||
- `score_invalid`, integer. Score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts.
|
- `score_invalid`, integer. Score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts.
|
||||||
- `score_valid`, integer. Score for valid login attempts, eg. user accounts that exist.
|
- `score_valid`, integer. Score for valid login attempts, eg. user accounts that exist.
|
||||||
- `score_rate_exceeded`, integer. Score for hosts that exceeded the configured rate limits.
|
- `score_limit_exceeded`, integer. Score for hosts that exceeded the configured rate limits or the maximum, per-host, allowed connections.
|
||||||
- `observation_time`, integer. Defines the time window, in minutes, for tracking client errors. A host is banned if it has exceeded the defined threshold during the last observation time minutes.
|
- `observation_time`, integer. Defines the time window, in minutes, for tracking client errors. A host is banned if it has exceeded the defined threshold during the last observation time minutes.
|
||||||
- `entries_soft_limit`, integer.
|
- `entries_soft_limit`, integer.
|
||||||
- `entries_hard_limit`, integer. The number of banned IPs and host scores kept in memory will vary between the soft and hard limit.
|
- `entries_hard_limit`, integer. The number of banned IPs and host scores kept in memory will vary between the soft and hard limit.
|
||||||
|
|
|
@ -18,7 +18,7 @@ The supported protocols are:
|
||||||
You can also define two types of rate limiters:
|
You can also define two types of rate limiters:
|
||||||
|
|
||||||
- global, it is independent from the source host and therefore define an aggregate limit for the configured protocol/s
|
- global, it is independent from the source host and therefore define an aggregate limit for the configured protocol/s
|
||||||
- per-host, this type of rate limiter can be connected to the built-in [defender](./defender.md) and generate `score_rate_exceeded` events and thus hosts that repeatedly exceed the configured limit can be automatically blocked
|
- per-host, this type of rate limiter can be connected to the built-in [defender](./defender.md) and generate `score_limit_exceeded` events and thus hosts that repeatedly exceed the configured limit can be automatically blocked
|
||||||
|
|
||||||
If you configure a per-host rate limiter, SFTPGo will keep a rate limiter in memory for each host that connects to the service, you can limit the memory usage using the `entries_soft_limit` and `entries_hard_limit` configuration keys.
|
If you configure a per-host rate limiter, SFTPGo will keep a rate limiter in memory for each host that connects to the service, you can limit the memory usage using the `entries_soft_limit` and `entries_hard_limit` configuration keys.
|
||||||
|
|
||||||
|
|
|
@ -779,13 +779,36 @@ func TestMaxConnections(t *testing.T) {
|
||||||
common.Config.MaxTotalConnections = oldValue
|
common.Config.MaxTotalConnections = oldValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMaxPerHostConnections(t *testing.T) {
|
||||||
|
oldValue := common.Config.MaxPerHostConnections
|
||||||
|
common.Config.MaxPerHostConnections = 1
|
||||||
|
|
||||||
|
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
client, err := getFTPClient(user, true, nil)
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
err = checkBasicFTP(client)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_, err = getFTPClient(user, false, nil)
|
||||||
|
assert.Error(t, err)
|
||||||
|
err = client.Quit()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = os.RemoveAll(user.GetHomeDir())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
common.Config.MaxPerHostConnections = oldValue
|
||||||
|
}
|
||||||
|
|
||||||
func TestRateLimiter(t *testing.T) {
|
func TestRateLimiter(t *testing.T) {
|
||||||
oldConfig := config.GetCommonConfig()
|
oldConfig := config.GetCommonConfig()
|
||||||
|
|
||||||
cfg := config.GetCommonConfig()
|
cfg := config.GetCommonConfig()
|
||||||
cfg.DefenderConfig.Enabled = true
|
cfg.DefenderConfig.Enabled = true
|
||||||
cfg.DefenderConfig.Threshold = 5
|
cfg.DefenderConfig.Threshold = 5
|
||||||
cfg.DefenderConfig.ScoreRateExceeded = 3
|
cfg.DefenderConfig.ScoreLimitExceeded = 3
|
||||||
cfg.RateLimitersConfig = []common.RateLimiterConfig{
|
cfg.RateLimitersConfig = []common.RateLimiterConfig{
|
||||||
{
|
{
|
||||||
Average: 1,
|
Average: 1,
|
||||||
|
@ -843,7 +866,7 @@ func TestDefender(t *testing.T) {
|
||||||
cfg := config.GetCommonConfig()
|
cfg := config.GetCommonConfig()
|
||||||
cfg.DefenderConfig.Enabled = true
|
cfg.DefenderConfig.Enabled = true
|
||||||
cfg.DefenderConfig.Threshold = 3
|
cfg.DefenderConfig.Threshold = 3
|
||||||
cfg.DefenderConfig.ScoreRateExceeded = 2
|
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||||
|
|
||||||
err := common.Initialize(cfg)
|
err := common.Initialize(cfg)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
|
@ -135,13 +135,13 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) {
|
||||||
|
|
||||||
// ClientConnected is called to send the very first welcome message
|
// ClientConnected is called to send the very first welcome message
|
||||||
func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
|
func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
|
||||||
common.Connections.AddNetworkConnection()
|
|
||||||
ipAddr := utils.GetIPFromRemoteAddress(cc.RemoteAddr().String())
|
ipAddr := utils.GetIPFromRemoteAddress(cc.RemoteAddr().String())
|
||||||
|
common.Connections.AddClientConnection(ipAddr)
|
||||||
if common.IsBanned(ipAddr) {
|
if common.IsBanned(ipAddr) {
|
||||||
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, ip %#v is banned", ipAddr)
|
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, ip %#v is banned", ipAddr)
|
||||||
return "Access denied: banned client IP", common.ErrConnectionDenied
|
return "Access denied: banned client IP", common.ErrConnectionDenied
|
||||||
}
|
}
|
||||||
if !common.Connections.IsNewConnectionAllowed() {
|
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
|
||||||
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
|
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
|
||||||
return "Access denied: max allowed connection exceeded", common.ErrConnectionDenied
|
return "Access denied: max allowed connection exceeded", common.ErrConnectionDenied
|
||||||
}
|
}
|
||||||
|
@ -167,7 +167,7 @@ func (s *Server) ClientDisconnected(cc ftpserver.ClientContext) {
|
||||||
s.cleanTLSConnVerification(cc.ID())
|
s.cleanTLSConnVerification(cc.ID())
|
||||||
connID := fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID())
|
connID := fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID())
|
||||||
common.Connections.Remove(connID)
|
common.Connections.Remove(connID)
|
||||||
common.Connections.RemoveNetworkConnection()
|
common.Connections.RemoveClientConnection(utils.GetIPFromRemoteAddress(cc.RemoteAddr().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthUser authenticates the user and selects an handling driver
|
// AuthUser authenticates the user and selects an handling driver
|
||||||
|
|
|
@ -2945,7 +2945,7 @@ func TestDefenderAPI(t *testing.T) {
|
||||||
cfg := config.GetCommonConfig()
|
cfg := config.GetCommonConfig()
|
||||||
cfg.DefenderConfig.Enabled = true
|
cfg.DefenderConfig.Enabled = true
|
||||||
cfg.DefenderConfig.Threshold = 3
|
cfg.DefenderConfig.Threshold = 3
|
||||||
cfg.DefenderConfig.ScoreRateExceeded = 2
|
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||||
|
|
||||||
err := common.Initialize(cfg)
|
err := common.Initialize(cfg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -4615,7 +4615,7 @@ func TestDefender(t *testing.T) {
|
||||||
cfg := config.GetCommonConfig()
|
cfg := config.GetCommonConfig()
|
||||||
cfg.DefenderConfig.Enabled = true
|
cfg.DefenderConfig.Enabled = true
|
||||||
cfg.DefenderConfig.Threshold = 3
|
cfg.DefenderConfig.Threshold = 3
|
||||||
cfg.DefenderConfig.ScoreRateExceeded = 2
|
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||||
|
|
||||||
err := common.Initialize(cfg)
|
err := common.Initialize(cfg)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
|
@ -110,8 +110,10 @@ func (s *httpdServer) refreshCookie(next http.Handler) http.Handler {
|
||||||
|
|
||||||
func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) {
|
func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) {
|
||||||
r.Body = http.MaxBytesReader(w, r.Body, maxLoginPostSize)
|
r.Body = http.MaxBytesReader(w, r.Body, maxLoginPostSize)
|
||||||
common.Connections.AddNetworkConnection()
|
|
||||||
defer common.Connections.RemoveNetworkConnection()
|
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||||
|
common.Connections.AddClientConnection(ipAddr)
|
||||||
|
defer common.Connections.RemoveClientConnection(ipAddr)
|
||||||
|
|
||||||
if err := r.ParseForm(); err != nil {
|
if err := r.ParseForm(); err != nil {
|
||||||
renderClientLoginPage(w, err.Error())
|
renderClientLoginPage(w, err.Error())
|
||||||
|
@ -128,8 +130,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
|
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
|
||||||
if !common.Connections.IsNewConnectionAllowed() {
|
|
||||||
logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
|
logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
|
||||||
renderClientLoginPage(w, "configured connections limit reached")
|
renderClientLoginPage(w, "configured connections limit reached")
|
||||||
return
|
return
|
||||||
|
|
|
@ -279,20 +279,20 @@ func handleWebClientLogout(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleClientGetFiles(w http.ResponseWriter, r *http.Request) {
|
func handleClientGetFiles(w http.ResponseWriter, r *http.Request) {
|
||||||
common.Connections.AddNetworkConnection()
|
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||||
defer common.Connections.RemoveNetworkConnection()
|
common.Connections.AddClientConnection(ipAddr)
|
||||||
|
defer common.Connections.RemoveClientConnection(ipAddr)
|
||||||
|
|
||||||
claims, err := getTokenClaims(r)
|
claims, err := getTokenClaims(r)
|
||||||
if err != nil || claims.Username == "" {
|
if err != nil || claims.Username == "" {
|
||||||
renderClientForbiddenPage(w, r, "Invalid token claims")
|
renderClientForbiddenPage(w, r, "Invalid token claims")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !common.Connections.IsNewConnectionAllowed() {
|
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
|
||||||
logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
|
logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
|
||||||
renderClientForbiddenPage(w, r, "configured connections limit reached")
|
renderClientForbiddenPage(w, r, "configured connections limit reached")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
|
|
||||||
if common.IsBanned(ipAddr) {
|
if common.IsBanned(ipAddr) {
|
||||||
renderClientForbiddenPage(w, r, "your IP address is banned")
|
renderClientForbiddenPage(w, r, "your IP address is banned")
|
||||||
return
|
return
|
||||||
|
|
|
@ -356,7 +356,7 @@ func canAcceptConnection(ip string) bool {
|
||||||
logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, ip %#v is banned", ip)
|
logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, ip %#v is banned", ip)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if !common.Connections.IsNewConnectionAllowed() {
|
if !common.Connections.IsNewConnectionAllowed(ip) {
|
||||||
logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached")
|
logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -378,10 +378,10 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
common.Connections.AddNetworkConnection()
|
|
||||||
defer common.Connections.RemoveNetworkConnection()
|
|
||||||
|
|
||||||
ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
|
ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
|
||||||
|
common.Connections.AddClientConnection(ipAddr)
|
||||||
|
defer common.Connections.RemoveClientConnection(ipAddr)
|
||||||
|
|
||||||
if !canAcceptConnection(ipAddr) {
|
if !canAcceptConnection(ipAddr) {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return
|
return
|
||||||
|
|
|
@ -527,7 +527,7 @@ func TestDefender(t *testing.T) {
|
||||||
cfg := config.GetCommonConfig()
|
cfg := config.GetCommonConfig()
|
||||||
cfg.DefenderConfig.Enabled = true
|
cfg.DefenderConfig.Enabled = true
|
||||||
cfg.DefenderConfig.Threshold = 3
|
cfg.DefenderConfig.Threshold = 3
|
||||||
cfg.DefenderConfig.ScoreRateExceeded = 2
|
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||||
|
|
||||||
err := common.Initialize(cfg)
|
err := common.Initialize(cfg)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -663,6 +663,9 @@ func TestOpenReadWritePerm(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConcurrency(t *testing.T) {
|
func TestConcurrency(t *testing.T) {
|
||||||
|
oldValue := common.Config.MaxPerHostConnections
|
||||||
|
common.Config.MaxPerHostConnections = 0
|
||||||
|
|
||||||
usePubKey := true
|
usePubKey := true
|
||||||
numLogins := 50
|
numLogins := 50
|
||||||
u := getTestUser(usePubKey)
|
u := getTestUser(usePubKey)
|
||||||
|
@ -747,6 +750,8 @@ func TestConcurrency(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = os.RemoveAll(user.GetHomeDir())
|
err = os.RemoveAll(user.GetHomeDir())
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
common.Config.MaxPerHostConnections = oldValue
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyProtocol(t *testing.T) {
|
func TestProxyProtocol(t *testing.T) {
|
||||||
|
@ -2896,6 +2901,7 @@ func TestQuotaDisabledError(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:dupl
|
||||||
func TestMaxConnections(t *testing.T) {
|
func TestMaxConnections(t *testing.T) {
|
||||||
oldValue := common.Config.MaxTotalConnections
|
oldValue := common.Config.MaxTotalConnections
|
||||||
common.Config.MaxTotalConnections = 1
|
common.Config.MaxTotalConnections = 1
|
||||||
|
@ -2910,7 +2916,7 @@ func TestMaxConnections(t *testing.T) {
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
assert.NoError(t, checkBasicSFTP(client))
|
assert.NoError(t, checkBasicSFTP(client))
|
||||||
s, c, err := getSftpClient(user, usePubKey)
|
s, c, err := getSftpClient(user, usePubKey)
|
||||||
if !assert.Error(t, err, "max sessions exceeded, new login should not succeed") {
|
if !assert.Error(t, err, "max total connections exceeded, new login should not succeed") {
|
||||||
c.Close()
|
c.Close()
|
||||||
s.Close()
|
s.Close()
|
||||||
}
|
}
|
||||||
|
@ -2923,6 +2929,34 @@ func TestMaxConnections(t *testing.T) {
|
||||||
common.Config.MaxTotalConnections = oldValue
|
common.Config.MaxTotalConnections = oldValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:dupl
|
||||||
|
func TestMaxPerHostConnections(t *testing.T) {
|
||||||
|
oldValue := common.Config.MaxPerHostConnections
|
||||||
|
common.Config.MaxPerHostConnections = 1
|
||||||
|
|
||||||
|
usePubKey := true
|
||||||
|
u := getTestUser(usePubKey)
|
||||||
|
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
conn, client, err := getSftpClient(user, usePubKey)
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
defer conn.Close()
|
||||||
|
defer client.Close()
|
||||||
|
assert.NoError(t, checkBasicSFTP(client))
|
||||||
|
s, c, err := getSftpClient(user, usePubKey)
|
||||||
|
if !assert.Error(t, err, "max per host connections exceeded, new login should not succeed") {
|
||||||
|
c.Close()
|
||||||
|
s.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = os.RemoveAll(user.GetHomeDir())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
common.Config.MaxPerHostConnections = oldValue
|
||||||
|
}
|
||||||
|
|
||||||
func TestMaxSessions(t *testing.T) {
|
func TestMaxSessions(t *testing.T) {
|
||||||
usePubKey := false
|
usePubKey := false
|
||||||
u := getTestUser(usePubKey)
|
u := getTestUser(usePubKey)
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
"startup_hook": "",
|
"startup_hook": "",
|
||||||
"post_connect_hook": "",
|
"post_connect_hook": "",
|
||||||
"max_total_connections": 0,
|
"max_total_connections": 0,
|
||||||
|
"max_per_host_connections": 20,
|
||||||
"defender": {
|
"defender": {
|
||||||
"enabled": false,
|
"enabled": false,
|
||||||
"ban_time": 30,
|
"ban_time": 30,
|
||||||
|
@ -19,7 +20,7 @@
|
||||||
"threshold": 15,
|
"threshold": 15,
|
||||||
"score_invalid": 2,
|
"score_invalid": 2,
|
||||||
"score_valid": 1,
|
"score_valid": 1,
|
||||||
"score_rate_exceeded": 3,
|
"score_limit_exceeded": 3,
|
||||||
"observation_time": 30,
|
"observation_time": 30,
|
||||||
"entries_soft_limit": 100,
|
"entries_soft_limit": 100,
|
||||||
"entries_hard_limit": 150,
|
"entries_hard_limit": 150,
|
||||||
|
|
|
@ -144,16 +144,18 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
|
http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
common.Connections.AddNetworkConnection()
|
|
||||||
defer common.Connections.RemoveNetworkConnection()
|
|
||||||
|
|
||||||
if !common.Connections.IsNewConnectionAllowed() {
|
checkRemoteAddress(r)
|
||||||
|
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
|
||||||
|
|
||||||
|
common.Connections.AddClientConnection(ipAddr)
|
||||||
|
defer common.Connections.RemoveClientConnection(ipAddr)
|
||||||
|
|
||||||
|
if !common.Connections.IsNewConnectionAllowed(ipAddr) {
|
||||||
logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", "connection refused, configured limit reached")
|
logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", "connection refused, configured limit reached")
|
||||||
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable)
|
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
checkRemoteAddress(r)
|
|
||||||
ipAddr := utils.GetIPFromRemoteAddress(r.RemoteAddr)
|
|
||||||
if common.IsBanned(ipAddr) {
|
if common.IsBanned(ipAddr) {
|
||||||
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
|
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
|
||||||
return
|
return
|
||||||
|
|
|
@ -747,7 +747,7 @@ func TestDefender(t *testing.T) {
|
||||||
cfg := config.GetCommonConfig()
|
cfg := config.GetCommonConfig()
|
||||||
cfg.DefenderConfig.Enabled = true
|
cfg.DefenderConfig.Enabled = true
|
||||||
cfg.DefenderConfig.Threshold = 3
|
cfg.DefenderConfig.Threshold = 3
|
||||||
cfg.DefenderConfig.ScoreRateExceeded = 2
|
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||||
|
|
||||||
err := common.Initialize(cfg)
|
err := common.Initialize(cfg)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -934,6 +934,33 @@ func TestMaxConnections(t *testing.T) {
|
||||||
common.Config.MaxTotalConnections = oldValue
|
common.Config.MaxTotalConnections = oldValue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMaxPerHostConnections(t *testing.T) {
|
||||||
|
oldValue := common.Config.MaxPerHostConnections
|
||||||
|
common.Config.MaxPerHostConnections = 1
|
||||||
|
|
||||||
|
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
client := getWebDavClient(user, true, nil)
|
||||||
|
assert.NoError(t, checkBasicFunc(client))
|
||||||
|
// now add a fake connection
|
||||||
|
addrs, err := net.LookupHost("localhost")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
for _, addr := range addrs {
|
||||||
|
common.Connections.AddClientConnection(addr)
|
||||||
|
}
|
||||||
|
assert.Error(t, checkBasicFunc(client))
|
||||||
|
for _, addr := range addrs {
|
||||||
|
common.Connections.RemoveClientConnection(addr)
|
||||||
|
}
|
||||||
|
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = os.RemoveAll(user.GetHomeDir())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, common.Connections.GetStats(), 0)
|
||||||
|
|
||||||
|
common.Config.MaxPerHostConnections = oldValue
|
||||||
|
}
|
||||||
|
|
||||||
func TestMaxSessions(t *testing.T) {
|
func TestMaxSessions(t *testing.T) {
|
||||||
u := getTestUser()
|
u := getTestUser()
|
||||||
u.MaxSessions = 1
|
u.MaxSessions = 1
|
||||||
|
|
Loading…
Reference in a new issue