add rate limiting support for REST API/web admin too

This commit is contained in:
Nicola Murino 2021-04-19 08:14:04 +02:00
parent 112e3b2fc2
commit f45c89fc46
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
15 changed files with 109 additions and 39 deletions

View file

@ -70,6 +70,7 @@ const (
ProtocolSSH = "SSH"
ProtocolFTP = "FTP"
ProtocolWebDAV = "DAV"
ProtocolHTTP = "HTTP"
)
// Upload modes
@ -144,14 +145,14 @@ func Initialize(c Configuration) error {
// allow one event to happen.
// It returns an error if the time to wait exceeds the max
// allowed delay
func LimitRate(protocol, ip string) error {
func LimitRate(protocol, ip string) (time.Duration, error) {
for _, limiter := range rateLimiters[protocol] {
if err := limiter.Wait(ip); err != nil {
if delay, err := limiter.Wait(ip); err != nil {
logger.Debug(logSender, "", "protocol %v ip %v: %v", protocol, ip, err)
return err
return delay, err
}
}
return nil
return 0, nil
}
// ReloadDefender reloads the defender's block and safe lists

View file

@ -194,30 +194,31 @@ func TestRateLimitersIntegration(t *testing.T) {
err = Initialize(Config)
assert.NoError(t, err)
assert.Len(t, rateLimiters, 3)
assert.Len(t, rateLimiters, 4)
assert.Len(t, rateLimiters[ProtocolSSH], 1)
assert.Len(t, rateLimiters[ProtocolFTP], 2)
assert.Len(t, rateLimiters[ProtocolWebDAV], 2)
assert.Len(t, rateLimiters[ProtocolHTTP], 1)
source1 := "127.1.1.1"
source2 := "127.1.1.2"
err = LimitRate(ProtocolSSH, source1)
_, err = LimitRate(ProtocolSSH, source1)
assert.NoError(t, err)
err = LimitRate(ProtocolFTP, source1)
_, err = LimitRate(ProtocolFTP, source1)
assert.NoError(t, err)
// sleep to allow the add configured burst to the token.
// This sleep is not enough to add the per-source burst
time.Sleep(20 * time.Millisecond)
err = LimitRate(ProtocolWebDAV, source2)
_, err = LimitRate(ProtocolWebDAV, source2)
assert.NoError(t, err)
err = LimitRate(ProtocolFTP, source1)
_, err = LimitRate(ProtocolFTP, source1)
assert.Error(t, err)
err = LimitRate(ProtocolWebDAV, source2)
_, err = LimitRate(ProtocolWebDAV, source2)
assert.Error(t, err)
err = LimitRate(ProtocolSSH, source1)
_, err = LimitRate(ProtocolSSH, source1)
assert.NoError(t, err)
err = LimitRate(ProtocolSSH, source2)
_, err = LimitRate(ProtocolSSH, source2)
assert.NoError(t, err)
Config = configCopy

View file

@ -16,7 +16,7 @@ import (
var (
errNoBucket = errors.New("no bucket found")
errReserve = errors.New("unable to reserve token")
rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV}
rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP}
)
// RateLimiterType defines the supported rate limiters types
@ -130,7 +130,7 @@ type rateLimiter struct {
// Wait blocks until the limit allows one event to happen
// or returns an error if the time to wait exceeds the max
// allowed delay
func (rl *rateLimiter) Wait(source string) error {
func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
var res *rate.Reservation
if rl.globalBucket != nil {
res = rl.globalBucket.Reserve()
@ -143,7 +143,7 @@ func (rl *rateLimiter) Wait(source string) error {
}
}
if !res.OK() {
return errReserve
return 0, errReserve
}
delay := res.Delay()
if delay > rl.maxDelay {
@ -151,10 +151,10 @@ func (rl *rateLimiter) Wait(source string) error {
if rl.generateDefenderEvents && rl.globalBucket == nil {
AddDefenderEvent(source, HostEventRateExceeded)
}
return fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
}
time.Sleep(delay)
return nil
return 0, nil
}
type sourceRateLimiter struct {

View file

@ -63,9 +63,9 @@ func TestRateLimiter(t *testing.T) {
Protocols: rateLimiterProtocolValues,
}
limiter := config.getLimiter()
err := limiter.Wait("")
_, err := limiter.Wait("")
require.NoError(t, err)
err = limiter.Wait("")
_, err = limiter.Wait("")
require.Error(t, err)
config.Type = int(rateLimiterTypeSource)
@ -75,17 +75,17 @@ func TestRateLimiter(t *testing.T) {
limiter = config.getLimiter()
source := "192.168.1.2"
err = limiter.Wait(source)
_, err = limiter.Wait(source)
require.NoError(t, err)
err = limiter.Wait(source)
_, err = limiter.Wait(source)
require.Error(t, err)
// a different source should work
err = limiter.Wait(source + "1")
_, err = limiter.Wait(source + "1")
require.NoError(t, err)
config.Burst = 0
limiter = config.getLimiter()
err = limiter.Wait(source)
_, err = limiter.Wait(source)
require.ErrorIs(t, err, errReserve)
}
@ -104,10 +104,10 @@ func TestLimiterCleanup(t *testing.T) {
source2 := "10.8.0.2"
source3 := "10.8.0.3"
source4 := "10.8.0.4"
err := limiter.Wait(source1)
_, err := limiter.Wait(source1)
assert.NoError(t, err)
time.Sleep(20 * time.Millisecond)
err = limiter.Wait(source2)
_, err = limiter.Wait(source2)
assert.NoError(t, err)
time.Sleep(20 * time.Millisecond)
assert.Len(t, limiter.buckets.buckets, 2)
@ -115,7 +115,7 @@ func TestLimiterCleanup(t *testing.T) {
assert.True(t, ok)
_, ok = limiter.buckets.buckets[source2]
assert.True(t, ok)
err = limiter.Wait(source3)
_, err = limiter.Wait(source3)
assert.NoError(t, err)
assert.Len(t, limiter.buckets.buckets, 3)
_, ok = limiter.buckets.buckets[source1]
@ -125,7 +125,7 @@ func TestLimiterCleanup(t *testing.T) {
_, ok = limiter.buckets.buckets[source3]
assert.True(t, ok)
time.Sleep(20 * time.Millisecond)
err = limiter.Wait(source4)
_, err = limiter.Wait(source4)
assert.NoError(t, err)
assert.Len(t, limiter.buckets.buckets, 2)
_, ok = limiter.buckets.buckets[source3]

View file

@ -74,7 +74,7 @@ var (
Period: 1000,
Burst: 1,
Type: 2,
Protocols: []string{common.ProtocolSSH, common.ProtocolFTP, common.ProtocolWebDAV},
Protocols: []string{common.ProtocolSSH, common.ProtocolFTP, common.ProtocolWebDAV, common.ProtocolHTTP},
GenerateDefenderEvents: false,
EntriesSoftLimit: 100,
EntriesHardLimit: 150,

View file

@ -474,10 +474,11 @@ func TestRateLimitersFromEnv(t *testing.T) {
require.Equal(t, 1, limiters[1].Burst)
require.Equal(t, 2, limiters[1].Type)
protocols = limiters[1].Protocols
require.Len(t, protocols, 3)
require.Len(t, protocols, 4)
require.True(t, utils.IsStringInSlice(common.ProtocolFTP, protocols))
require.True(t, utils.IsStringInSlice(common.ProtocolSSH, protocols))
require.True(t, utils.IsStringInSlice(common.ProtocolWebDAV, protocols))
require.True(t, utils.IsStringInSlice(common.ProtocolHTTP, protocols))
require.False(t, limiters[1].GenerateDefenderEvents)
require.Equal(t, 100, limiters[1].EntriesSoftLimit)
require.Equal(t, 150, limiters[1].EntriesHardLimit)

View file

@ -83,7 +83,7 @@ The configuration file contains the following sections:
- `period`, integer. Period defines the period as milliseconds. The rate is actually defined by dividing average by period Default: 1000 (1 second).
- `burst`, integer. Burst defines the maximum number of requests allowed to go through in the same arbitrarily small period of time. Default: 1
- `type`, integer. 1 means a global rate limiter, independent from the source host. 2 means a per-ip rate limiter. Default: 2
- `protocols`, list of strings. Available protocols are `SSH`, `FTP`, `DAV`. By default all supported protocols are enabled
- `protocols`, list of strings. Available protocols are `SSH`, `FTP`, `DAV`, `HTTP`. By default all supported protocols are enabled
- `generate_defender_events`, boolean. If `true`, the defender is enabled, and this is not a global rate limiter, a new defender event will be generated each time the configured limit is exceeded. Default `false`
- `entries_soft_limit`, integer.
- `entries_hard_limit`, integer. The number of per-ip rate limiters kept in memory will vary between the soft and hard limit

View file

@ -1,6 +1,6 @@
# Rate limiting
Rate limiting allows to control the number of requests going to the configured services.
Rate limiting allows to control the number of requests going to the SFTPGo services.
SFTPGo implements a [token bucket](https://en.wikipedia.org/wiki/Token_bucket) initially full and refilled at the configured rate. The `burst` configuration parameter defines the size of the bucket. The rate is defined by dividing `average` by `period`, so for a rate below 1 req/s, one needs to define a period larger than a second.
@ -8,9 +8,16 @@ Requests that exceed the configured limit will be delayed or denied if they exce
SFTPGo allows to define per-protocol rate limiters so you can have different configurations for different protocols.
The supported protocols are:
- `SSH`, includes SFTP and SSH commands
- `FTP`, includes FTP, FTPES, FTPS
- `DAV`, WebDAV
- `HTTP`, REST API and web admin
You can also define two types of rate limiters:
- global, it is independent from the source host and therefore define a limit for the configured protocol/s
- global, it is independent from the source host and therefore define an aggregate limit for the configured protocol/s
- per-host, this type of rate limiter can be connected to the built-in [defender](./defender.md) and generate `score_rate_exceeded` events and thus hosts that repeatedly exceed the configured limit can be automatically blocked
If you configure a per-host rate limiter, SFTPGo will keep a rate limiter in memory for each host that connects to the service, you can limit the memory usage using the `entries_soft_limit` and `entries_hard_limit` configuration keys.
@ -27,7 +34,8 @@ You can defines how many rate limiters as you want, but keep in mind that if you
"protocols": [
"SSH",
"FTP",
"DAV"
"DAV",
"HTTP"
],
"generate_defender_events": false,
"entries_soft_limit": 100,
@ -48,6 +56,6 @@ You can defines how many rate limiters as you want, but keep in mind that if you
]
```
we have a global rate limiter that limit the rate for the whole service to 100 req/s and an additional rate limiter that limits the `FTP` protocol to 10 req/s per host.
we have a global rate limiter that limit the aggregate rate for the all the services to 100 req/s and an additional rate limiter that limits the `FTP` protocol to 10 req/s per host.
With this configuration, when a client connects via FTP it will be limited first by the global rate limiter and then by the per host rate limiter.
Clients connecting via SFTP/WebDAV will be checked only against the global rate limiter.

View file

@ -144,7 +144,8 @@ func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
return "Access denied: max allowed connection exceeded", common.ErrConnectionDenied
}
if err := common.LimitRate(common.ProtocolFTP, ipAddr); err != nil {
_, err := common.LimitRate(common.ProtocolFTP, ipAddr)
if err != nil {
return fmt.Sprintf("Access denied: %v", err.Error()), err
}
if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolFTP); err != nil {

View file

@ -3118,6 +3118,44 @@ func TestLoaddataMode(t *testing.T) {
assert.NoError(t, err)
}
func TestRateLimiter(t *testing.T) {
oldConfig := config.GetCommonConfig()
cfg := config.GetCommonConfig()
cfg.RateLimitersConfig = []common.RateLimiterConfig{
{
Average: 1,
Period: 1000,
Burst: 1,
Type: 1,
Protocols: []string{common.ProtocolHTTP},
},
}
err := common.Initialize(cfg)
assert.NoError(t, err)
client := &http.Client{
Timeout: 5 * time.Second,
}
resp, err := client.Get(httpBaseURL + healthzPath)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
err = resp.Body.Close()
assert.NoError(t, err)
resp, err = client.Get(httpBaseURL + healthzPath)
assert.NoError(t, err)
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
assert.NotEmpty(t, resp.Header.Get("Retry-After"))
assert.NotEmpty(t, resp.Header.Get("X-Retry-In"))
err = resp.Body.Close()
assert.NoError(t, err)
err = common.Initialize(oldConfig)
assert.NoError(t, err)
}
func TestHTTPSConnection(t *testing.T) {
client := &http.Client{
Timeout: 5 * time.Second,

View file

@ -3,11 +3,13 @@ package httpd
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/jwt"
"github.com/drakkan/sftpgo/common"
"github.com/drakkan/sftpgo/logger"
"github.com/drakkan/sftpgo/utils"
)
@ -141,3 +143,15 @@ func verifyCSRFHeader(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
})
}
func rateLimiter(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if delay, err := common.LimitRate(common.ProtocolHTTP, utils.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil {
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds()))
w.Header().Set("X-Retry-In", delay.String())
sendAPIResponse(w, r, err, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}

View file

@ -259,6 +259,8 @@ func (s *httpdServer) initializeRouter() {
s.router.Use(saveConnectionAddress)
s.router.Use(middleware.GetHead)
s.router.Use(middleware.StripSlashes)
s.router.Use(middleware.RealIP)
s.router.Use(rateLimiter)
s.router.Group(func(r chi.Router) {
r.Get(healthzPath, func(w http.ResponseWriter, r *http.Request) {
@ -268,7 +270,6 @@ func (s *httpdServer) initializeRouter() {
s.router.Group(func(router chi.Router) {
router.Use(middleware.RequestID)
router.Use(middleware.RealIP)
router.Use(logger.NewStructuredLogger(logger.GetLogger()))
router.Use(middleware.Recoverer)

View file

@ -360,7 +360,8 @@ func canAcceptConnection(ip string) bool {
logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached")
return false
}
if err := common.LimitRate(common.ProtocolSSH, ip); err != nil {
_, err := common.LimitRate(common.ProtocolSSH, ip)
if err != nil {
return false
}
if err := common.Config.ExecutePostConnectHook(ip, common.ProtocolSSH); err != nil {

View file

@ -35,7 +35,8 @@
"protocols": [
"SSH",
"FTP",
"DAV"
"DAV",
"HTTP"
],
"generate_defender_events": false,
"entries_soft_limit": 100,

View file

@ -158,7 +158,10 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
return
}
if err := common.LimitRate(common.ProtocolWebDAV, ipAddr); err != nil {
delay, err := common.LimitRate(common.ProtocolWebDAV, ipAddr)
if err != nil {
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds()))
w.Header().Set("X-Retry-In", delay.String())
http.Error(w, err.Error(), http.StatusTooManyRequests)
return
}