Browse Source

add a global whitelist

if defined only the listed IPs/networks can access the configured
services, all other client connections will be dropped before they
even try to authenticate

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 3 years ago
parent
commit
d9f30e7ac5

+ 63 - 6
common/common.go

@@ -142,11 +142,12 @@ func Initialize(c Configuration, isShared int) error {
 	Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute
 	startPeriodicTimeoutTicker(periodicTimeoutCheckInterval)
 	Config.defender = nil
+	Config.whitelist = nil
 	rateLimiters = make(map[string][]*rateLimiter)
 	for _, rlCfg := range c.RateLimitersConfig {
 		if rlCfg.isEnabled() {
 			if err := rlCfg.validate(); err != nil {
-				return fmt.Errorf("rate limiters initialization error: %v", err)
+				return fmt.Errorf("rate limiters initialization error: %w", err)
 			}
 			allowList, err := util.ParseAllowedIPAndRanges(rlCfg.AllowList)
 			if err != nil {
@@ -177,6 +178,16 @@ func Initialize(c Configuration, isShared int) error {
 		logger.Info(logSender, "", "defender initialized with config %+v", c.DefenderConfig)
 		Config.defender = defender
 	}
+	if c.WhiteListFile != "" {
+		whitelist := &whitelist{
+			fileName: c.WhiteListFile,
+		}
+		if err := whitelist.reload(); err != nil {
+			return fmt.Errorf("whitelist initialization error: %w", err)
+		}
+		logger.Info(logSender, "", "whitelist initialized from file: %#v", c.WhiteListFile)
+		Config.whitelist = whitelist
+	}
 	vfs.SetTempPath(c.TempPath)
 	dataprovider.SetTempPath(c.TempPath)
 	transfersChecker = getTransfersChecker(isShared)
@@ -197,13 +208,19 @@ func LimitRate(protocol, ip string) (time.Duration, error) {
 	return 0, nil
 }
 
-// ReloadDefender reloads the defender's block and safe lists
-func ReloadDefender() error {
+// Reload reloads the whitelist and the defender's block and safe lists
+func Reload() error {
+	var errWithelist error
+	if Config.whitelist != nil {
+		errWithelist = Config.whitelist.reload()
+	}
 	if Config.defender == nil {
-		return nil
+		return errWithelist
 	}
-
-	return Config.defender.Reload()
+	if err := Config.defender.Reload(); err != nil {
+		return err
+	}
+	return errWithelist
 }
 
 // IsBanned returns true if the specified IP address is banned
@@ -379,6 +396,35 @@ func (t *ConnectionTransfer) getConnectionTransferAsString() string {
 	return result
 }
 
+type whitelist struct {
+	fileName string
+	sync.RWMutex
+	list HostList
+}
+
+func (l *whitelist) reload() error {
+	list, err := loadHostListFromFile(l.fileName)
+	if err != nil {
+		return err
+	}
+	if list == nil {
+		return errors.New("cannot accept a nil whitelist")
+	}
+
+	l.Lock()
+	defer l.Unlock()
+
+	l.list = *list
+	return nil
+}
+
+func (l *whitelist) isAllowed(ip string) bool {
+	l.RLock()
+	defer l.RUnlock()
+
+	return l.list.isListed(ip)
+}
+
 // Configuration defines configuration parameters common to all supported protocols
 type Configuration struct {
 	// Maximum idle timeout as minutes. If a client is idle for a time that exceeds this setting it will be disconnected.
@@ -444,6 +490,10 @@ type Configuration struct {
 	MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"`
 	// Maximum number of concurrent client connections from the same host (IP). 0 means unlimited
 	MaxPerHostConnections int `json:"max_per_host_connections" mapstructure:"max_per_host_connections"`
+	// Path to a file containing a list of IP addresses and/or networks to allow.
+	// Only the listed IPs/networks can access the configured services, all other client connections
+	// will be dropped before they even try to authenticate.
+	WhiteListFile string `json:"whitelist_file" mapstructure:"whitelist_file"`
 	// Defender configuration
 	DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"`
 	// Rate limiter configurations
@@ -451,6 +501,7 @@ type Configuration struct {
 	idleTimeoutAsDuration time.Duration
 	idleLoginTimeout      time.Duration
 	defender              Defender
+	whitelist             *whitelist
 }
 
 // IsAtomicUploadEnabled returns true if atomic upload is enabled
@@ -924,7 +975,13 @@ func (conns *ActiveConnections) GetClientConnections() int32 {
 }
 
 // IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
+// or a whitelist is defined and the specified ipAddr is not listed
 func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool {
+	if Config.whitelist != nil {
+		if !Config.whitelist.isAllowed(ipAddr) {
+			return false
+		}
+	}
 	if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 {
 		return true
 	}

+ 63 - 2
common/common_test.go

@@ -130,7 +130,7 @@ func TestDefenderIntegration(t *testing.T) {
 
 	ip := "127.1.1.1"
 
-	assert.Nil(t, ReloadDefender())
+	assert.Nil(t, Reload())
 
 	AddDefenderEvent(ip, HostEventNoLoginTried)
 	assert.False(t, IsBanned(ip))
@@ -173,9 +173,18 @@ func TestDefenderIntegration(t *testing.T) {
 	// ScoreInvalid cannot be greater than threshold
 	assert.Error(t, err)
 	Config.DefenderConfig.Threshold = 3
+	Config.DefenderConfig.SafeListFile = filepath.Join(os.TempDir(), "sl.json")
+	err = os.WriteFile(Config.DefenderConfig.SafeListFile, []byte(`{}`), 0644)
+	assert.NoError(t, err)
+	defer os.Remove(Config.DefenderConfig.SafeListFile)
+
 	err = Initialize(Config, 0)
 	assert.NoError(t, err)
-	assert.Nil(t, ReloadDefender())
+	assert.Nil(t, Reload())
+	err = os.WriteFile(Config.DefenderConfig.SafeListFile, []byte(`{`), 0644)
+	assert.NoError(t, err)
+	err = Reload()
+	assert.Error(t, err)
 
 	AddDefenderEvent(ip, HostEventNoLoginTried)
 	assert.False(t, IsBanned(ip))
@@ -291,6 +300,58 @@ func TestRateLimitersIntegration(t *testing.T) {
 	Config = configCopy
 }
 
+func TestWhitelist(t *testing.T) {
+	configCopy := Config
+
+	Config.whitelist = &whitelist{}
+	err := Config.whitelist.reload()
+	if assert.Error(t, err) {
+		assert.Contains(t, err.Error(), "cannot accept a nil whitelist")
+	}
+	wlFile := filepath.Join(os.TempDir(), "wl.json")
+	Config.WhiteListFile = wlFile
+
+	err = os.WriteFile(wlFile, []byte(`invalid list file`), 0664)
+	assert.NoError(t, err)
+	err = Initialize(Config, 0)
+	assert.Error(t, err)
+
+	wl := HostListFile{
+		IPAddresses:  []string{"172.18.1.1", "172.18.1.2"},
+		CIDRNetworks: []string{"10.8.7.0/24"},
+	}
+	data, err := json.Marshal(wl)
+	assert.NoError(t, err)
+	err = os.WriteFile(wlFile, data, 0664)
+	assert.NoError(t, err)
+	defer os.Remove(wlFile)
+
+	err = Initialize(Config, 0)
+	assert.NoError(t, err)
+
+	assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.1"))
+	assert.False(t, Connections.IsNewConnectionAllowed("172.18.1.3"))
+	assert.True(t, Connections.IsNewConnectionAllowed("10.8.7.3"))
+	assert.False(t, Connections.IsNewConnectionAllowed("10.8.8.2"))
+
+	wl.IPAddresses = append(wl.IPAddresses, "172.18.1.3")
+	wl.CIDRNetworks = append(wl.CIDRNetworks, "10.8.8.0/24")
+	data, err = json.Marshal(wl)
+	assert.NoError(t, err)
+	err = os.WriteFile(wlFile, data, 0664)
+	assert.NoError(t, err)
+	assert.False(t, Connections.IsNewConnectionAllowed("10.8.8.3"))
+
+	err = Reload()
+	assert.NoError(t, err)
+	assert.True(t, Connections.IsNewConnectionAllowed("10.8.8.3"))
+	assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.3"))
+	assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.2"))
+	assert.False(t, Connections.IsNewConnectionAllowed("172.18.1.12"))
+
+	Config = configCopy
+}
+
 func TestMaxConnections(t *testing.T) {
 	oldValue := Config.MaxTotalConnections
 	perHost := Config.MaxPerHostConnections

+ 2 - 0
config/config.go

@@ -169,6 +169,7 @@ func Init() {
 			DataRetentionHook:     "",
 			MaxTotalConnections:   0,
 			MaxPerHostConnections: 20,
+			WhiteListFile:         "",
 			DefenderConfig: common.DefenderConfig{
 				Enabled:            false,
 				Driver:             common.DefenderDriverMemory,
@@ -1470,6 +1471,7 @@ func setViperDefaults() {
 	viper.SetDefault("common.data_retention_hook", globalConf.Common.DataRetentionHook)
 	viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections)
 	viper.SetDefault("common.max_per_host_connections", globalConf.Common.MaxPerHostConnections)
+	viper.SetDefault("common.whitelist_file", globalConf.Common.WhiteListFile)
 	viper.SetDefault("common.defender.enabled", globalConf.Common.DefenderConfig.Enabled)
 	viper.SetDefault("common.defender.driver", globalConf.Common.DefenderConfig.Driver)
 	viper.SetDefault("common.defender.ban_time", globalConf.Common.DefenderConfig.BanTime)

+ 1 - 0
docs/full-configuration.md

@@ -73,6 +73,7 @@ The configuration file contains the following sections:
   - `data_retention_hook`, string. Absolute path to the command to execute or HTTP URL to notify. See [Data retention hook](./data-retention-hook.md) for more details. Leave empty to disable
   - `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited. Default: 0.
   - `max_per_host_connections`, integer.  Maximum number of concurrent client connections from the same host (IP). If the defender is enabled, exceeding this limit will generate `score_limit_exceeded` events and thus hosts that repeatedly exceed the max allowed connections can be automatically blocked. 0 means unlimited. Default: 20.
+  - `whitelist_file`, string. Path to a file containing a list of IP addresses and/or networks to allow. Only the listed IPs/networks can access the configured services, all other client connections will be dropped before they even try to authenticate. The whitelist must be a JSON file with the same structure documented for the [defenders's list](./defender.md). The whitelist can be reloaded on demand sending a `SIGHUP` signal on Unix based systems and a `paramchange` request to the running service on Windows. Default: "".
   - `defender`, struct containing the defender configuration. See [Defender](./defender.md) for more details.
     - `enabled`, boolean. Default `false`.
     - `driver`, string. Supported drivers are `memory` and `provider`. The `provider` driver will use the configured data provider to store defender events and it is supported for `MySQL`, `PostgreSQL` and `CockroachDB` data providers. Using the `provider` driver you can share the defender events among multiple SFTPGO instances. For a single instance the `memory` driver will be much faster. Default: `memory`.

+ 2 - 2
ftpd/server.go

@@ -151,8 +151,8 @@ func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
 		return "Access denied: banned client IP", common.ErrConnectionDenied
 	}
 	if !common.Connections.IsNewConnectionAllowed(ipAddr) {
-		logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
-		return "Access denied: max allowed connection exceeded", common.ErrConnectionDenied
+		logger.Log(logger.LevelDebug, common.ProtocolFTP, "", fmt.Sprintf("connection not allowed from ip %#v", ipAddr))
+		return "Access denied", common.ErrConnectionDenied
 	}
 	_, err := common.LimitRate(common.ProtocolFTP, ipAddr)
 	if err != nil {

+ 55 - 1
httpd/httpd_test.go

@@ -8700,7 +8700,7 @@ func TestWebClientMaxConnections(t *testing.T) {
 	setJWTCookieForReq(req, webToken)
 	rr = executeRequest(req)
 	checkResponseCode(t, http.StatusForbidden, rr)
-	assert.Contains(t, rr.Body.String(), "configured connections limit reached")
+	assert.Contains(t, rr.Body.String(), "connection not allowed from your ip")
 
 	common.Connections.Remove(connection.GetID())
 	_, err = httpdtest.RemoveUser(user, http.StatusOK)
@@ -12617,6 +12617,60 @@ func TestWebAdminSetupMock(t *testing.T) {
 	os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1")
 }
 
+func TestWhitelist(t *testing.T) {
+	configCopy := common.Config
+
+	common.Config.MaxTotalConnections = 1
+	wlFile := filepath.Join(os.TempDir(), "wl.json")
+	common.Config.WhiteListFile = wlFile
+	wl := common.HostListFile{
+		IPAddresses:  []string{"172.120.1.1", "172.120.1.2"},
+		CIDRNetworks: []string{"192.8.7.0/22"},
+	}
+	data, err := json.Marshal(wl)
+	assert.NoError(t, err)
+	err = os.WriteFile(wlFile, data, 0664)
+	assert.NoError(t, err)
+	defer os.Remove(wlFile)
+
+	err = common.Initialize(common.Config, 0)
+	assert.NoError(t, err)
+
+	req, _ := http.NewRequest(http.MethodGet, webLoginPath, nil)
+	rr := executeRequest(req)
+	checkResponseCode(t, http.StatusForbidden, rr)
+	assert.Contains(t, rr.Body.String(), "connection not allowed from your ip")
+
+	req.RemoteAddr = "172.120.1.1"
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr)
+
+	req.RemoteAddr = "172.120.1.3"
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusForbidden, rr)
+	assert.Contains(t, rr.Body.String(), "connection not allowed from your ip")
+
+	req.RemoteAddr = "192.8.7.1"
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr)
+
+	wl.IPAddresses = append(wl.IPAddresses, "172.120.1.3")
+	data, err = json.Marshal(wl)
+	assert.NoError(t, err)
+	err = os.WriteFile(wlFile, data, 0664)
+	assert.NoError(t, err)
+	err = common.Reload()
+	assert.NoError(t, err)
+
+	req.RemoteAddr = "172.120.1.3"
+	rr = executeRequest(req)
+	checkResponseCode(t, http.StatusOK, rr)
+
+	common.Config = configCopy
+	err = common.Initialize(common.Config, 0)
+	assert.NoError(t, err)
+}
+
 func TestWebAdminLoginMock(t *testing.T) {
 	webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass)
 	assert.NoError(t, err)

+ 2 - 2
httpd/server.go

@@ -983,8 +983,8 @@ func (s *httpdServer) checkConnection(next http.Handler) http.Handler {
 		defer common.Connections.RemoveClientConnection(ipAddr)
 
 		if !common.Connections.IsNewConnectionAllowed(ipAddr) {
-			logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection refused, configured limit reached")
-			s.sendForbiddenResponse(w, r, "configured connections limit reached")
+			logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", fmt.Sprintf("connection not allowed from ip %#v", ipAddr))
+			s.sendForbiddenResponse(w, r, "connection not allowed from your ip")
 			return
 		}
 		if common.IsBanned(ipAddr) {

+ 2 - 2
service/service_windows.go

@@ -137,9 +137,9 @@ loop:
 			if err != nil {
 				logger.Warn(logSender, "", "error reloading telemetry cert manager: %v", err)
 			}
-			err = common.ReloadDefender()
+			err = common.Reload()
 			if err != nil {
-				logger.Warn(logSender, "", "error reloading defender's lists: %v", err)
+				logger.Warn(logSender, "", "error reloading common configs: %v", err)
 			}
 		case rotateLogCmd:
 			logger.Debug(logSender, "", "Received log file rotation request")

+ 2 - 2
service/signals_unix.go

@@ -57,9 +57,9 @@ func handleSIGHUP() {
 	if err != nil {
 		logger.Warn(logSender, "", "error reloading telemetry cert manager: %v", err)
 	}
-	err = common.ReloadDefender()
+	err = common.Reload()
 	if err != nil {
-		logger.Warn(logSender, "", "error reloading defender's lists: %v", err)
+		logger.Warn(logSender, "", "error reloading common configs: %v", err)
 	}
 }
 

+ 1 - 1
sftpd/server.go

@@ -411,7 +411,7 @@ func canAcceptConnection(ip string) bool {
 		return false
 	}
 	if !common.Connections.IsNewConnectionAllowed(ip) {
-		logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached")
+		logger.Log(logger.LevelDebug, common.ProtocolSSH, "", fmt.Sprintf("connection not allowed from ip %#v", ip))
 		return false
 	}
 	_, err := common.LimitRate(common.ProtocolSSH, ip)

+ 1 - 0
sftpgo.json

@@ -17,6 +17,7 @@
     "data_retention_hook": "",
     "max_total_connections": 0,
     "max_per_host_connections": 20,
+    "whitelist_file": "",
     "defender": {
       "enabled": false,
       "driver": "memory",

+ 1 - 1
webdavd/server.go

@@ -146,7 +146,7 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	defer common.Connections.RemoveClientConnection(ipAddr)
 
 	if !common.Connections.IsNewConnectionAllowed(ipAddr) {
-		logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", "connection refused, configured limit reached")
+		logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", fmt.Sprintf("connection not allowed from ip %#v", ipAddr))
 		http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable)
 		return
 	}