From d9f30e7ac57f56260ede2ecd192c4ecd708f680d Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Thu, 17 Mar 2022 22:10:52 +0100 Subject: [PATCH] add a global whitelist if defined only the listed IPs/networks can access the configured services, all other client connections will be dropped before they even try to authenticate Signed-off-by: Nicola Murino --- common/common.go | 71 ++++++++++++++++++++++++++++++++++---- common/common_test.go | 65 ++++++++++++++++++++++++++++++++-- config/config.go | 2 ++ docs/full-configuration.md | 1 + ftpd/server.go | 4 +-- httpd/httpd_test.go | 56 +++++++++++++++++++++++++++++- httpd/server.go | 4 +-- service/service_windows.go | 4 +-- service/signals_unix.go | 4 +-- sftpd/server.go | 2 +- sftpgo.json | 1 + webdavd/server.go | 2 +- 12 files changed, 196 insertions(+), 20 deletions(-) diff --git a/common/common.go b/common/common.go index f5f61222..c222eb9c 100644 --- a/common/common.go +++ b/common/common.go @@ -142,11 +142,12 @@ func Initialize(c Configuration, isShared int) error { Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute startPeriodicTimeoutTicker(periodicTimeoutCheckInterval) Config.defender = nil + Config.whitelist = nil rateLimiters = make(map[string][]*rateLimiter) for _, rlCfg := range c.RateLimitersConfig { if rlCfg.isEnabled() { if err := rlCfg.validate(); err != nil { - return fmt.Errorf("rate limiters initialization error: %v", err) + return fmt.Errorf("rate limiters initialization error: %w", err) } allowList, err := util.ParseAllowedIPAndRanges(rlCfg.AllowList) if err != nil { @@ -177,6 +178,16 @@ func Initialize(c Configuration, isShared int) error { logger.Info(logSender, "", "defender initialized with config %+v", c.DefenderConfig) Config.defender = defender } + if c.WhiteListFile != "" { + whitelist := &whitelist{ + fileName: c.WhiteListFile, + } + if err := whitelist.reload(); err != nil { + return fmt.Errorf("whitelist initialization error: %w", err) + } + logger.Info(logSender, "", "whitelist initialized from file: %#v", c.WhiteListFile) + Config.whitelist = whitelist + } vfs.SetTempPath(c.TempPath) dataprovider.SetTempPath(c.TempPath) transfersChecker = getTransfersChecker(isShared) @@ -197,13 +208,19 @@ func LimitRate(protocol, ip string) (time.Duration, error) { return 0, nil } -// ReloadDefender reloads the defender's block and safe lists -func ReloadDefender() error { - if Config.defender == nil { - return nil +// Reload reloads the whitelist and the defender's block and safe lists +func Reload() error { + var errWithelist error + if Config.whitelist != nil { + errWithelist = Config.whitelist.reload() } - - return Config.defender.Reload() + if Config.defender == nil { + return errWithelist + } + if err := Config.defender.Reload(); err != nil { + return err + } + return errWithelist } // IsBanned returns true if the specified IP address is banned @@ -379,6 +396,35 @@ func (t *ConnectionTransfer) getConnectionTransferAsString() string { return result } +type whitelist struct { + fileName string + sync.RWMutex + list HostList +} + +func (l *whitelist) reload() error { + list, err := loadHostListFromFile(l.fileName) + if err != nil { + return err + } + if list == nil { + return errors.New("cannot accept a nil whitelist") + } + + l.Lock() + defer l.Unlock() + + l.list = *list + return nil +} + +func (l *whitelist) isAllowed(ip string) bool { + l.RLock() + defer l.RUnlock() + + return l.list.isListed(ip) +} + // Configuration defines configuration parameters common to all supported protocols type Configuration struct { // Maximum idle timeout as minutes. If a client is idle for a time that exceeds this setting it will be disconnected. @@ -444,6 +490,10 @@ type Configuration struct { MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"` // Maximum number of concurrent client connections from the same host (IP). 0 means unlimited MaxPerHostConnections int `json:"max_per_host_connections" mapstructure:"max_per_host_connections"` + // Path to a file containing a list of IP addresses and/or networks to allow. + // Only the listed IPs/networks can access the configured services, all other client connections + // will be dropped before they even try to authenticate. + WhiteListFile string `json:"whitelist_file" mapstructure:"whitelist_file"` // Defender configuration DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"` // Rate limiter configurations @@ -451,6 +501,7 @@ type Configuration struct { idleTimeoutAsDuration time.Duration idleLoginTimeout time.Duration defender Defender + whitelist *whitelist } // IsAtomicUploadEnabled returns true if atomic upload is enabled @@ -924,7 +975,13 @@ func (conns *ActiveConnections) GetClientConnections() int32 { } // IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded +// or a whitelist is defined and the specified ipAddr is not listed func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool { + if Config.whitelist != nil { + if !Config.whitelist.isAllowed(ipAddr) { + return false + } + } if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 { return true } diff --git a/common/common_test.go b/common/common_test.go index 65626d77..99513ed0 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -130,7 +130,7 @@ func TestDefenderIntegration(t *testing.T) { ip := "127.1.1.1" - assert.Nil(t, ReloadDefender()) + assert.Nil(t, Reload()) AddDefenderEvent(ip, HostEventNoLoginTried) assert.False(t, IsBanned(ip)) @@ -173,9 +173,18 @@ func TestDefenderIntegration(t *testing.T) { // ScoreInvalid cannot be greater than threshold assert.Error(t, err) Config.DefenderConfig.Threshold = 3 + Config.DefenderConfig.SafeListFile = filepath.Join(os.TempDir(), "sl.json") + err = os.WriteFile(Config.DefenderConfig.SafeListFile, []byte(`{}`), 0644) + assert.NoError(t, err) + defer os.Remove(Config.DefenderConfig.SafeListFile) + err = Initialize(Config, 0) assert.NoError(t, err) - assert.Nil(t, ReloadDefender()) + assert.Nil(t, Reload()) + err = os.WriteFile(Config.DefenderConfig.SafeListFile, []byte(`{`), 0644) + assert.NoError(t, err) + err = Reload() + assert.Error(t, err) AddDefenderEvent(ip, HostEventNoLoginTried) assert.False(t, IsBanned(ip)) @@ -291,6 +300,58 @@ func TestRateLimitersIntegration(t *testing.T) { Config = configCopy } +func TestWhitelist(t *testing.T) { + configCopy := Config + + Config.whitelist = &whitelist{} + err := Config.whitelist.reload() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "cannot accept a nil whitelist") + } + wlFile := filepath.Join(os.TempDir(), "wl.json") + Config.WhiteListFile = wlFile + + err = os.WriteFile(wlFile, []byte(`invalid list file`), 0664) + assert.NoError(t, err) + err = Initialize(Config, 0) + assert.Error(t, err) + + wl := HostListFile{ + IPAddresses: []string{"172.18.1.1", "172.18.1.2"}, + CIDRNetworks: []string{"10.8.7.0/24"}, + } + data, err := json.Marshal(wl) + assert.NoError(t, err) + err = os.WriteFile(wlFile, data, 0664) + assert.NoError(t, err) + defer os.Remove(wlFile) + + err = Initialize(Config, 0) + assert.NoError(t, err) + + assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.1")) + assert.False(t, Connections.IsNewConnectionAllowed("172.18.1.3")) + assert.True(t, Connections.IsNewConnectionAllowed("10.8.7.3")) + assert.False(t, Connections.IsNewConnectionAllowed("10.8.8.2")) + + wl.IPAddresses = append(wl.IPAddresses, "172.18.1.3") + wl.CIDRNetworks = append(wl.CIDRNetworks, "10.8.8.0/24") + data, err = json.Marshal(wl) + assert.NoError(t, err) + err = os.WriteFile(wlFile, data, 0664) + assert.NoError(t, err) + assert.False(t, Connections.IsNewConnectionAllowed("10.8.8.3")) + + err = Reload() + assert.NoError(t, err) + assert.True(t, Connections.IsNewConnectionAllowed("10.8.8.3")) + assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.3")) + assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.2")) + assert.False(t, Connections.IsNewConnectionAllowed("172.18.1.12")) + + Config = configCopy +} + func TestMaxConnections(t *testing.T) { oldValue := Config.MaxTotalConnections perHost := Config.MaxPerHostConnections diff --git a/config/config.go b/config/config.go index 446f5799..07241346 100644 --- a/config/config.go +++ b/config/config.go @@ -169,6 +169,7 @@ func Init() { DataRetentionHook: "", MaxTotalConnections: 0, MaxPerHostConnections: 20, + WhiteListFile: "", DefenderConfig: common.DefenderConfig{ Enabled: false, Driver: common.DefenderDriverMemory, @@ -1470,6 +1471,7 @@ func setViperDefaults() { viper.SetDefault("common.data_retention_hook", globalConf.Common.DataRetentionHook) viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections) viper.SetDefault("common.max_per_host_connections", globalConf.Common.MaxPerHostConnections) + viper.SetDefault("common.whitelist_file", globalConf.Common.WhiteListFile) viper.SetDefault("common.defender.enabled", globalConf.Common.DefenderConfig.Enabled) viper.SetDefault("common.defender.driver", globalConf.Common.DefenderConfig.Driver) viper.SetDefault("common.defender.ban_time", globalConf.Common.DefenderConfig.BanTime) diff --git a/docs/full-configuration.md b/docs/full-configuration.md index 4128395b..0c1b972b 100644 --- a/docs/full-configuration.md +++ b/docs/full-configuration.md @@ -73,6 +73,7 @@ The configuration file contains the following sections: - `data_retention_hook`, string. Absolute path to the command to execute or HTTP URL to notify. See [Data retention hook](./data-retention-hook.md) for more details. Leave empty to disable - `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited. Default: 0. - `max_per_host_connections`, integer. Maximum number of concurrent client connections from the same host (IP). If the defender is enabled, exceeding this limit will generate `score_limit_exceeded` events and thus hosts that repeatedly exceed the max allowed connections can be automatically blocked. 0 means unlimited. Default: 20. + - `whitelist_file`, string. Path to a file containing a list of IP addresses and/or networks to allow. Only the listed IPs/networks can access the configured services, all other client connections will be dropped before they even try to authenticate. The whitelist must be a JSON file with the same structure documented for the [defenders's list](./defender.md). The whitelist can be reloaded on demand sending a `SIGHUP` signal on Unix based systems and a `paramchange` request to the running service on Windows. Default: "". - `defender`, struct containing the defender configuration. See [Defender](./defender.md) for more details. - `enabled`, boolean. Default `false`. - `driver`, string. Supported drivers are `memory` and `provider`. The `provider` driver will use the configured data provider to store defender events and it is supported for `MySQL`, `PostgreSQL` and `CockroachDB` data providers. Using the `provider` driver you can share the defender events among multiple SFTPGO instances. For a single instance the `memory` driver will be much faster. Default: `memory`. diff --git a/ftpd/server.go b/ftpd/server.go index d2a03958..4420a87a 100644 --- a/ftpd/server.go +++ b/ftpd/server.go @@ -151,8 +151,8 @@ func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) { return "Access denied: banned client IP", common.ErrConnectionDenied } if !common.Connections.IsNewConnectionAllowed(ipAddr) { - logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached") - return "Access denied: max allowed connection exceeded", common.ErrConnectionDenied + logger.Log(logger.LevelDebug, common.ProtocolFTP, "", fmt.Sprintf("connection not allowed from ip %#v", ipAddr)) + return "Access denied", common.ErrConnectionDenied } _, err := common.LimitRate(common.ProtocolFTP, ipAddr) if err != nil { diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index 2bbc164b..e254dbf8 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -8700,7 +8700,7 @@ func TestWebClientMaxConnections(t *testing.T) { setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - assert.Contains(t, rr.Body.String(), "configured connections limit reached") + assert.Contains(t, rr.Body.String(), "connection not allowed from your ip") common.Connections.Remove(connection.GetID()) _, err = httpdtest.RemoveUser(user, http.StatusOK) @@ -12617,6 +12617,60 @@ func TestWebAdminSetupMock(t *testing.T) { os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1") } +func TestWhitelist(t *testing.T) { + configCopy := common.Config + + common.Config.MaxTotalConnections = 1 + wlFile := filepath.Join(os.TempDir(), "wl.json") + common.Config.WhiteListFile = wlFile + wl := common.HostListFile{ + IPAddresses: []string{"172.120.1.1", "172.120.1.2"}, + CIDRNetworks: []string{"192.8.7.0/22"}, + } + data, err := json.Marshal(wl) + assert.NoError(t, err) + err = os.WriteFile(wlFile, data, 0664) + assert.NoError(t, err) + defer os.Remove(wlFile) + + err = common.Initialize(common.Config, 0) + assert.NoError(t, err) + + req, _ := http.NewRequest(http.MethodGet, webLoginPath, nil) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "connection not allowed from your ip") + + req.RemoteAddr = "172.120.1.1" + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req.RemoteAddr = "172.120.1.3" + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "connection not allowed from your ip") + + req.RemoteAddr = "192.8.7.1" + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + wl.IPAddresses = append(wl.IPAddresses, "172.120.1.3") + data, err = json.Marshal(wl) + assert.NoError(t, err) + err = os.WriteFile(wlFile, data, 0664) + assert.NoError(t, err) + err = common.Reload() + assert.NoError(t, err) + + req.RemoteAddr = "172.120.1.3" + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + common.Config = configCopy + err = common.Initialize(common.Config, 0) + assert.NoError(t, err) +} + func TestWebAdminLoginMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) diff --git a/httpd/server.go b/httpd/server.go index 8a8e2b62..7e41382d 100644 --- a/httpd/server.go +++ b/httpd/server.go @@ -983,8 +983,8 @@ func (s *httpdServer) checkConnection(next http.Handler) http.Handler { defer common.Connections.RemoveClientConnection(ipAddr) if !common.Connections.IsNewConnectionAllowed(ipAddr) { - logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached") - s.sendForbiddenResponse(w, r, "configured connections limit reached") + logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", fmt.Sprintf("connection not allowed from ip %#v", ipAddr)) + s.sendForbiddenResponse(w, r, "connection not allowed from your ip") return } if common.IsBanned(ipAddr) { diff --git a/service/service_windows.go b/service/service_windows.go index 7b64c0a1..8e717fee 100644 --- a/service/service_windows.go +++ b/service/service_windows.go @@ -137,9 +137,9 @@ loop: if err != nil { logger.Warn(logSender, "", "error reloading telemetry cert manager: %v", err) } - err = common.ReloadDefender() + err = common.Reload() if err != nil { - logger.Warn(logSender, "", "error reloading defender's lists: %v", err) + logger.Warn(logSender, "", "error reloading common configs: %v", err) } case rotateLogCmd: logger.Debug(logSender, "", "Received log file rotation request") diff --git a/service/signals_unix.go b/service/signals_unix.go index 5eb7c8a4..d8a4b4c8 100644 --- a/service/signals_unix.go +++ b/service/signals_unix.go @@ -57,9 +57,9 @@ func handleSIGHUP() { if err != nil { logger.Warn(logSender, "", "error reloading telemetry cert manager: %v", err) } - err = common.ReloadDefender() + err = common.Reload() if err != nil { - logger.Warn(logSender, "", "error reloading defender's lists: %v", err) + logger.Warn(logSender, "", "error reloading common configs: %v", err) } } diff --git a/sftpd/server.go b/sftpd/server.go index 45454d83..eecf2b6c 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -411,7 +411,7 @@ func canAcceptConnection(ip string) bool { return false } if !common.Connections.IsNewConnectionAllowed(ip) { - logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached") + logger.Log(logger.LevelDebug, common.ProtocolSSH, "", fmt.Sprintf("connection not allowed from ip %#v", ip)) return false } _, err := common.LimitRate(common.ProtocolSSH, ip) diff --git a/sftpgo.json b/sftpgo.json index dec1fb6e..c168460a 100644 --- a/sftpgo.json +++ b/sftpgo.json @@ -17,6 +17,7 @@ "data_retention_hook": "", "max_total_connections": 0, "max_per_host_connections": 20, + "whitelist_file": "", "defender": { "enabled": false, "driver": "memory", diff --git a/webdavd/server.go b/webdavd/server.go index 01f8eacc..b594280b 100644 --- a/webdavd/server.go +++ b/webdavd/server.go @@ -146,7 +146,7 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer common.Connections.RemoveClientConnection(ipAddr) if !common.Connections.IsNewConnectionAllowed(ipAddr) { - logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", "connection refused, configured limit reached") + logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", fmt.Sprintf("connection not allowed from ip %#v", ipAddr)) http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable) return }