add rate limiting support

This commit is contained in:
Nicola Murino 2021-04-18 12:31:06 +02:00
parent 124c471a2b
commit 112e3b2fc2
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
22 changed files with 876 additions and 51 deletions

View file

@ -34,6 +34,7 @@ Several storage backends are supported: local filesystem, encrypted local filesy
- Configurable custom commands and/or HTTP notifications on file upload, download, pre-delete, delete, rename, on SSH commands and on user add, update and delete.
- Automatically terminating idle connections.
- Automatic blocklist management is supported using the built-in [defender](./docs/defender.md).
- Per-protocol [rate limiting](./docs/rate-limiting.md) is supported and can optionally be connected to the built-in defender to automatically block hosts that repeatedly exceed the configured limit.
- Atomic uploads are configurable.
- Support for Git repositories over SSH.
- SCP and rsync are supported.

View file

@ -104,6 +104,8 @@ var (
idleTimeoutTicker *time.Ticker
idleTimeoutTickerDone chan bool
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV}
// the map key is the protocol, for each protocol we can have multiple rate limiters
rateLimiters map[string][]*rateLimiter
)
// Initialize sets the common configuration
@ -123,6 +125,32 @@ func Initialize(c Configuration) error {
logger.Info(logSender, "", "defender initialized with config %+v", c.DefenderConfig)
Config.defender = defender
}
rateLimiters = make(map[string][]*rateLimiter)
for _, rlCfg := range c.RateLimitersConfig {
if rlCfg.isEnabled() {
if err := rlCfg.validate(); err != nil {
return fmt.Errorf("rate limiters initialization error: %v", err)
}
rateLimiter := rlCfg.getLimiter()
for _, protocol := range rlCfg.Protocols {
rateLimiters[protocol] = append(rateLimiters[protocol], rateLimiter)
}
}
}
return nil
}
// LimitRate blocks until all the configured rate limiters
// allow one event to happen.
// It returns an error if the time to wait exceeds the max
// allowed delay
func LimitRate(protocol, ip string) error {
for _, limiter := range rateLimiters[protocol] {
if err := limiter.Wait(ip); err != nil {
logger.Debug(logSender, "", "protocol %v ip %v: %v", protocol, ip, err)
return err
}
}
return nil
}
@ -324,7 +352,9 @@ type Configuration struct {
// Maximum number of concurrent client connections. 0 means unlimited
MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"`
// Defender configuration
DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"`
DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"`
// Rate limiter configurations
RateLimitersConfig []RateLimiterConfig `json:"rate_limiters" mapstructure:"rate_limiters"`
idleTimeoutAsDuration time.Duration
idleLoginTimeout time.Duration
defender Defender

View file

@ -165,6 +165,64 @@ func TestDefenderIntegration(t *testing.T) {
Config = configCopy
}
func TestRateLimitersIntegration(t *testing.T) {
// by default defender is nil
configCopy := Config
Config.RateLimitersConfig = []RateLimiterConfig{
{
Average: 100,
Period: 10,
Burst: 5,
Type: int(rateLimiterTypeGlobal),
Protocols: rateLimiterProtocolValues,
},
{
Average: 1,
Period: 1000,
Burst: 1,
Type: int(rateLimiterTypeSource),
Protocols: []string{ProtocolWebDAV, ProtocolWebDAV, ProtocolFTP},
GenerateDefenderEvents: true,
EntriesSoftLimit: 100,
EntriesHardLimit: 150,
},
}
err := Initialize(Config)
assert.Error(t, err)
Config.RateLimitersConfig[0].Period = 1000
err = Initialize(Config)
assert.NoError(t, err)
assert.Len(t, rateLimiters, 3)
assert.Len(t, rateLimiters[ProtocolSSH], 1)
assert.Len(t, rateLimiters[ProtocolFTP], 2)
assert.Len(t, rateLimiters[ProtocolWebDAV], 2)
source1 := "127.1.1.1"
source2 := "127.1.1.2"
err = LimitRate(ProtocolSSH, source1)
assert.NoError(t, err)
err = LimitRate(ProtocolFTP, source1)
assert.NoError(t, err)
// sleep to allow the add configured burst to the token.
// This sleep is not enough to add the per-source burst
time.Sleep(20 * time.Millisecond)
err = LimitRate(ProtocolWebDAV, source2)
assert.NoError(t, err)
err = LimitRate(ProtocolFTP, source1)
assert.Error(t, err)
err = LimitRate(ProtocolWebDAV, source2)
assert.Error(t, err)
err = LimitRate(ProtocolSSH, source1)
assert.NoError(t, err)
err = LimitRate(ProtocolSSH, source2)
assert.NoError(t, err)
Config = configCopy
}
func TestMaxConnections(t *testing.T) {
oldValue := Config.MaxTotalConnections
Config.MaxTotalConnections = 1

View file

@ -23,6 +23,7 @@ const (
HostEventLoginFailed HostEvent = iota
HostEventUserNotFound
HostEventNoLoginTried
HostEventRateExceeded
)
// Defender defines the interface that a defender must implements
@ -50,6 +51,8 @@ type DefenderConfig struct {
ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"`
// Score for valid login attempts, eg. user accounts that exist
ScoreValid int `json:"score_valid" mapstructure:"score_valid"`
// Score for rate exceeded events, generated from the rate limiters
ScoreRateExceeded int `json:"score_rate_exceeded" mapstructure:"score_rate_exceeded"`
// Defines the time window, in minutes, for tracking client errors.
// A host is banned if it has exceeded the defined threshold during
// the last observation time minutes
@ -123,6 +126,9 @@ func (c *DefenderConfig) validate() error {
if c.ScoreValid >= c.Threshold {
return fmt.Errorf("score_valid %v cannot be greater than threshold %v", c.ScoreValid, c.Threshold)
}
if c.ScoreRateExceeded >= c.Threshold {
return fmt.Errorf("score_rate_exceeded %v cannot be greater than threshold %v", c.ScoreRateExceeded, c.Threshold)
}
if c.BanTime <= 0 {
return fmt.Errorf("invalid ban_time %v", c.BanTime)
}
@ -248,6 +254,8 @@ func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
switch event {
case HostEventLoginFailed:
score = d.config.ScoreValid
case HostEventRateExceeded:
score = d.config.ScoreRateExceeded
case HostEventUserNotFound, HostEventNoLoginTried:
score = d.config.ScoreInvalid
}

View file

@ -41,17 +41,18 @@ func TestBasicDefender(t *testing.T) {
assert.NoError(t, err)
config := &DefenderConfig{
Enabled: true,
BanTime: 10,
BanTimeIncrement: 2,
Threshold: 5,
ScoreInvalid: 2,
ScoreValid: 1,
ObservationTime: 15,
EntriesSoftLimit: 1,
EntriesHardLimit: 2,
SafeListFile: "slFile",
BlockListFile: "blFile",
Enabled: true,
BanTime: 10,
BanTimeIncrement: 2,
Threshold: 5,
ScoreInvalid: 2,
ScoreValid: 1,
ScoreRateExceeded: 3,
ObservationTime: 15,
EntriesSoftLimit: 1,
EntriesHardLimit: 2,
SafeListFile: "slFile",
BlockListFile: "blFile",
}
_, err = newInMemoryDefender(config)
@ -74,6 +75,7 @@ func TestBasicDefender(t *testing.T) {
defender.AddEvent("172.16.1.4", HostEventLoginFailed)
defender.AddEvent("192.168.8.4", HostEventUserNotFound)
defender.AddEvent("172.16.1.3", HostEventRateExceeded)
assert.Equal(t, 0, defender.countHosts())
testIP := "12.34.56.78"
@ -82,10 +84,10 @@ func TestBasicDefender(t *testing.T) {
assert.Equal(t, 0, defender.countBanned())
assert.Equal(t, 1, defender.GetScore(testIP))
assert.Nil(t, defender.GetBanTime(testIP))
defender.AddEvent(testIP, HostEventNoLoginTried)
defender.AddEvent(testIP, HostEventRateExceeded)
assert.Equal(t, 1, defender.countHosts())
assert.Equal(t, 0, defender.countBanned())
assert.Equal(t, 3, defender.GetScore(testIP))
assert.Equal(t, 4, defender.GetScore(testIP))
defender.AddEvent(testIP, HostEventNoLoginTried)
assert.Equal(t, 0, defender.countHosts())
assert.Equal(t, 1, defender.countBanned())
@ -315,6 +317,11 @@ func TestDefenderConfig(t *testing.T) {
require.Error(t, err)
c.ScoreInvalid = 2
c.ScoreRateExceeded = 10
err = c.validate()
require.Error(t, err)
c.ScoreRateExceeded = 2
c.ScoreValid = 10
err = c.validate()
require.Error(t, err)

229
common/ratelimiter.go Normal file
View file

@ -0,0 +1,229 @@
package common
import (
"errors"
"fmt"
"sort"
"sync"
"sync/atomic"
"time"
"golang.org/x/time/rate"
"github.com/drakkan/sftpgo/utils"
)
var (
errNoBucket = errors.New("no bucket found")
errReserve = errors.New("unable to reserve token")
rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV}
)
// RateLimiterType defines the supported rate limiters types
type RateLimiterType int
// Supported rate limiter types
const (
rateLimiterTypeGlobal RateLimiterType = iota + 1
rateLimiterTypeSource
)
// RateLimiterConfig defines the configuration for a rate limiter
type RateLimiterConfig struct {
// Average defines the maximum rate allowed. 0 means disabled
Average int64 `json:"average" mapstructure:"average"`
// Period defines the period as milliseconds. Default: 1000 (1 second).
// The rate is actually defined by dividing average by period.
// So for a rate below 1 req/s, one needs to define a period larger than a second.
Period int64 `json:"period" mapstructure:"period"`
// Burst is the maximum number of requests allowed to go through in the
// same arbitrarily small period of time. Default: 1.
Burst int `json:"burst" mapstructure:"burst"`
// Type defines the rate limiter type:
// - rateLimiterTypeGlobal is a global rate limiter independent from the source
// - rateLimiterTypeSource is a per-source rate limiter
Type int `json:"type" mapstructure:"type"`
// Protocols defines the protocols for this rate limiter.
// Available protocols are: "SFTP", "FTP", "DAV".
// A rate limiter with no protocols defined is disabled
Protocols []string `json:"protocols" mapstructure:"protocols"`
// If the rate limit is exceeded, the defender is enabled, and this is a per-source limiter,
// a new defender event will be generated
GenerateDefenderEvents bool `json:"generate_defender_events" mapstructure:"generate_defender_events"`
// The number of per-ip rate limiters kept in memory will vary between the
// soft and hard limit
EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"`
EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"`
}
func (r *RateLimiterConfig) isEnabled() bool {
return r.Average > 0 && len(r.Protocols) > 0
}
func (r *RateLimiterConfig) validate() error {
if r.Burst < 1 {
return fmt.Errorf("invalid burst %v. It must be >= 1", r.Burst)
}
if r.Period < 100 {
return fmt.Errorf("invalid period %v. It must be >= 100", r.Period)
}
if r.Type != int(rateLimiterTypeGlobal) && r.Type != int(rateLimiterTypeSource) {
return fmt.Errorf("invalid type %v", r.Type)
}
if r.Type != int(rateLimiterTypeGlobal) {
if r.EntriesSoftLimit <= 0 {
return fmt.Errorf("invalid entries_soft_limit %v", r.EntriesSoftLimit)
}
if r.EntriesHardLimit <= r.EntriesSoftLimit {
return fmt.Errorf("invalid entries_hard_limit %v must be > %v", r.EntriesHardLimit, r.EntriesSoftLimit)
}
}
r.Protocols = utils.RemoveDuplicates(r.Protocols)
for _, protocol := range r.Protocols {
if !utils.IsStringInSlice(protocol, rateLimiterProtocolValues) {
return fmt.Errorf("invalid protocol %#v", protocol)
}
}
return nil
}
func (r *RateLimiterConfig) getLimiter() *rateLimiter {
limiter := &rateLimiter{
burst: r.Burst,
globalBucket: nil,
generateDefenderEvents: r.GenerateDefenderEvents,
}
var maxDelay time.Duration
period := time.Duration(r.Period) * time.Millisecond
rtl := float64(r.Average*int64(time.Second)) / float64(period)
limiter.rate = rate.Limit(rtl)
if rtl < 1 {
maxDelay = period / 2
} else {
maxDelay = time.Second / (time.Duration(rtl) * 2)
}
if maxDelay > 10*time.Second {
maxDelay = 10 * time.Second
}
limiter.maxDelay = maxDelay
limiter.buckets = sourceBuckets{
buckets: make(map[string]sourceRateLimiter),
hardLimit: r.EntriesHardLimit,
softLimit: r.EntriesSoftLimit,
}
if r.Type != int(rateLimiterTypeSource) {
limiter.globalBucket = rate.NewLimiter(limiter.rate, limiter.burst)
}
return limiter
}
// RateLimiter defines a rate limiter
type rateLimiter struct {
rate rate.Limit
burst int
maxDelay time.Duration
globalBucket *rate.Limiter
buckets sourceBuckets
generateDefenderEvents bool
}
// Wait blocks until the limit allows one event to happen
// or returns an error if the time to wait exceeds the max
// allowed delay
func (rl *rateLimiter) Wait(source string) error {
var res *rate.Reservation
if rl.globalBucket != nil {
res = rl.globalBucket.Reserve()
} else {
var err error
res, err = rl.buckets.reserve(source)
if err != nil {
rateLimiter := rate.NewLimiter(rl.rate, rl.burst)
res = rl.buckets.addAndReserve(rateLimiter, source)
}
}
if !res.OK() {
return errReserve
}
delay := res.Delay()
if delay > rl.maxDelay {
res.Cancel()
if rl.generateDefenderEvents && rl.globalBucket == nil {
AddDefenderEvent(source, HostEventRateExceeded)
}
return fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
}
time.Sleep(delay)
return nil
}
type sourceRateLimiter struct {
lastActivity int64
bucket *rate.Limiter
}
func (s *sourceRateLimiter) updateLastActivity() {
atomic.StoreInt64(&s.lastActivity, time.Now().UnixNano())
}
func (s *sourceRateLimiter) getLastActivity() int64 {
return atomic.LoadInt64(&s.lastActivity)
}
type sourceBuckets struct {
sync.RWMutex
buckets map[string]sourceRateLimiter
hardLimit int
softLimit int
}
func (b *sourceBuckets) reserve(source string) (*rate.Reservation, error) {
b.RLock()
defer b.RUnlock()
if src, ok := b.buckets[source]; ok {
src.updateLastActivity()
return src.bucket.Reserve(), nil
}
return nil, errNoBucket
}
func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Reservation {
b.Lock()
defer b.Unlock()
b.cleanup()
src := sourceRateLimiter{
bucket: r,
}
src.updateLastActivity()
b.buckets[source] = src
return src.bucket.Reserve()
}
func (b *sourceBuckets) cleanup() {
if len(b.buckets) >= b.hardLimit {
numToRemove := len(b.buckets) - b.softLimit
kvList := make(kvList, 0, len(b.buckets))
for k, v := range b.buckets {
kvList = append(kvList, kv{
Key: k,
Value: v.getLastActivity(),
})
}
sort.Sort(kvList)
for idx, kv := range kvList {
if idx >= numToRemove {
break
}
delete(b.buckets, kv.Key)
}
}
}

135
common/ratelimiter_test.go Normal file
View file

@ -0,0 +1,135 @@
package common
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRateLimiterConfig(t *testing.T) {
config := RateLimiterConfig{}
err := config.validate()
require.Error(t, err)
config.Burst = 1
config.Period = 10
err = config.validate()
require.Error(t, err)
config.Period = 1000
config.Type = 100
err = config.validate()
require.Error(t, err)
config.Type = int(rateLimiterTypeSource)
config.EntriesSoftLimit = 0
err = config.validate()
require.Error(t, err)
config.EntriesSoftLimit = 150
config.EntriesHardLimit = 0
err = config.validate()
require.Error(t, err)
config.EntriesHardLimit = 200
config.Protocols = []string{"unsupported protocol"}
err = config.validate()
require.Error(t, err)
config.Protocols = rateLimiterProtocolValues
err = config.validate()
require.NoError(t, err)
limiter := config.getLimiter()
require.Equal(t, 500*time.Millisecond, limiter.maxDelay)
require.Nil(t, limiter.globalBucket)
config.Type = int(rateLimiterTypeGlobal)
config.Average = 1
config.Period = 10000
limiter = config.getLimiter()
require.Equal(t, 5*time.Second, limiter.maxDelay)
require.NotNil(t, limiter.globalBucket)
config.Period = 100000
limiter = config.getLimiter()
require.Equal(t, 10*time.Second, limiter.maxDelay)
config.Period = 500
config.Average = 1
limiter = config.getLimiter()
require.Equal(t, 250*time.Millisecond, limiter.maxDelay)
}
func TestRateLimiter(t *testing.T) {
config := RateLimiterConfig{
Average: 1,
Period: 1000,
Burst: 1,
Type: int(rateLimiterTypeGlobal),
Protocols: rateLimiterProtocolValues,
}
limiter := config.getLimiter()
err := limiter.Wait("")
require.NoError(t, err)
err = limiter.Wait("")
require.Error(t, err)
config.Type = int(rateLimiterTypeSource)
config.GenerateDefenderEvents = true
config.EntriesSoftLimit = 5
config.EntriesHardLimit = 10
limiter = config.getLimiter()
source := "192.168.1.2"
err = limiter.Wait(source)
require.NoError(t, err)
err = limiter.Wait(source)
require.Error(t, err)
// a different source should work
err = limiter.Wait(source + "1")
require.NoError(t, err)
config.Burst = 0
limiter = config.getLimiter()
err = limiter.Wait(source)
require.ErrorIs(t, err, errReserve)
}
func TestLimiterCleanup(t *testing.T) {
config := RateLimiterConfig{
Average: 100,
Period: 1000,
Burst: 1,
Type: int(rateLimiterTypeSource),
Protocols: rateLimiterProtocolValues,
EntriesSoftLimit: 1,
EntriesHardLimit: 3,
}
limiter := config.getLimiter()
source1 := "10.8.0.1"
source2 := "10.8.0.2"
source3 := "10.8.0.3"
source4 := "10.8.0.4"
err := limiter.Wait(source1)
assert.NoError(t, err)
time.Sleep(20 * time.Millisecond)
err = limiter.Wait(source2)
assert.NoError(t, err)
time.Sleep(20 * time.Millisecond)
assert.Len(t, limiter.buckets.buckets, 2)
_, ok := limiter.buckets.buckets[source1]
assert.True(t, ok)
_, ok = limiter.buckets.buckets[source2]
assert.True(t, ok)
err = limiter.Wait(source3)
assert.NoError(t, err)
assert.Len(t, limiter.buckets.buckets, 3)
_, ok = limiter.buckets.buckets[source1]
assert.True(t, ok)
_, ok = limiter.buckets.buckets[source2]
assert.True(t, ok)
_, ok = limiter.buckets.buckets[source3]
assert.True(t, ok)
time.Sleep(20 * time.Millisecond)
err = limiter.Wait(source4)
assert.NoError(t, err)
assert.Len(t, limiter.buckets.buckets, 2)
_, ok = limiter.buckets.buckets[source3]
assert.True(t, ok)
_, ok = limiter.buckets.buckets[source4]
assert.True(t, ok)
}

View file

@ -69,6 +69,16 @@ var (
ClientAuthType: 0,
TLSCipherSuites: nil,
}
defaultRateLimiter = common.RateLimiterConfig{
Average: 0,
Period: 1000,
Burst: 1,
Type: 2,
Protocols: []string{common.ProtocolSSH, common.ProtocolFTP, common.ProtocolWebDAV},
GenerateDefenderEvents: false,
EntriesSoftLimit: 100,
EntriesHardLimit: 150,
}
)
type globalConfig struct {
@ -106,18 +116,20 @@ func Init() {
PostConnectHook: "",
MaxTotalConnections: 0,
DefenderConfig: common.DefenderConfig{
Enabled: false,
BanTime: 30,
BanTimeIncrement: 50,
Threshold: 15,
ScoreInvalid: 2,
ScoreValid: 1,
ObservationTime: 30,
EntriesSoftLimit: 100,
EntriesHardLimit: 150,
SafeListFile: "",
BlockListFile: "",
Enabled: false,
BanTime: 30,
BanTimeIncrement: 50,
Threshold: 15,
ScoreInvalid: 2,
ScoreValid: 1,
ScoreRateExceeded: 3,
ObservationTime: 30,
EntriesSoftLimit: 100,
EntriesHardLimit: 150,
SafeListFile: "",
BlockListFile: "",
},
RateLimitersConfig: []common.RateLimiterConfig{defaultRateLimiter},
},
SFTPD: sftpd.Configuration{
Banner: defaultSFTPDBanner,
@ -538,8 +550,8 @@ func loadBindingsFromEnv() {
checkWebDAVDBindingCompatibility()
checkHTTPDBindingCompatibility()
maxBindings := make([]int, 10)
for idx := range maxBindings {
for idx := 0; idx < 10; idx++ {
getRateLimitersFromEnv(idx)
getSFTPDBindindFromEnv(idx)
getFTPDBindingFromEnv(idx)
getWebDAVDBindingFromEnv(idx)
@ -548,6 +560,71 @@ func loadBindingsFromEnv() {
}
}
func getRateLimitersFromEnv(idx int) {
rtlConfig := defaultRateLimiter
if len(globalConf.Common.RateLimitersConfig) > idx {
rtlConfig = globalConf.Common.RateLimitersConfig[idx]
}
isSet := false
average, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__AVERAGE", idx))
if ok {
rtlConfig.Average = average
isSet = true
}
period, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__PERIOD", idx))
if ok {
rtlConfig.Period = period
isSet = true
}
burst, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__BURST", idx))
if ok {
rtlConfig.Burst = int(burst)
isSet = true
}
rtlType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__TYPE", idx))
if ok {
rtlConfig.Type = int(rtlType)
isSet = true
}
protocols, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__PROTOCOLS", idx))
if ok {
rtlConfig.Protocols = protocols
isSet = true
}
generateEvents, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__GENERATE_DEFENDER_EVENTS", idx))
if ok {
rtlConfig.GenerateDefenderEvents = generateEvents
isSet = true
}
softLimit, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__ENTRIES_SOFT_LIMIT", idx))
if ok {
rtlConfig.EntriesSoftLimit = int(softLimit)
isSet = true
}
hardLimit, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__ENTRIES_HARD_LIMIT", idx))
if ok {
rtlConfig.EntriesHardLimit = int(hardLimit)
isSet = true
}
if isSet {
if len(globalConf.Common.RateLimitersConfig) > idx {
globalConf.Common.RateLimitersConfig[idx] = rtlConfig
} else {
globalConf.Common.RateLimitersConfig = append(globalConf.Common.RateLimitersConfig, rtlConfig)
}
}
}
func getSFTPDBindindFromEnv(idx int) {
binding := sftpd.Binding{
ApplyProxyConfig: true,
@ -560,7 +637,7 @@ func getSFTPDBindindFromEnv(idx int) {
port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_SFTPD__BINDINGS__%v__PORT", idx))
if ok {
binding.Port = port
binding.Port = int(port)
isSet = true
}
@ -597,7 +674,7 @@ func getFTPDBindingFromEnv(idx int) {
port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PORT", idx))
if ok {
binding.Port = port
binding.Port = int(port)
isSet = true
}
@ -615,7 +692,7 @@ func getFTPDBindingFromEnv(idx int) {
tlsMode, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__TLS_MODE", idx))
if ok {
binding.TLSMode = tlsMode
binding.TLSMode = int(tlsMode)
isSet = true
}
@ -627,7 +704,7 @@ func getFTPDBindingFromEnv(idx int) {
clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx))
if ok {
binding.ClientAuthType = clientAuthType
binding.ClientAuthType = int(clientAuthType)
isSet = true
}
@ -656,7 +733,7 @@ func getWebDAVDBindingFromEnv(idx int) {
port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PORT", idx))
if ok {
binding.Port = port
binding.Port = int(port)
isSet = true
}
@ -674,7 +751,7 @@ func getWebDAVDBindingFromEnv(idx int) {
clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx))
if ok {
binding.ClientAuthType = clientAuthType
binding.ClientAuthType = int(clientAuthType)
isSet = true
}
@ -709,7 +786,7 @@ func getHTTPDBindingFromEnv(idx int) {
port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__PORT", idx))
if ok {
binding.Port = port
binding.Port = int(port)
isSet = true
}
@ -733,7 +810,7 @@ func getHTTPDBindingFromEnv(idx int) {
clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx))
if ok {
binding.ClientAuthType = clientAuthType
binding.ClientAuthType = int(clientAuthType)
isSet = true
}
@ -790,6 +867,7 @@ func setViperDefaults() {
viper.SetDefault("common.defender.threshold", globalConf.Common.DefenderConfig.Threshold)
viper.SetDefault("common.defender.score_invalid", globalConf.Common.DefenderConfig.ScoreInvalid)
viper.SetDefault("common.defender.score_valid", globalConf.Common.DefenderConfig.ScoreValid)
viper.SetDefault("common.defender.score_rate_exceeded", globalConf.Common.DefenderConfig.ScoreRateExceeded)
viper.SetDefault("common.defender.observation_time", globalConf.Common.DefenderConfig.ObservationTime)
viper.SetDefault("common.defender.entries_soft_limit", globalConf.Common.DefenderConfig.EntriesSoftLimit)
viper.SetDefault("common.defender.entries_hard_limit", globalConf.Common.DefenderConfig.EntriesHardLimit)
@ -900,12 +978,12 @@ func lookupBoolFromEnv(envName string) (bool, bool) {
return false, false
}
func lookupIntFromEnv(envName string) (int, bool) {
func lookupIntFromEnv(envName string) (int64, bool) {
value, ok := os.LookupEnv(envName)
if ok {
converted, err := strconv.ParseInt(value, 10, 16)
converted, err := strconv.ParseInt(value, 10, 64)
if err == nil {
return int(converted), ok
return converted, ok
}
}

View file

@ -18,6 +18,7 @@ import (
"github.com/drakkan/sftpgo/httpclient"
"github.com/drakkan/sftpgo/httpd"
"github.com/drakkan/sftpgo/sftpd"
"github.com/drakkan/sftpgo/utils"
"github.com/drakkan/sftpgo/webdavd"
)
@ -427,6 +428,61 @@ func TestHTTPDBindingsCompatibility(t *testing.T) {
assert.NoError(t, err)
}
func TestRateLimitersFromEnv(t *testing.T) {
reset()
os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__AVERAGE", "100")
os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__PERIOD", "2000")
os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__BURST", "10")
os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__TYPE", "2")
os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__PROTOCOLS", "SSH, FTP")
os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__GENERATE_DEFENDER_EVENTS", "1")
os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__ENTRIES_SOFT_LIMIT", "50")
os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__ENTRIES_HARD_LIMIT", "100")
os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__8__AVERAGE", "50")
t.Cleanup(func() {
os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__AVERAGE")
os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__PERIOD")
os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__BURST")
os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__TYPE")
os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__PROTOCOLS")
os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__GENERATE_DEFENDER_EVENTS")
os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__ENTRIES_SOFT_LIMIT")
os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__ENTRIES_HARD_LIMIT")
os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__8__AVERAGE")
})
configDir := ".."
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
limiters := config.GetCommonConfig().RateLimitersConfig
require.Len(t, limiters, 2)
require.Equal(t, int64(100), limiters[0].Average)
require.Equal(t, int64(2000), limiters[0].Period)
require.Equal(t, 10, limiters[0].Burst)
require.Equal(t, 2, limiters[0].Type)
protocols := limiters[0].Protocols
require.Len(t, protocols, 2)
require.True(t, utils.IsStringInSlice(common.ProtocolFTP, protocols))
require.True(t, utils.IsStringInSlice(common.ProtocolSSH, protocols))
require.True(t, limiters[0].GenerateDefenderEvents)
require.Equal(t, 50, limiters[0].EntriesSoftLimit)
require.Equal(t, 100, limiters[0].EntriesHardLimit)
require.Equal(t, int64(50), limiters[1].Average)
// we check the default values here
require.Equal(t, int64(1000), limiters[1].Period)
require.Equal(t, 1, limiters[1].Burst)
require.Equal(t, 2, limiters[1].Type)
protocols = limiters[1].Protocols
require.Len(t, protocols, 3)
require.True(t, utils.IsStringInSlice(common.ProtocolFTP, protocols))
require.True(t, utils.IsStringInSlice(common.ProtocolSSH, protocols))
require.True(t, utils.IsStringInSlice(common.ProtocolWebDAV, protocols))
require.False(t, limiters[1].GenerateDefenderEvents)
require.Equal(t, 100, limiters[1].EntriesSoftLimit)
require.Equal(t, 150, limiters[1].EntriesHardLimit)
}
func TestSFTPDBindingsFromEnv(t *testing.T) {
reset()
@ -435,14 +491,12 @@ func TestSFTPDBindingsFromEnv(t *testing.T) {
os.Setenv("SFTPGO_SFTPD__BINDINGS__0__APPLY_PROXY_CONFIG", "false")
os.Setenv("SFTPGO_SFTPD__BINDINGS__3__ADDRESS", "127.0.1.1")
os.Setenv("SFTPGO_SFTPD__BINDINGS__3__PORT", "2203")
os.Setenv("SFTPGO_SFTPD__BINDINGS__3__APPLY_PROXY_CONFIG", "1")
t.Cleanup(func() {
os.Unsetenv("SFTPGO_SFTPD__BINDINGS__0__ADDRESS")
os.Unsetenv("SFTPGO_SFTPD__BINDINGS__0__PORT")
os.Unsetenv("SFTPGO_SFTPD__BINDINGS__0__APPLY_PROXY_CONFIG")
os.Unsetenv("SFTPGO_SFTPD__BINDINGS__3__ADDRESS")
os.Unsetenv("SFTPGO_SFTPD__BINDINGS__3__PORT")
os.Unsetenv("SFTPGO_SFTPD__BINDINGS__3__APPLY_PROXY_CONFIG")
})
configDir := ".."
@ -455,7 +509,7 @@ func TestSFTPDBindingsFromEnv(t *testing.T) {
require.False(t, bindings[0].ApplyProxyConfig)
require.Equal(t, 2203, bindings[1].Port)
require.Equal(t, "127.0.1.1", bindings[1].Address)
require.True(t, bindings[1].ApplyProxyConfig)
require.True(t, bindings[1].ApplyProxyConfig) // default value
}
func TestFTPDBindingsFromEnv(t *testing.T) {
@ -469,7 +523,6 @@ func TestFTPDBindingsFromEnv(t *testing.T) {
os.Setenv("SFTPGO_FTPD__BINDINGS__0__TLS_CIPHER_SUITES", "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256")
os.Setenv("SFTPGO_FTPD__BINDINGS__9__ADDRESS", "127.0.1.1")
os.Setenv("SFTPGO_FTPD__BINDINGS__9__PORT", "2203")
os.Setenv("SFTPGO_FTPD__BINDINGS__9__APPLY_PROXY_CONFIG", "t")
os.Setenv("SFTPGO_FTPD__BINDINGS__9__TLS_MODE", "1")
os.Setenv("SFTPGO_FTPD__BINDINGS__9__FORCE_PASSIVE_IP", "127.0.1.1")
os.Setenv("SFTPGO_FTPD__BINDINGS__9__CLIENT_AUTH_TYPE", "2")
@ -483,7 +536,6 @@ func TestFTPDBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__TLS_CIPHER_SUITES")
os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__ADDRESS")
os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__PORT")
os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__APPLY_PROXY_CONFIG")
os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__TLS_MODE")
os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__FORCE_PASSIVE_IP")
os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__CLIENT_AUTH_TYPE")
@ -505,7 +557,7 @@ func TestFTPDBindingsFromEnv(t *testing.T) {
require.Equal(t, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", bindings[0].TLSCipherSuites[1])
require.Equal(t, 2203, bindings[1].Port)
require.Equal(t, "127.0.1.1", bindings[1].Address)
require.True(t, bindings[1].ApplyProxyConfig)
require.True(t, bindings[1].ApplyProxyConfig) // default value
require.Equal(t, 1, bindings[1].TLSMode)
require.Equal(t, "127.0.1.1", bindings[1].ForcePassiveIP)
require.Equal(t, 2, bindings[1].ClientAuthType)

View file

@ -8,6 +8,7 @@ You can configure a score for each event type:
- `score_valid`, defines the score for valid login attempts, eg. user accounts that exist. Default `1`.
- `score_invalid`, defines the score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts. Default `2`.
- `score_rate_exceeded`, defines the score for hosts that exceeded the configured rate limits. Default `3`.
And then you can configure:

View file

@ -72,11 +72,21 @@ The configuration file contains the following sections:
- `threshold`, integer. Threshold value for banning a client.
- `score_invalid`, integer. Score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts.
- `score_valid`, integer. Score for valid login attempts, eg. user accounts that exist.
- `score_rate_exceeded`, integer. Score for hosts that exceeded the configured rate limits.
- `observation_time`, integer. Defines the time window, in minutes, for tracking client errors. A host is banned if it has exceeded the defined threshold during the last observation time minutes.
- `entries_soft_limit`, integer.
- `entries_hard_limit`, integer. The number of banned IPs and host scores kept in memory will vary between the soft and hard limit.
- `safelist_file`, string. Path to a file containing a list of ip addresses and/or networks to never ban.
- `blocklist_file`, string. Path to a file containing a list of ip addresses and/or networks to always ban. The lists can be reloaded on demand sending a `SIGHUP` signal on Unix based systems and a `paramchange` request to the running service on Windows. An host that is already banned will not be automatically unbanned if you put it inside the safe list, you have to unban it using the REST API.
- `rate_limiters`, list of structs containing the rate limiters configuration. Take a look [here](./rate-limiting.md) for more details. Each struct has the following fields:
- `average`, integer. Average defines the maximum rate allowed. 0 means disabled. Default: 0
- `period`, integer. Period defines the period as milliseconds. The rate is actually defined by dividing average by period Default: 1000 (1 second).
- `burst`, integer. Burst defines the maximum number of requests allowed to go through in the same arbitrarily small period of time. Default: 1
- `type`, integer. 1 means a global rate limiter, independent from the source host. 2 means a per-ip rate limiter. Default: 2
- `protocols`, list of strings. Available protocols are `SSH`, `FTP`, `DAV`. By default all supported protocols are enabled
- `generate_defender_events`, boolean. If `true`, the defender is enabled, and this is not a global rate limiter, a new defender event will be generated each time the configured limit is exceeded. Default `false`
- `entries_soft_limit`, integer.
- `entries_hard_limit`, integer. The number of per-ip rate limiters kept in memory will vary between the soft and hard limit
- **"sftpd"**, the configuration for the SFTP server
- `bindings`, list of structs. Each struct has the following fields:
- `port`, integer. The port used for serving SFTP requests. 0 means disabled. Default: 2022

53
docs/rate-limiting.md Normal file
View file

@ -0,0 +1,53 @@
# Rate limiting
Rate limiting allows to control the number of requests going to the configured services.
SFTPGo implements a [token bucket](https://en.wikipedia.org/wiki/Token_bucket) initially full and refilled at the configured rate. The `burst` configuration parameter defines the size of the bucket. The rate is defined by dividing `average` by `period`, so for a rate below 1 req/s, one needs to define a period larger than a second.
Requests that exceed the configured limit will be delayed or denied if they exceed the maximum delay time.
SFTPGo allows to define per-protocol rate limiters so you can have different configurations for different protocols.
You can also define two types of rate limiters:
- global, it is independent from the source host and therefore define a limit for the configured protocol/s
- per-host, this type of rate limiter can be connected to the built-in [defender](./defender.md) and generate `score_rate_exceeded` events and thus hosts that repeatedly exceed the configured limit can be automatically blocked
If you configure a per-host rate limiter, SFTPGo will keep a rate limiter in memory for each host that connects to the service, you can limit the memory usage using the `entries_soft_limit` and `entries_hard_limit` configuration keys.
You can defines how many rate limiters as you want, but keep in mind that if you defines multiple rate limiters each request will be checked against all the configured limiters and so it can potentially be delayed multiple times. Let's clarify with an example, here is a configuration that defines a global rate limiter and a per-host rate limiter for the FTP protocol:
```json
"rate_limiters": [
{
"average": 100,
"period": 1000,
"burst": 1,
"type": 1,
"protocols": [
"SSH",
"FTP",
"DAV"
],
"generate_defender_events": false,
"entries_soft_limit": 100,
"entries_hard_limit": 150
},
{
"average": 10,
"period": 1000,
"burst": 1,
"type": 2,
"protocols": [
"FTP"
],
"generate_defender_events": true,
"entries_soft_limit": 100,
"entries_hard_limit": 150
}
]
```
we have a global rate limiter that limit the rate for the whole service to 100 req/s and an additional rate limiter that limits the `FTP` protocol to 10 req/s per host.
With this configuration, when a client connects via FTP it will be limited first by the global rate limiter and then by the per host rate limiter.
Clients connecting via SFTP/WebDAV will be checked only against the global rate limiter.

View file

@ -4,7 +4,7 @@ The `WebDAV` support can be enabled by configuring one or more `bindings` inside
Each user can access their home directory using the path `http/s://<SFTPGo ip>:<WevDAVPORT>/<prefix>`. By default `prefix` is empty. If you define a prefix it must be an abosulte URI, for example `/dav`.
WebDAV is quite a different protocol than SCP/FTP, there is no session concept, each command is a separate HTTP request and must be authenticated, to improve performance SFTPGo caches authenticated users. This way SFTPGo don't need to do a dataprovider query and a password check for each request.
WebDAV is quite a different protocol than SFTP/FTP, there is no session concept, each command is a separate HTTP request and must be authenticated, to improve performance SFTPGo caches authenticated users. This way SFTPGo don't need to do a dataprovider query and a password check for each request.
The user caching configuration allows to set:

View file

@ -771,12 +771,71 @@ func TestMaxConnections(t *testing.T) {
common.Config.MaxTotalConnections = oldValue
}
func TestRateLimiter(t *testing.T) {
oldConfig := config.GetCommonConfig()
cfg := config.GetCommonConfig()
cfg.DefenderConfig.Enabled = true
cfg.DefenderConfig.Threshold = 5
cfg.DefenderConfig.ScoreRateExceeded = 3
cfg.RateLimitersConfig = []common.RateLimiterConfig{
{
Average: 1,
Period: 1000,
Burst: 1,
Type: 2,
Protocols: []string{common.ProtocolFTP},
GenerateDefenderEvents: true,
EntriesSoftLimit: 100,
EntriesHardLimit: 150,
},
}
err := common.Initialize(cfg)
assert.NoError(t, err)
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
assert.NoError(t, err)
client, err := getFTPClient(user, false, nil)
if assert.NoError(t, err) {
err = checkBasicFTP(client)
assert.NoError(t, err)
err = client.Quit()
assert.NoError(t, err)
}
_, err = getFTPClient(user, true, nil)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "rate limit exceed")
}
_, err = getFTPClient(user, false, nil)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "rate limit exceed")
}
_, err = getFTPClient(user, true, nil)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "banned client IP")
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
err = common.Initialize(oldConfig)
assert.NoError(t, err)
}
func TestDefender(t *testing.T) {
oldConfig := config.GetCommonConfig()
cfg := config.GetCommonConfig()
cfg.DefenderConfig.Enabled = true
cfg.DefenderConfig.Threshold = 3
cfg.DefenderConfig.ScoreRateExceeded = 2
err := common.Initialize(cfg)
assert.NoError(t, err)
@ -800,7 +859,7 @@ func TestDefender(t *testing.T) {
user.Password = defaultPassword
_, err = getFTPClient(user, false, nil)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "Access denied, banned client IP")
assert.Contains(t, err.Error(), "banned client IP")
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)

View file

@ -138,14 +138,17 @@ func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
ipAddr := utils.GetIPFromRemoteAddress(cc.RemoteAddr().String())
if common.IsBanned(ipAddr) {
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, ip %#v is banned", ipAddr)
return "Access denied, banned client IP", common.ErrConnectionDenied
return "Access denied: banned client IP", common.ErrConnectionDenied
}
if !common.Connections.IsNewConnectionAllowed() {
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
return "", common.ErrConnectionDenied
return "Access denied: max allowed connection exceeded", common.ErrConnectionDenied
}
if err := common.LimitRate(common.ProtocolFTP, ipAddr); err != nil {
return fmt.Sprintf("Access denied: %v", err.Error()), err
}
if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolFTP); err != nil {
return "", err
return "Access denied by post connect hook", err
}
connID := fmt.Sprintf("%v_%v", s.ID, cc.ID())
user := dataprovider.User{}

2
go.mod
View file

@ -68,7 +68,7 @@ require (
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57
golang.org/x/text v0.3.6 // indirect
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba
google.golang.org/api v0.44.0
google.golang.org/genproto v0.0.0-20210406143921-e86de6bf7a46 // indirect
gopkg.in/ini.v1 v1.62.0 // indirect

View file

@ -2809,6 +2809,7 @@ func TestDefenderAPI(t *testing.T) {
cfg := config.GetCommonConfig()
cfg.DefenderConfig.Enabled = true
cfg.DefenderConfig.Threshold = 3
cfg.DefenderConfig.ScoreRateExceeded = 2
err := common.Initialize(cfg)
require.NoError(t, err)

View file

@ -360,6 +360,9 @@ func canAcceptConnection(ip string) bool {
logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached")
return false
}
if err := common.LimitRate(common.ProtocolSSH, ip); err != nil {
return false
}
if err := common.Config.ExecutePostConnectHook(ip, common.ProtocolSSH); err != nil {
return false
}

View file

@ -480,12 +480,51 @@ func TestLoginNonExistentUser(t *testing.T) {
assert.Error(t, err)
}
func TestRateLimiter(t *testing.T) {
oldConfig := config.GetCommonConfig()
cfg := config.GetCommonConfig()
cfg.RateLimitersConfig = []common.RateLimiterConfig{
{
Average: 1,
Period: 1000,
Burst: 1,
Type: 1,
Protocols: []string{common.ProtocolSSH},
},
}
err := common.Initialize(cfg)
assert.NoError(t, err)
usePubKey := false
user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer client.Close()
err = checkBasicSFTP(client)
assert.NoError(t, err)
}
_, err = getSftpClient(user, usePubKey)
assert.Error(t, err)
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
err = common.Initialize(oldConfig)
assert.NoError(t, err)
}
func TestDefender(t *testing.T) {
oldConfig := config.GetCommonConfig()
cfg := config.GetCommonConfig()
cfg.DefenderConfig.Enabled = true
cfg.DefenderConfig.Threshold = 3
cfg.DefenderConfig.ScoreRateExceeded = 2
err := common.Initialize(cfg)
assert.NoError(t, err)

View file

@ -19,12 +19,29 @@
"threshold": 15,
"score_invalid": 2,
"score_valid": 1,
"score_rate_exceeded": 3,
"observation_time": 30,
"entries_soft_limit": 100,
"entries_hard_limit": 150,
"safelist_file": "",
"blocklist_file": ""
}
},
"rate_limiters": [
{
"average": 0,
"period": 1000,
"burst": 1,
"type": 2,
"protocols": [
"SSH",
"FTP",
"DAV"
],
"generate_defender_events": false,
"entries_soft_limit": 100,
"entries_hard_limit": 150
}
]
},
"sftpd": {
"bindings": [

View file

@ -158,6 +158,10 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
return
}
if err := common.LimitRate(common.ProtocolWebDAV, ipAddr); err != nil {
http.Error(w, err.Error(), http.StatusTooManyRequests)
return
}
if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolWebDAV); err != nil {
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
return

View file

@ -704,12 +704,49 @@ func TestLoginNonExistentUser(t *testing.T) {
assert.Error(t, checkBasicFunc(client))
}
func TestRateLimiter(t *testing.T) {
oldConfig := config.GetCommonConfig()
cfg := config.GetCommonConfig()
cfg.RateLimitersConfig = []common.RateLimiterConfig{
{
Average: 1,
Period: 1000,
Burst: 3,
Type: 1,
Protocols: []string{common.ProtocolWebDAV},
},
}
err := common.Initialize(cfg)
assert.NoError(t, err)
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
assert.NoError(t, err)
client := getWebDavClient(user, false, nil)
assert.NoError(t, checkBasicFunc(client))
_, err = client.ReadDir(".")
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "429")
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
err = common.Initialize(oldConfig)
assert.NoError(t, err)
}
func TestDefender(t *testing.T) {
oldConfig := config.GetCommonConfig()
cfg := config.GetCommonConfig()
cfg.DefenderConfig.Enabled = true
cfg.DefenderConfig.Threshold = 3
cfg.DefenderConfig.ScoreRateExceeded = 2
err := common.Initialize(cfg)
assert.NoError(t, err)