Browse Source

defender: allow to set a different score for "no auth tried" events

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 2 years ago
parent
commit
c0fe08b597

+ 5 - 2
docs/defender.md

@@ -2,14 +2,17 @@
 
 
 The built-in `defender` allows you to configure an auto-blocking policy for SFTPGo and thus helps to prevent DoS (Denial of Service) and brute force password guessing.
 The built-in `defender` allows you to configure an auto-blocking policy for SFTPGo and thus helps to prevent DoS (Denial of Service) and brute force password guessing.
 
 
-If enabled it will protect SFTP, HTTP, FTP and WebDAV services and it will automatically block hosts (IP addresses) that continually fail to log in or attempt to connect.
+If enabled it will protect SFTP, HTTP (WebClient and user API), FTP and WebDAV services and it will automatically block hosts (IP addresses) that continually fail to log in or attempt to connect.
 
 
 You can configure a score for the following events:
 You can configure a score for the following events:
 
 
 - `score_valid`, defines the score for valid login attempts, eg. user accounts that exist. Default `1`.
 - `score_valid`, defines the score for valid login attempts, eg. user accounts that exist. Default `1`.
-- `score_invalid`, defines the score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts. Default `2`.
+- `score_invalid`, defines the score for invalid login attempts, eg. non-existent user accounts. Default `2`.
+- `score_no_auth`, defines the score for clients disconnected without any authentication attempt. Default `0`.
 - `score_limit_exceeded`, defines the score for hosts that exceeded the configured rate limits or the configured max connections per host. Default `3`.
 - `score_limit_exceeded`, defines the score for hosts that exceeded the configured rate limits or the configured max connections per host. Default `3`.
 
 
+You can set the score to `0` to not penalize some events.
+
 And then you can configure:
 And then you can configure:
 
 
 - `observation_time`, defines the time window, in minutes, for tracking client errors.
 - `observation_time`, defines the time window, in minutes, for tracking client errors.

+ 10 - 9
docs/full-configuration.md

@@ -86,15 +86,16 @@ The configuration file contains the following sections:
   - `defender`, struct containing the defender configuration. See [Defender](./defender.md) for more details.
   - `defender`, struct containing the defender configuration. See [Defender](./defender.md) for more details.
     - `enabled`, boolean. Default `false`.
     - `enabled`, boolean. Default `false`.
     - `driver`, string. Supported drivers are `memory` and `provider`. The `provider` driver will use the configured data provider to store defender events and it is supported for `MySQL`, `PostgreSQL` and `CockroachDB` data providers. Using the `provider` driver you can share the defender events among multiple SFTPGO instances. For a single instance the `memory` driver will be much faster. Default: `memory`.
     - `driver`, string. Supported drivers are `memory` and `provider`. The `provider` driver will use the configured data provider to store defender events and it is supported for `MySQL`, `PostgreSQL` and `CockroachDB` data providers. Using the `provider` driver you can share the defender events among multiple SFTPGO instances. For a single instance the `memory` driver will be much faster. Default: `memory`.
-    - `ban_time`, integer. Ban time in minutes.
-    - `ban_time_increment`, integer. Ban time increment, as a percentage, if a banned host tries to connect again.
-    - `threshold`, integer. Threshold value for banning a client.
-    - `score_invalid`, integer. Score for invalid login attempts, eg. non-existent user accounts or client disconnected for inactivity without authentication attempts.
-    - `score_valid`, integer. Score for valid login attempts, eg. user accounts that exist.
-    - `score_limit_exceeded`, integer. Score for hosts that exceeded the configured rate limits or the maximum, per-host, allowed connections.
-    - `observation_time`, integer. Defines the time window, in minutes, for tracking client errors. A host is banned if it has exceeded the defined threshold during the last observation time minutes.
-    - `entries_soft_limit`, integer. Ignored for `provider` driver. Default: 100.
-    - `entries_hard_limit`, integer. The number of banned IPs and host scores kept in memory will vary between the soft and hard limit for `memory` driver. If you use the `provider` driver, this setting will limit the number of entries to return when you ask for the entire host list from the defender. Default: 150.
+    - `ban_time`, integer. Ban time in minutes. Default: `30`.
+    - `ban_time_increment`, integer. Ban time increment, as a percentage, if a banned host tries to connect again. Default: `50`.
+    - `threshold`, integer. Threshold value for banning a client. Default: `15`.
+    - `score_invalid`, integer. Score for invalid login attempts, eg. non-existent user accounts. Default: `2`.
+    - `score_valid`, integer. Score for valid login attempts, eg. user accounts that exist. Default: `1`.
+    - `score_limit_exceeded`, integer. Score for hosts that exceeded the configured rate limits or the maximum, per-host, allowed connections. Default: `3`.
+    - `score_no_auth`, defines the score for clients disconnected without any authentication attempt. Default: `0`.
+    - `observation_time`, integer. Defines the time window, in minutes, for tracking client errors. A host is banned if it has exceeded the defined threshold during the last observation time minutes. Default: `30`.
+    - `entries_soft_limit`, integer. Ignored for `provider` driver. Default: `100`.
+    - `entries_hard_limit`, integer. The number of banned IPs and host scores kept in memory will vary between the soft and hard limit for `memory` driver. If you use the `provider` driver, this setting will limit the number of entries to return when you ask for the entire host list from the defender. Default: `150`.
     - `safelist_file`, string. Path to a file containing a list of ip addresses and/or networks to never ban.
     - `safelist_file`, string. Path to a file containing a list of ip addresses and/or networks to never ban.
     - `blocklist_file`, string. Path to a file containing a list of ip addresses and/or networks to always ban. The lists can be reloaded on demand sending a `SIGHUP` signal on Unix based systems and a `paramchange` request to the running service on Windows. An host that is already banned will not be automatically unbanned if you put it inside the safe list, you have to unban it using the REST API.
     - `blocklist_file`, string. Path to a file containing a list of ip addresses and/or networks to always ban. The lists can be reloaded on demand sending a `SIGHUP` signal on Unix based systems and a `paramchange` request to the running service on Windows. An host that is already banned will not be automatically unbanned if you put it inside the safe list, you have to unban it using the REST API.
     - `safelist`, list of IP addresses and/or IP ranges and/or networks to never ban. Invalid entries will be silently ignored. For large lists prefer `safelist_file`. `safelist` and `safelist_file` will be merged so that you can set both.
     - `safelist`, list of IP addresses and/or IP ranges and/or networks to never ban. Invalid entries will be silently ignored. For large lists prefer `safelist_file`. `safelist` and `safelist_file` will be merged so that you can set both.

+ 2 - 2
go.mod

@@ -157,8 +157,8 @@ require (
 	golang.org/x/tools v0.5.0 // indirect
 	golang.org/x/tools v0.5.0 // indirect
 	golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
 	golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
 	google.golang.org/appengine v1.6.7 // indirect
 	google.golang.org/appengine v1.6.7 // indirect
-	google.golang.org/genproto v0.0.0-20230124163310-31e0e69b6fc2 // indirect
-	google.golang.org/grpc v1.52.0 // indirect
+	google.golang.org/genproto v0.0.0-20230125152338-dcaf20b6aeaa // indirect
+	google.golang.org/grpc v1.52.1 // indirect
 	google.golang.org/protobuf v1.28.1 // indirect
 	google.golang.org/protobuf v1.28.1 // indirect
 	gopkg.in/ini.v1 v1.67.0 // indirect
 	gopkg.in/ini.v1 v1.67.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect

+ 4 - 4
go.sum

@@ -2711,8 +2711,8 @@ google.golang.org/genproto v0.0.0-20221109142239-94d6d90a7d66/go.mod h1:rZS5c/ZV
 google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6/go.mod h1:rZS5c/ZVYMaOGBfO68GWtjOw/eLaZM1X6iVtgjZ+EWg=
 google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6/go.mod h1:rZS5c/ZVYMaOGBfO68GWtjOw/eLaZM1X6iVtgjZ+EWg=
 google.golang.org/genproto v0.0.0-20221201164419-0e50fba7f41c/go.mod h1:rZS5c/ZVYMaOGBfO68GWtjOw/eLaZM1X6iVtgjZ+EWg=
 google.golang.org/genproto v0.0.0-20221201164419-0e50fba7f41c/go.mod h1:rZS5c/ZVYMaOGBfO68GWtjOw/eLaZM1X6iVtgjZ+EWg=
 google.golang.org/genproto v0.0.0-20221201204527-e3fa12d562f3/go.mod h1:rZS5c/ZVYMaOGBfO68GWtjOw/eLaZM1X6iVtgjZ+EWg=
 google.golang.org/genproto v0.0.0-20221201204527-e3fa12d562f3/go.mod h1:rZS5c/ZVYMaOGBfO68GWtjOw/eLaZM1X6iVtgjZ+EWg=
-google.golang.org/genproto v0.0.0-20230124163310-31e0e69b6fc2 h1:O97sLx/Xmb/KIZHB/2/BzofxBs5QmmR0LcihPtllmbc=
-google.golang.org/genproto v0.0.0-20230124163310-31e0e69b6fc2/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM=
+google.golang.org/genproto v0.0.0-20230125152338-dcaf20b6aeaa h1:qQPhfbPO23fwm/9lQr91L1u62Zo6cm+zI+slZT+uf+o=
+google.golang.org/genproto v0.0.0-20230125152338-dcaf20b6aeaa/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM=
 google.golang.org/grpc v0.0.0-20160317175043-d3ddb4469d5a/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
 google.golang.org/grpc v0.0.0-20160317175043-d3ddb4469d5a/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
 google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
 google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
 google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
 google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
@@ -2758,8 +2758,8 @@ google.golang.org/grpc v1.49.0/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCD
 google.golang.org/grpc v1.50.0/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI=
 google.golang.org/grpc v1.50.0/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI=
 google.golang.org/grpc v1.50.1/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI=
 google.golang.org/grpc v1.50.1/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI=
 google.golang.org/grpc v1.51.0/go.mod h1:wgNDFcnuBGmxLKI/qn4T+m5BtEBYXJPvibbUPsAIPww=
 google.golang.org/grpc v1.51.0/go.mod h1:wgNDFcnuBGmxLKI/qn4T+m5BtEBYXJPvibbUPsAIPww=
-google.golang.org/grpc v1.52.0 h1:kd48UiU7EHsV4rnLyOJRuP/Il/UHE7gdDAQ+SZI7nZk=
-google.golang.org/grpc v1.52.0/go.mod h1:pu6fVzoFb+NBYNAvQL08ic+lvB2IojljRYuun5vorUY=
+google.golang.org/grpc v1.52.1 h1:2NpOPk5g5Xtb0qebIEs7hNIa++PdtZLo2AQUpc1YnSU=
+google.golang.org/grpc v1.52.1/go.mod h1:pu6fVzoFb+NBYNAvQL08ic+lvB2IojljRYuun5vorUY=
 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw=
 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw=
 google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
 google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
 google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
 google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=

+ 6 - 5
internal/common/common.go

@@ -156,8 +156,9 @@ var (
 		ProtocolHTTP, ProtocolHTTPShare, ProtocolOIDC}
 		ProtocolHTTP, ProtocolHTTPShare, ProtocolOIDC}
 	disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
 	disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
 	// the map key is the protocol, for each protocol we can have multiple rate limiters
 	// the map key is the protocol, for each protocol we can have multiple rate limiters
-	rateLimiters   map[string][]*rateLimiter
-	isShuttingDown atomic.Bool
+	rateLimiters     map[string][]*rateLimiter
+	isShuttingDown   atomic.Bool
+	ftpLoginCommands = []string{"PASS", "USER"}
 )
 )
 
 
 // Initialize sets the common configuration
 // Initialize sets the common configuration
@@ -191,7 +192,7 @@ func Initialize(c Configuration, isShared int) error {
 	}
 	}
 	if c.DefenderConfig.Enabled {
 	if c.DefenderConfig.Enabled {
 		if !util.Contains(supportedDefenderDrivers, c.DefenderConfig.Driver) {
 		if !util.Contains(supportedDefenderDrivers, c.DefenderConfig.Driver) {
-			return fmt.Errorf("unsupported defender driver %#v", c.DefenderConfig.Driver)
+			return fmt.Errorf("unsupported defender driver %q", c.DefenderConfig.Driver)
 		}
 		}
 		var defender Defender
 		var defender Defender
 		var err error
 		var err error
@@ -933,9 +934,9 @@ func (conns *ActiveConnections) Remove(connectionID string) {
 		}
 		}
 		conns.removeUserConnection(conn.GetUsername())
 		conns.removeUserConnection(conn.GetUsername())
 		metric.UpdateActiveConnectionsSize(lastIdx)
 		metric.UpdateActiveConnectionsSize(lastIdx)
-		logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %#v, remote address %#v close fs error: %v, num open connections: %v",
+		logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %q, remote address %q close fs error: %v, num open connections: %d",
 			conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)
 			conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)
-		if conn.GetProtocol() == ProtocolFTP && conn.GetUsername() == "" {
+		if conn.GetProtocol() == ProtocolFTP && conn.GetUsername() == "" && !util.Contains(ftpLoginCommands, conn.GetCommand()) {
 			ip := util.GetIPFromRemoteAddress(conn.GetRemoteAddress())
 			ip := util.GetIPFromRemoteAddress(conn.GetRemoteAddress())
 			logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, conn.GetProtocol(),
 			logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, conn.GetProtocol(),
 				dataprovider.ErrNoAuthTryed.Error())
 				dataprovider.ErrNoAuthTryed.Error())

+ 1 - 0
internal/common/common_test.go

@@ -324,6 +324,7 @@ func TestDefenderIntegration(t *testing.T) {
 		Threshold:        0,
 		Threshold:        0,
 		ScoreInvalid:     2,
 		ScoreInvalid:     2,
 		ScoreValid:       1,
 		ScoreValid:       1,
+		ScoreNoAuth:      2,
 		ObservationTime:  15,
 		ObservationTime:  15,
 		EntriesSoftLimit: 100,
 		EntriesSoftLimit: 100,
 		EntriesHardLimit: 150,
 		EntriesHardLimit: 150,

+ 35 - 6
internal/common/defender.go

@@ -78,14 +78,16 @@ type DefenderConfig struct {
 	BanTimeIncrement int `json:"ban_time_increment" mapstructure:"ban_time_increment"`
 	BanTimeIncrement int `json:"ban_time_increment" mapstructure:"ban_time_increment"`
 	// Threshold value for banning a client
 	// Threshold value for banning a client
 	Threshold int `json:"threshold" mapstructure:"threshold"`
 	Threshold int `json:"threshold" mapstructure:"threshold"`
-	// Score for invalid login attempts, eg. non-existent user accounts or
-	// client disconnected for inactivity without authentication attempts
+	// Score for invalid login attempts, eg. non-existent user accounts
 	ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"`
 	ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"`
 	// Score for valid login attempts, eg. user accounts that exist
 	// Score for valid login attempts, eg. user accounts that exist
 	ScoreValid int `json:"score_valid" mapstructure:"score_valid"`
 	ScoreValid int `json:"score_valid" mapstructure:"score_valid"`
 	// Score for limit exceeded events, generated from the rate limiters or for max connections
 	// Score for limit exceeded events, generated from the rate limiters or for max connections
 	// per-host exceeded
 	// per-host exceeded
 	ScoreLimitExceeded int `json:"score_limit_exceeded" mapstructure:"score_limit_exceeded"`
 	ScoreLimitExceeded int `json:"score_limit_exceeded" mapstructure:"score_limit_exceeded"`
+	// ScoreNoAuth defines the score for clients disconnected without authentication
+	// attempts
+	ScoreNoAuth int `json:"score_no_auth" mapstructure:"score_no_auth"`
 	// Defines the time window, in minutes, for tracking client errors.
 	// Defines the time window, in minutes, for tracking client errors.
 	// A host is banned if it has exceeded the defined threshold during
 	// A host is banned if it has exceeded the defined threshold during
 	// the last observation time minutes
 	// the last observation time minutes
@@ -157,8 +159,10 @@ func (d *baseDefender) getScore(event HostEvent) int {
 		score = d.config.ScoreValid
 		score = d.config.ScoreValid
 	case HostEventLimitExceeded:
 	case HostEventLimitExceeded:
 		score = d.config.ScoreLimitExceeded
 		score = d.config.ScoreLimitExceeded
-	case HostEventUserNotFound, HostEventNoLoginTried:
+	case HostEventUserNotFound:
 		score = d.config.ScoreInvalid
 		score = d.config.ScoreInvalid
+	case HostEventNoLoginTried:
+		score = d.config.ScoreNoAuth
 	}
 	}
 	return score
 	return score
 }
 }
@@ -198,19 +202,44 @@ type hostScore struct {
 	Events     []hostEvent
 	Events     []hostEvent
 }
 }
 
 
+func (c *DefenderConfig) checkScores() error {
+	if c.ScoreInvalid < 0 {
+		c.ScoreInvalid = 0
+	}
+	if c.ScoreValid < 0 {
+		c.ScoreValid = 0
+	}
+	if c.ScoreLimitExceeded < 0 {
+		c.ScoreLimitExceeded = 0
+	}
+	if c.ScoreNoAuth < 0 {
+		c.ScoreNoAuth = 0
+	}
+	if c.ScoreInvalid == 0 && c.ScoreValid == 0 && c.ScoreLimitExceeded == 0 && c.ScoreNoAuth == 0 {
+		return fmt.Errorf("invalid defender configuration: all scores are disabled")
+	}
+	return nil
+}
+
 // validate returns an error if the configuration is invalid
 // validate returns an error if the configuration is invalid
 func (c *DefenderConfig) validate() error {
 func (c *DefenderConfig) validate() error {
 	if !c.Enabled {
 	if !c.Enabled {
 		return nil
 		return nil
 	}
 	}
+	if err := c.checkScores(); err != nil {
+		return err
+	}
 	if c.ScoreInvalid >= c.Threshold {
 	if c.ScoreInvalid >= c.Threshold {
-		return fmt.Errorf("score_invalid %v cannot be greater than threshold %v", c.ScoreInvalid, c.Threshold)
+		return fmt.Errorf("score_invalid %d cannot be greater than threshold %d", c.ScoreInvalid, c.Threshold)
 	}
 	}
 	if c.ScoreValid >= c.Threshold {
 	if c.ScoreValid >= c.Threshold {
-		return fmt.Errorf("score_valid %v cannot be greater than threshold %v", c.ScoreValid, c.Threshold)
+		return fmt.Errorf("score_valid %d cannot be greater than threshold %d", c.ScoreValid, c.Threshold)
 	}
 	}
 	if c.ScoreLimitExceeded >= c.Threshold {
 	if c.ScoreLimitExceeded >= c.Threshold {
-		return fmt.Errorf("score_limit_exceeded %v cannot be greater than threshold %v", c.ScoreLimitExceeded, c.Threshold)
+		return fmt.Errorf("score_limit_exceeded %d cannot be greater than threshold %d", c.ScoreLimitExceeded, c.Threshold)
+	}
+	if c.ScoreNoAuth >= c.Threshold {
+		return fmt.Errorf("score_no_auth %d cannot be greater than threshold %d", c.ScoreNoAuth, c.Threshold)
 	}
 	}
 	if c.BanTime <= 0 {
 	if c.BanTime <= 0 {
 		return fmt.Errorf("invalid ban_time %v", c.BanTime)
 		return fmt.Errorf("invalid ban_time %v", c.BanTime)

+ 21 - 1
internal/common/defender_test.go

@@ -62,6 +62,7 @@ func TestBasicDefender(t *testing.T) {
 		Threshold:          5,
 		Threshold:          5,
 		ScoreInvalid:       2,
 		ScoreInvalid:       2,
 		ScoreValid:         1,
 		ScoreValid:         1,
+		ScoreNoAuth:        2,
 		ScoreLimitExceeded: 3,
 		ScoreLimitExceeded: 3,
 		ObservationTime:    15,
 		ObservationTime:    15,
 		EntriesSoftLimit:   1,
 		EntriesSoftLimit:   1,
@@ -140,7 +141,7 @@ func TestBasicDefender(t *testing.T) {
 		assert.True(t, hosts[0].BanTime.IsZero())
 		assert.True(t, hosts[0].BanTime.IsZero())
 		assert.Empty(t, hosts[0].GetBanTime())
 		assert.Empty(t, hosts[0].GetBanTime())
 	}
 	}
-	defender.AddEvent(testIP, HostEventNoLoginTried)
+	defender.AddEvent(testIP, HostEventUserNotFound)
 	defender.AddEvent(testIP, HostEventNoLoginTried)
 	defender.AddEvent(testIP, HostEventNoLoginTried)
 	assert.Equal(t, 0, defender.countHosts())
 	assert.Equal(t, 0, defender.countHosts())
 	assert.Equal(t, 1, defender.countBanned())
 	assert.Equal(t, 1, defender.countBanned())
@@ -511,6 +512,11 @@ func TestDefenderConfig(t *testing.T) {
 	require.Error(t, err)
 	require.Error(t, err)
 
 
 	c.ScoreValid = 1
 	c.ScoreValid = 1
+	c.ScoreNoAuth = 10
+	err = c.validate()
+	require.Error(t, err)
+
+	c.ScoreNoAuth = 2
 	c.BanTime = 0
 	c.BanTime = 0
 	err = c.validate()
 	err = c.validate()
 	require.Error(t, err)
 	require.Error(t, err)
@@ -540,6 +546,20 @@ func TestDefenderConfig(t *testing.T) {
 	c.EntriesHardLimit = 20
 	c.EntriesHardLimit = 20
 	err = c.validate()
 	err = c.validate()
 	require.NoError(t, err)
 	require.NoError(t, err)
+
+	c = DefenderConfig{
+		Enabled:            true,
+		ScoreInvalid:       -1,
+		ScoreLimitExceeded: -1,
+		ScoreNoAuth:        -1,
+		ScoreValid:         -1,
+	}
+	err = c.validate()
+	require.Error(t, err)
+	assert.Equal(t, 0, c.ScoreInvalid)
+	assert.Equal(t, 0, c.ScoreValid)
+	assert.Equal(t, 0, c.ScoreLimitExceeded)
+	assert.Equal(t, 0, c.ScoreNoAuth)
 }
 }
 
 
 func BenchmarkDefenderBannedSearch(b *testing.B) {
 func BenchmarkDefenderBannedSearch(b *testing.B) {

+ 3 - 2
internal/common/defenderdb_test.go

@@ -39,6 +39,7 @@ func TestBasicDbDefender(t *testing.T) {
 		Threshold:          5,
 		Threshold:          5,
 		ScoreInvalid:       2,
 		ScoreInvalid:       2,
 		ScoreValid:         1,
 		ScoreValid:         1,
+		ScoreNoAuth:        2,
 		ScoreLimitExceeded: 3,
 		ScoreLimitExceeded: 3,
 		ObservationTime:    15,
 		ObservationTime:    15,
 		EntriesSoftLimit:   1,
 		EntriesSoftLimit:   1,
@@ -161,9 +162,9 @@ func TestBasicDbDefender(t *testing.T) {
 	testIP2 := "123.45.67.91"
 	testIP2 := "123.45.67.91"
 	testIP3 := "123.45.67.92"
 	testIP3 := "123.45.67.92"
 	for i := 0; i < 3; i++ {
 	for i := 0; i < 3; i++ {
-		defender.AddEvent(testIP, HostEventNoLoginTried)
+		defender.AddEvent(testIP, HostEventUserNotFound)
 		defender.AddEvent(testIP1, HostEventNoLoginTried)
 		defender.AddEvent(testIP1, HostEventNoLoginTried)
-		defender.AddEvent(testIP2, HostEventNoLoginTried)
+		defender.AddEvent(testIP2, HostEventUserNotFound)
 	}
 	}
 	hosts, err = defender.GetHosts()
 	hosts, err = defender.GetHosts()
 	assert.NoError(t, err)
 	assert.NoError(t, err)

+ 1 - 1
internal/common/eventmanager.go

@@ -2297,7 +2297,7 @@ func (j *eventCronJob) getTask(rule *dataprovider.EventRule) (dataprovider.Task,
 	if rule.GuardFromConcurrentExecution() {
 	if rule.GuardFromConcurrentExecution() {
 		task, err := dataprovider.GetTaskByName(rule.Name)
 		task, err := dataprovider.GetTaskByName(rule.Name)
 		if err != nil {
 		if err != nil {
-			if _, ok := err.(*util.RecordNotFoundError); ok {
+			if errors.Is(err, util.ErrNotFound) {
 				eventManagerLog(logger.LevelDebug, "adding task for rule %q", rule.Name)
 				eventManagerLog(logger.LevelDebug, "adding task for rule %q", rule.Name)
 				task = dataprovider.Task{
 				task = dataprovider.Task{
 					Name:     rule.Name,
 					Name:     rule.Name,

+ 2 - 1
internal/common/eventmanager_test.go

@@ -17,6 +17,7 @@ package common
 import (
 import (
 	"bytes"
 	"bytes"
 	"crypto/rand"
 	"crypto/rand"
+	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"mime/multipart"
 	"mime/multipart"
@@ -383,7 +384,7 @@ func TestEventManager(t *testing.T) {
 
 
 	assert.Eventually(t, func() bool {
 	assert.Eventually(t, func() bool {
 		_, err = dataprovider.EventRuleExists(rule.Name)
 		_, err = dataprovider.EventRuleExists(rule.Name)
-		_, ok := err.(*util.RecordNotFoundError)
+		ok := errors.Is(err, util.ErrNotFound)
 		return ok
 		return ok
 	}, 2*time.Second, 100*time.Millisecond)
 	}, 2*time.Second, 100*time.Millisecond)
 
 

+ 2 - 0
internal/config/config.go

@@ -221,6 +221,7 @@ func Init() {
 				ScoreInvalid:       2,
 				ScoreInvalid:       2,
 				ScoreValid:         1,
 				ScoreValid:         1,
 				ScoreLimitExceeded: 3,
 				ScoreLimitExceeded: 3,
+				ScoreNoAuth:        2,
 				ObservationTime:    30,
 				ObservationTime:    30,
 				EntriesSoftLimit:   100,
 				EntriesSoftLimit:   100,
 				EntriesHardLimit:   150,
 				EntriesHardLimit:   150,
@@ -1968,6 +1969,7 @@ func setViperDefaults() {
 	viper.SetDefault("common.defender.score_invalid", globalConf.Common.DefenderConfig.ScoreInvalid)
 	viper.SetDefault("common.defender.score_invalid", globalConf.Common.DefenderConfig.ScoreInvalid)
 	viper.SetDefault("common.defender.score_valid", globalConf.Common.DefenderConfig.ScoreValid)
 	viper.SetDefault("common.defender.score_valid", globalConf.Common.DefenderConfig.ScoreValid)
 	viper.SetDefault("common.defender.score_limit_exceeded", globalConf.Common.DefenderConfig.ScoreLimitExceeded)
 	viper.SetDefault("common.defender.score_limit_exceeded", globalConf.Common.DefenderConfig.ScoreLimitExceeded)
+	viper.SetDefault("common.defender.score_no_auth", globalConf.Common.DefenderConfig.ScoreNoAuth)
 	viper.SetDefault("common.defender.observation_time", globalConf.Common.DefenderConfig.ObservationTime)
 	viper.SetDefault("common.defender.observation_time", globalConf.Common.DefenderConfig.ObservationTime)
 	viper.SetDefault("common.defender.entries_soft_limit", globalConf.Common.DefenderConfig.EntriesSoftLimit)
 	viper.SetDefault("common.defender.entries_soft_limit", globalConf.Common.DefenderConfig.EntriesSoftLimit)
 	viper.SetDefault("common.defender.entries_hard_limit", globalConf.Common.DefenderConfig.EntriesHardLimit)
 	viper.SetDefault("common.defender.entries_hard_limit", globalConf.Common.DefenderConfig.EntriesHardLimit)

+ 1 - 1
internal/dataprovider/dataprovider.go

@@ -4049,7 +4049,7 @@ func doPluginAuth(username, password string, pubKey []byte, ip, protocol string,
 func getUserForHook(username string, oidcTokenFields *map[string]any) (User, User, error) {
 func getUserForHook(username string, oidcTokenFields *map[string]any) (User, User, error) {
 	u, err := provider.userExists(username, "")
 	u, err := provider.userExists(username, "")
 	if err != nil {
 	if err != nil {
-		if _, ok := err.(*util.RecordNotFoundError); !ok {
+		if !errors.Is(err, util.ErrNotFound) {
 			return u, u, err
 			return u, u, err
 		}
 		}
 		u = User{
 		u = User{

+ 1 - 1
internal/dataprovider/memory.go

@@ -1379,7 +1379,7 @@ func (p *MemoryProvider) addOrUpdateFolderInternal(baseFolder *vfs.BaseVirtualFo
 		p.updateFoldersMappingInternal(folder)
 		p.updateFoldersMappingInternal(folder)
 		return folder, nil
 		return folder, nil
 	}
 	}
-	if _, ok := err.(*util.RecordNotFoundError); ok {
+	if errors.Is(err, util.ErrNotFound) {
 		folder = baseFolder.GetACopy()
 		folder = baseFolder.GetACopy()
 		folder.ID = p.getNextFolderID()
 		folder.ID = p.getNextFolderID()
 		folder.UsedQuotaSize = usedQuotaSize
 		folder.UsedQuotaSize = usedQuotaSize

+ 1 - 1
internal/dataprovider/user.go

@@ -526,7 +526,7 @@ func (u *User) getForbiddenSFTPSelfUsers(username string) ([]string, error) {
 		}
 		}
 		return forbiddens, nil
 		return forbiddens, nil
 	}
 	}
-	if _, ok := err.(*util.RecordNotFoundError); !ok {
+	if !errors.Is(err, util.ErrNotFound) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 

+ 2 - 2
internal/ftpd/server.go

@@ -231,7 +231,7 @@ func (s *Server) PreAuthUser(cc ftpserver.ClientContext, username string) error
 			}
 			}
 			return nil
 			return nil
 		}
 		}
-		if _, ok := err.(*util.RecordNotFoundError); !ok {
+		if !errors.Is(err, util.ErrNotFound) {
 			logger.Error(logSender, fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID()),
 			logger.Error(logSender, fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID()),
 				"unable to get user on pre auth: %v", err)
 				"unable to get user on pre auth: %v", err)
 			return common.ErrInternalFailure
 			return common.ErrInternalFailure
@@ -426,7 +426,7 @@ func updateLoginMetrics(user *dataprovider.User, ip, loginMethod string, err err
 		logger.ConnectionFailedLog(user.Username, ip, loginMethod,
 		logger.ConnectionFailedLog(user.Username, ip, loginMethod,
 			common.ProtocolFTP, err.Error())
 			common.ProtocolFTP, err.Error())
 		event := common.HostEventLoginFailed
 		event := common.HostEventLoginFailed
-		if _, ok := err.(*util.RecordNotFoundError); ok {
+		if errors.Is(err, util.ErrNotFound) {
 			event = common.HostEventUserNotFound
 			event = common.HostEventUserNotFound
 		}
 		}
 		common.AddDefenderEvent(ip, event)
 		common.AddDefenderEvent(ip, event)

+ 1 - 1
internal/httpd/api_user.go

@@ -248,7 +248,7 @@ func resetUserPassword(w http.ResponseWriter, r *http.Request) {
 	sendAPIResponse(w, r, err, "Password reset successful", http.StatusOK)
 	sendAPIResponse(w, r, err, "Password reset successful", http.StatusOK)
 }
 }
 
 
-func disconnectUser(username string, admin, role string) {
+func disconnectUser(username, admin, role string) {
 	for _, stat := range common.Connections.GetStats("") {
 	for _, stat := range common.Connections.GetStats("") {
 		if stat.Username == username {
 		if stat.Username == username {
 			common.Connections.Close(stat.ConnectionID, "")
 			common.Connections.Close(stat.ConnectionID, "")

+ 3 - 3
internal/httpd/api_utils.go

@@ -72,7 +72,7 @@ type userProfile struct {
 
 
 func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) {
 func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) {
 	var errorString string
 	var errorString string
-	if _, ok := err.(*util.RecordNotFoundError); ok {
+	if errors.Is(err, util.ErrNotFound) {
 		errorString = http.StatusText(http.StatusNotFound)
 		errorString = http.StatusText(http.StatusNotFound)
 	} else if err != nil {
 	} else if err != nil {
 		errorString = err.Error()
 		errorString = err.Error()
@@ -600,7 +600,7 @@ func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err err
 	if err != nil && err != common.ErrInternalFailure && err != common.ErrNoCredentials {
 	if err != nil && err != common.ErrInternalFailure && err != common.ErrNoCredentials {
 		logger.ConnectionFailedLog(user.Username, ip, loginMethod, protocol, err.Error())
 		logger.ConnectionFailedLog(user.Username, ip, loginMethod, protocol, err.Error())
 		event := common.HostEventLoginFailed
 		event := common.HostEventLoginFailed
-		if _, ok := err.(*util.RecordNotFoundError); ok {
+		if errors.Is(err, util.ErrNotFound) {
 			event = common.HostEventUserNotFound
 			event = common.HostEventUserNotFound
 		}
 		}
 		common.AddDefenderEvent(ip, event)
 		common.AddDefenderEvent(ip, event)
@@ -657,7 +657,7 @@ func handleForgotPassword(r *http.Request, username string, isAdmin bool) error
 		}
 		}
 	}
 	}
 	if err != nil {
 	if err != nil {
-		if _, ok := err.(*util.RecordNotFoundError); ok {
+		if errors.Is(err, util.ErrNotFound) {
 			logger.Debug(logSender, middleware.GetReqID(r.Context()), "username %#v does not exists, reset password request silently ignored, is admin? %v",
 			logger.Debug(logSender, middleware.GetReqID(r.Context()), "username %#v does not exists, reset password request silently ignored, is admin? %v",
 				username, isAdmin)
 				username, isAdmin)
 			return nil
 			return nil

+ 8 - 3
internal/httpd/httpd_test.go

@@ -6959,6 +6959,7 @@ func TestDefenderAPI(t *testing.T) {
 		cfg.DefenderConfig.Driver = driver
 		cfg.DefenderConfig.Driver = driver
 		cfg.DefenderConfig.Threshold = 3
 		cfg.DefenderConfig.Threshold = 3
 		cfg.DefenderConfig.ScoreLimitExceeded = 2
 		cfg.DefenderConfig.ScoreLimitExceeded = 2
+		cfg.DefenderConfig.ScoreNoAuth = 0
 
 
 		err := common.Initialize(cfg, 0)
 		err := common.Initialize(cfg, 0)
 		assert.NoError(t, err)
 		assert.NoError(t, err)
@@ -6975,6 +6976,10 @@ func TestDefenderAPI(t *testing.T) {
 		common.AddDefenderEvent(ip, common.HostEventNoLoginTried)
 		common.AddDefenderEvent(ip, common.HostEventNoLoginTried)
 		hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
 		hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
 		assert.NoError(t, err)
 		assert.NoError(t, err)
+		assert.Len(t, hosts, 0)
+		common.AddDefenderEvent(ip, common.HostEventUserNotFound)
+		hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
+		assert.NoError(t, err)
 		if assert.Len(t, hosts, 1) {
 		if assert.Len(t, hosts, 1) {
 			host := hosts[0]
 			host := hosts[0]
 			assert.Empty(t, host.GetBanTime())
 			assert.Empty(t, host.GetBanTime())
@@ -6986,7 +6991,7 @@ func TestDefenderAPI(t *testing.T) {
 		assert.Empty(t, host.GetBanTime())
 		assert.Empty(t, host.GetBanTime())
 		assert.Equal(t, 2, host.Score)
 		assert.Equal(t, 2, host.Score)
 
 
-		common.AddDefenderEvent(ip, common.HostEventNoLoginTried)
+		common.AddDefenderEvent(ip, common.HostEventUserNotFound)
 		hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
 		hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
 		assert.NoError(t, err)
 		assert.NoError(t, err)
 		if assert.Len(t, hosts, 1) {
 		if assert.Len(t, hosts, 1) {
@@ -7006,8 +7011,8 @@ func TestDefenderAPI(t *testing.T) {
 		_, _, err = httpdtest.GetDefenderHostByIP(ip, http.StatusNotFound)
 		_, _, err = httpdtest.GetDefenderHostByIP(ip, http.StatusNotFound)
 		assert.NoError(t, err)
 		assert.NoError(t, err)
 
 
-		common.AddDefenderEvent(ip, common.HostEventNoLoginTried)
-		common.AddDefenderEvent(ip, common.HostEventNoLoginTried)
+		common.AddDefenderEvent(ip, common.HostEventUserNotFound)
+		common.AddDefenderEvent(ip, common.HostEventUserNotFound)
 		hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
 		hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
 		assert.NoError(t, err)
 		assert.NoError(t, err)
 		assert.Len(t, hosts, 1)
 		assert.Len(t, hosts, 1)

+ 2 - 4
internal/httpd/internal_test.go

@@ -2570,8 +2570,7 @@ func TestBrowsableSharePaths(t *testing.T) {
 	}
 	}
 	_, err := getUserForShare(share)
 	_, err := getUserForShare(share)
 	if assert.Error(t, err) {
 	if assert.Error(t, err) {
-		_, ok := err.(*util.RecordNotFoundError)
-		assert.True(t, ok)
+		assert.ErrorIs(t, err, util.ErrNotFound)
 	}
 	}
 	req, err := http.NewRequest(http.MethodGet, "/share", nil)
 	req, err := http.NewRequest(http.MethodGet, "/share", nil)
 	require.NoError(t, err)
 	require.NoError(t, err)
@@ -2876,8 +2875,7 @@ func TestDbResetCodeManager(t *testing.T) {
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	err = mgr.Delete(resetCode.Code)
 	err = mgr.Delete(resetCode.Code)
 	if assert.Error(t, err) {
 	if assert.Error(t, err) {
-		_, ok := err.(*util.RecordNotFoundError)
-		assert.True(t, ok)
+		assert.ErrorIs(t, err, util.ErrNotFound)
 	}
 	}
 	_, err = mgr.Get(resetCode.Code)
 	_, err = mgr.Get(resetCode.Code)
 	assert.ErrorIs(t, err, sql.ErrNoRows)
 	assert.ErrorIs(t, err, sql.ErrNoRows)

+ 3 - 6
internal/httpd/oidc_test.go

@@ -851,8 +851,7 @@ func TestOIDCToken(t *testing.T) {
 	token.Role = ""
 	token.Role = ""
 	err = token.getUser(req)
 	err = token.getUser(req)
 	if assert.Error(t, err) {
 	if assert.Error(t, err) {
-		_, ok := err.(*util.RecordNotFoundError)
-		assert.True(t, ok)
+		assert.ErrorIs(t, err, util.ErrNotFound)
 	}
 	}
 
 
 	user := dataprovider.User{
 	user := dataprovider.User{
@@ -1165,8 +1164,7 @@ func TestOIDCPreLoginHook(t *testing.T) {
 	server.initializeRouter()
 	server.initializeRouter()
 
 
 	_, err = dataprovider.UserExists(username, "")
 	_, err = dataprovider.UserExists(username, "")
-	_, ok = err.(*util.RecordNotFoundError)
-	assert.True(t, ok)
+	assert.ErrorIs(t, err, util.ErrNotFound)
 	// now login with OIDC
 	// now login with OIDC
 	authReq := newOIDCPendingAuth(tokenAudienceWebClient)
 	authReq := newOIDCPendingAuth(tokenAudienceWebClient)
 	oidcMgr.addPendingAuth(authReq)
 	oidcMgr.addPendingAuth(authReq)
@@ -1226,8 +1224,7 @@ func TestOIDCPreLoginHook(t *testing.T) {
 	assert.Equal(t, http.StatusFound, rr.Code)
 	assert.Equal(t, http.StatusFound, rr.Code)
 	assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
 	assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
 	_, err = dataprovider.UserExists(username, "")
 	_, err = dataprovider.UserExists(username, "")
-	_, ok = err.(*util.RecordNotFoundError)
-	assert.True(t, ok)
+	assert.ErrorIs(t, err, util.ErrNotFound)
 	if assert.Len(t, oidcMgr.tokens, 1) {
 	if assert.Len(t, oidcMgr.tokens, 1) {
 		for k := range oidcMgr.tokens {
 		for k := range oidcMgr.tokens {
 			oidcMgr.removeToken(k)
 			oidcMgr.removeToken(k)

+ 16 - 16
internal/httpd/webadmin.go

@@ -2591,7 +2591,7 @@ func (s *httpdServer) handleWebUpdateAdminGet(w http.ResponseWriter, r *http.Req
 	admin, err := dataprovider.AdminExists(username)
 	admin, err := dataprovider.AdminExists(username)
 	if err == nil {
 	if err == nil {
 		s.renderAddUpdateAdminPage(w, r, &admin, "", false)
 		s.renderAddUpdateAdminPage(w, r, &admin, "", false)
-	} else if _, ok := err.(*util.RecordNotFoundError); ok {
+	} else if errors.Is(err, util.ErrNotFound) {
 		s.renderNotFoundPage(w, r, err)
 		s.renderNotFoundPage(w, r, err)
 	} else {
 	} else {
 		s.renderInternalServerErrorPage(w, r, err)
 		s.renderInternalServerErrorPage(w, r, err)
@@ -2631,7 +2631,7 @@ func (s *httpdServer) handleWebUpdateAdminPost(w http.ResponseWriter, r *http.Re
 
 
 	username := getURLParam(r, "username")
 	username := getURLParam(r, "username")
 	admin, err := dataprovider.AdminExists(username)
 	admin, err := dataprovider.AdminExists(username)
-	if _, ok := err.(*util.RecordNotFoundError); ok {
+	if errors.Is(err, util.ErrNotFound) {
 		s.renderNotFoundPage(w, r, err)
 		s.renderNotFoundPage(w, r, err)
 		return
 		return
 	} else if err != nil {
 	} else if err != nil {
@@ -2737,7 +2737,7 @@ func (s *httpdServer) handleWebTemplateFolderGet(w http.ResponseWriter, r *http.
 		if err == nil {
 		if err == nil {
 			folder.FsConfig.SetEmptySecrets()
 			folder.FsConfig.SetEmptySecrets()
 			s.renderFolderPage(w, r, folder, folderPageModeTemplate, "")
 			s.renderFolderPage(w, r, folder, folderPageModeTemplate, "")
-		} else if _, ok := err.(*util.RecordNotFoundError); ok {
+		} else if errors.Is(err, util.ErrNotFound) {
 			s.renderNotFoundPage(w, r, err)
 			s.renderNotFoundPage(w, r, err)
 		} else {
 		} else {
 			s.renderInternalServerErrorPage(w, r, err)
 			s.renderInternalServerErrorPage(w, r, err)
@@ -2831,7 +2831,7 @@ func (s *httpdServer) handleWebTemplateUserGet(w http.ResponseWriter, r *http.Re
 				user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(admin.Filters.Preferences.DefaultUsersExpiration)))
 				user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(admin.Filters.Preferences.DefaultUsersExpiration)))
 			}
 			}
 			s.renderUserPage(w, r, &user, userPageModeTemplate, "", &admin)
 			s.renderUserPage(w, r, &user, userPageModeTemplate, "", &admin)
-		} else if _, ok := err.(*util.RecordNotFoundError); ok {
+		} else if errors.Is(err, util.ErrNotFound) {
 			s.renderNotFoundPage(w, r, err)
 			s.renderNotFoundPage(w, r, err)
 		} else {
 		} else {
 			s.renderInternalServerErrorPage(w, r, err)
 			s.renderInternalServerErrorPage(w, r, err)
@@ -2939,7 +2939,7 @@ func (s *httpdServer) handleWebUpdateUserGet(w http.ResponseWriter, r *http.Requ
 	user, err := dataprovider.UserExists(username, claims.Role)
 	user, err := dataprovider.UserExists(username, claims.Role)
 	if err == nil {
 	if err == nil {
 		s.renderUserPage(w, r, &user, userPageModeUpdate, "", nil)
 		s.renderUserPage(w, r, &user, userPageModeUpdate, "", nil)
-	} else if _, ok := err.(*util.RecordNotFoundError); ok {
+	} else if errors.Is(err, util.ErrNotFound) {
 		s.renderNotFoundPage(w, r, err)
 		s.renderNotFoundPage(w, r, err)
 	} else {
 	} else {
 		s.renderInternalServerErrorPage(w, r, err)
 		s.renderInternalServerErrorPage(w, r, err)
@@ -2992,7 +2992,7 @@ func (s *httpdServer) handleWebUpdateUserPost(w http.ResponseWriter, r *http.Req
 	}
 	}
 	username := getURLParam(r, "username")
 	username := getURLParam(r, "username")
 	user, err := dataprovider.UserExists(username, claims.Role)
 	user, err := dataprovider.UserExists(username, claims.Role)
-	if _, ok := err.(*util.RecordNotFoundError); ok {
+	if errors.Is(err, util.ErrNotFound) {
 		s.renderNotFoundPage(w, r, err)
 		s.renderNotFoundPage(w, r, err)
 		return
 		return
 	} else if err != nil {
 	} else if err != nil {
@@ -3118,7 +3118,7 @@ func (s *httpdServer) handleWebUpdateFolderGet(w http.ResponseWriter, r *http.Re
 	folder, err := dataprovider.GetFolderByName(name)
 	folder, err := dataprovider.GetFolderByName(name)
 	if err == nil {
 	if err == nil {
 		s.renderFolderPage(w, r, folder, folderPageModeUpdate, "")
 		s.renderFolderPage(w, r, folder, folderPageModeUpdate, "")
-	} else if _, ok := err.(*util.RecordNotFoundError); ok {
+	} else if errors.Is(err, util.ErrNotFound) {
 		s.renderNotFoundPage(w, r, err)
 		s.renderNotFoundPage(w, r, err)
 	} else {
 	} else {
 		s.renderInternalServerErrorPage(w, r, err)
 		s.renderInternalServerErrorPage(w, r, err)
@@ -3134,7 +3134,7 @@ func (s *httpdServer) handleWebUpdateFolderPost(w http.ResponseWriter, r *http.R
 	}
 	}
 	name := getURLParam(r, "name")
 	name := getURLParam(r, "name")
 	folder, err := dataprovider.GetFolderByName(name)
 	folder, err := dataprovider.GetFolderByName(name)
-	if _, ok := err.(*util.RecordNotFoundError); ok {
+	if errors.Is(err, util.ErrNotFound) {
 		s.renderNotFoundPage(w, r, err)
 		s.renderNotFoundPage(w, r, err)
 		return
 		return
 	} else if err != nil {
 	} else if err != nil {
@@ -3294,7 +3294,7 @@ func (s *httpdServer) handleWebUpdateGroupGet(w http.ResponseWriter, r *http.Req
 	group, err := dataprovider.GroupExists(name)
 	group, err := dataprovider.GroupExists(name)
 	if err == nil {
 	if err == nil {
 		s.renderGroupPage(w, r, group, genericPageModeUpdate, "")
 		s.renderGroupPage(w, r, group, genericPageModeUpdate, "")
-	} else if _, ok := err.(*util.RecordNotFoundError); ok {
+	} else if errors.Is(err, util.ErrNotFound) {
 		s.renderNotFoundPage(w, r, err)
 		s.renderNotFoundPage(w, r, err)
 	} else {
 	} else {
 		s.renderInternalServerErrorPage(w, r, err)
 		s.renderInternalServerErrorPage(w, r, err)
@@ -3310,7 +3310,7 @@ func (s *httpdServer) handleWebUpdateGroupPost(w http.ResponseWriter, r *http.Re
 	}
 	}
 	name := getURLParam(r, "name")
 	name := getURLParam(r, "name")
 	group, err := dataprovider.GroupExists(name)
 	group, err := dataprovider.GroupExists(name)
-	if _, ok := err.(*util.RecordNotFoundError); ok {
+	if errors.Is(err, util.ErrNotFound) {
 		s.renderNotFoundPage(w, r, err)
 		s.renderNotFoundPage(w, r, err)
 		return
 		return
 	} else if err != nil {
 	} else if err != nil {
@@ -3423,7 +3423,7 @@ func (s *httpdServer) handleWebUpdateEventActionGet(w http.ResponseWriter, r *ht
 	action, err := dataprovider.EventActionExists(name)
 	action, err := dataprovider.EventActionExists(name)
 	if err == nil {
 	if err == nil {
 		s.renderEventActionPage(w, r, action, genericPageModeUpdate, "")
 		s.renderEventActionPage(w, r, action, genericPageModeUpdate, "")
-	} else if _, ok := err.(*util.RecordNotFoundError); ok {
+	} else if errors.Is(err, util.ErrNotFound) {
 		s.renderNotFoundPage(w, r, err)
 		s.renderNotFoundPage(w, r, err)
 	} else {
 	} else {
 		s.renderInternalServerErrorPage(w, r, err)
 		s.renderInternalServerErrorPage(w, r, err)
@@ -3439,7 +3439,7 @@ func (s *httpdServer) handleWebUpdateEventActionPost(w http.ResponseWriter, r *h
 	}
 	}
 	name := getURLParam(r, "name")
 	name := getURLParam(r, "name")
 	action, err := dataprovider.EventActionExists(name)
 	action, err := dataprovider.EventActionExists(name)
-	if _, ok := err.(*util.RecordNotFoundError); ok {
+	if errors.Is(err, util.ErrNotFound) {
 		s.renderNotFoundPage(w, r, err)
 		s.renderNotFoundPage(w, r, err)
 		return
 		return
 	} else if err != nil {
 	} else if err != nil {
@@ -3541,7 +3541,7 @@ func (s *httpdServer) handleWebUpdateEventRuleGet(w http.ResponseWriter, r *http
 	rule, err := dataprovider.EventRuleExists(name)
 	rule, err := dataprovider.EventRuleExists(name)
 	if err == nil {
 	if err == nil {
 		s.renderEventRulePage(w, r, rule, genericPageModeUpdate, "")
 		s.renderEventRulePage(w, r, rule, genericPageModeUpdate, "")
-	} else if _, ok := err.(*util.RecordNotFoundError); ok {
+	} else if errors.Is(err, util.ErrNotFound) {
 		s.renderNotFoundPage(w, r, err)
 		s.renderNotFoundPage(w, r, err)
 	} else {
 	} else {
 		s.renderInternalServerErrorPage(w, r, err)
 		s.renderInternalServerErrorPage(w, r, err)
@@ -3557,7 +3557,7 @@ func (s *httpdServer) handleWebUpdateEventRulePost(w http.ResponseWriter, r *htt
 	}
 	}
 	name := getURLParam(r, "name")
 	name := getURLParam(r, "name")
 	rule, err := dataprovider.EventRuleExists(name)
 	rule, err := dataprovider.EventRuleExists(name)
-	if _, ok := err.(*util.RecordNotFoundError); ok {
+	if errors.Is(err, util.ErrNotFound) {
 		s.renderNotFoundPage(w, r, err)
 		s.renderNotFoundPage(w, r, err)
 		return
 		return
 	} else if err != nil {
 	} else if err != nil {
@@ -3648,7 +3648,7 @@ func (s *httpdServer) handleWebUpdateRoleGet(w http.ResponseWriter, r *http.Requ
 	role, err := dataprovider.RoleExists(getURLParam(r, "name"))
 	role, err := dataprovider.RoleExists(getURLParam(r, "name"))
 	if err == nil {
 	if err == nil {
 		s.renderRolePage(w, r, role, genericPageModeUpdate, "")
 		s.renderRolePage(w, r, role, genericPageModeUpdate, "")
-	} else if _, ok := err.(*util.RecordNotFoundError); ok {
+	} else if errors.Is(err, util.ErrNotFound) {
 		s.renderNotFoundPage(w, r, err)
 		s.renderNotFoundPage(w, r, err)
 	} else {
 	} else {
 		s.renderInternalServerErrorPage(w, r, err)
 		s.renderInternalServerErrorPage(w, r, err)
@@ -3663,7 +3663,7 @@ func (s *httpdServer) handleWebUpdateRolePost(w http.ResponseWriter, r *http.Req
 		return
 		return
 	}
 	}
 	role, err := dataprovider.RoleExists(getURLParam(r, "name"))
 	role, err := dataprovider.RoleExists(getURLParam(r, "name"))
-	if _, ok := err.(*util.RecordNotFoundError); ok {
+	if errors.Is(err, util.ErrNotFound) {
 		s.renderNotFoundPage(w, r, err)
 		s.renderNotFoundPage(w, r, err)
 		return
 		return
 	} else if err != nil {
 	} else if err != nil {

+ 2 - 2
internal/httpd/webclient.go

@@ -1071,7 +1071,7 @@ func (s *httpdServer) handleClientUpdateShareGet(w http.ResponseWriter, r *http.
 	if err == nil {
 	if err == nil {
 		share.HideConfidentialData()
 		share.HideConfidentialData()
 		s.renderAddUpdateSharePage(w, r, &share, "", false)
 		s.renderAddUpdateSharePage(w, r, &share, "", false)
-	} else if _, ok := err.(*util.RecordNotFoundError); ok {
+	} else if errors.Is(err, util.ErrNotFound) {
 		s.renderClientNotFoundPage(w, r, err)
 		s.renderClientNotFoundPage(w, r, err)
 	} else {
 	} else {
 		s.renderClientInternalServerErrorPage(w, r, err)
 		s.renderClientInternalServerErrorPage(w, r, err)
@@ -1122,7 +1122,7 @@ func (s *httpdServer) handleClientUpdateSharePost(w http.ResponseWriter, r *http
 	}
 	}
 	shareID := getURLParam(r, "id")
 	shareID := getURLParam(r, "id")
 	share, err := dataprovider.ShareExists(shareID, claims.Username)
 	share, err := dataprovider.ShareExists(shareID, claims.Username)
-	if _, ok := err.(*util.RecordNotFoundError); ok {
+	if errors.Is(err, util.ErrNotFound) {
 		s.renderClientNotFoundPage(w, r, err)
 		s.renderClientNotFoundPage(w, r, err)
 		return
 		return
 	} else if err != nil {
 	} else if err != nil {

+ 19 - 0
internal/sftpd/internal_test.go

@@ -19,6 +19,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
+	"io/fs"
 	"net"
 	"net"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
@@ -2226,3 +2227,21 @@ func TestCanReadSymlink(t *testing.T) {
 	err = connection.canReadLink("/denied/file.txt")
 	err = connection.canReadLink("/denied/file.txt")
 	assert.ErrorIs(t, err, sftp.ErrSSHFxNoSuchFile)
 	assert.ErrorIs(t, err, sftp.ErrSSHFxNoSuchFile)
 }
 }
+
+func TestAuthenticationErrors(t *testing.T) {
+	err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found")))
+	assert.ErrorIs(t, err, sftpAuthError)
+	assert.ErrorIs(t, err, util.ErrNotFound)
+	err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission))
+	assert.ErrorIs(t, err, sftpAuthError)
+	assert.NotErrorIs(t, err, util.ErrNotFound)
+	err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert))
+	assert.ErrorIs(t, err, sftpAuthError)
+	assert.NotErrorIs(t, err, util.ErrNotFound)
+	err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority"))
+	assert.ErrorIs(t, err, sftpAuthError)
+	assert.NotErrorIs(t, err, util.ErrNotFound)
+	err = newAuthenticationError(nil)
+	assert.ErrorIs(t, err, sftpAuthError)
+	assert.NotErrorIs(t, err, util.ErrNotFound)
+}

+ 31 - 18
internal/sftpd/server.go

@@ -92,6 +92,8 @@ var (
 	revokedCertManager = revokedCertificates{
 	revokedCertManager = revokedCertificates{
 		certs: map[string]bool{},
 		certs: map[string]bool{},
 	}
 	}
+
+	sftpAuthError = newAuthenticationError(nil)
 )
 )
 
 
 // Binding defines the configuration for a network listener
 // Binding defines the configuration for a network listener
@@ -208,11 +210,26 @@ type Configuration struct {
 }
 }
 
 
 type authenticationError struct {
 type authenticationError struct {
-	err string
+	err error
 }
 }
 
 
 func (e *authenticationError) Error() string {
 func (e *authenticationError) Error() string {
-	return fmt.Sprintf("Authentication error: %s", e.err)
+	return fmt.Sprintf("Authentication error: %v", e.err)
+}
+
+// Is reports if target matches
+func (e *authenticationError) Is(target error) bool {
+	_, ok := target.(*authenticationError)
+	return ok
+}
+
+// Unwrap returns the wrapped error
+func (e *authenticationError) Unwrap() error {
+	return e.err
+}
+
+func newAuthenticationError(err error) *authenticationError {
+	return &authenticationError{err: err}
 }
 }
 
 
 // ShouldBind returns true if there is at least a valid binding
 // ShouldBind returns true if there is at least a valid binding
@@ -236,7 +253,7 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
 				return sp, err
 				return sp, err
 			}
 			}
 			if err != nil {
 			if err != nil {
-				return nil, &authenticationError{err: fmt.Sprintf("could not validate public key credentials: %v", err)}
+				return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err))
 			}
 			}
 
 
 			return sp, nil
 			return sp, nil
@@ -256,7 +273,7 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
 		serverConfig.PasswordCallback = func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
 		serverConfig.PasswordCallback = func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
 			sp, err := c.validatePasswordCredentials(conn, pass)
 			sp, err := c.validatePasswordCredentials(conn, pass)
 			if err != nil {
 			if err != nil {
-				return nil, &authenticationError{err: fmt.Sprintf("could not validate password credentials: %v", err)}
+				return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err))
 			}
 			}
 
 
 			return sp, nil
 			return sp, nil
@@ -453,9 +470,9 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve
 	if c.KeyboardInteractiveHook != "" {
 	if c.KeyboardInteractiveHook != "" {
 		if !strings.HasPrefix(c.KeyboardInteractiveHook, "http") {
 		if !strings.HasPrefix(c.KeyboardInteractiveHook, "http") {
 			if !filepath.IsAbs(c.KeyboardInteractiveHook) {
 			if !filepath.IsAbs(c.KeyboardInteractiveHook) {
-				logger.WarnToConsole("invalid keyboard interactive authentication program: %#v must be an absolute path",
+				logger.WarnToConsole("invalid keyboard interactive authentication program: %q must be an absolute path",
 					c.KeyboardInteractiveHook)
 					c.KeyboardInteractiveHook)
-				logger.Warn(logSender, "", "invalid keyboard interactive authentication program: %#v must be an absolute path",
+				logger.Warn(logSender, "", "invalid keyboard interactive authentication program: %q must be an absolute path",
 					c.KeyboardInteractiveHook)
 					c.KeyboardInteractiveHook)
 				return
 				return
 			}
 			}
@@ -470,7 +487,7 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve
 	serverConfig.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
 	serverConfig.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
 		sp, err := c.validateKeyboardInteractiveCredentials(conn, client)
 		sp, err := c.validateKeyboardInteractiveCredentials(conn, client)
 		if err != nil {
 		if err != nil {
-			return nil, &authenticationError{err: fmt.Sprintf("could not validate keyboard interactive credentials: %v", err)}
+			return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err))
 		}
 		}
 
 
 		return sp, nil
 		return sp, nil
@@ -666,20 +683,16 @@ func (c *Configuration) createHandlers(connection *Connection) sftp.Handlers {
 
 
 func checkAuthError(ip string, err error) {
 func checkAuthError(ip string, err error) {
 	if authErrors, ok := err.(*ssh.ServerAuthError); ok {
 	if authErrors, ok := err.(*ssh.ServerAuthError); ok {
-		// check public key auth errors here
+		event := common.HostEventLoginFailed
 		for _, err := range authErrors.Errors {
 		for _, err := range authErrors.Errors {
-			if err != nil {
-				// these checks should be improved, we should check for error type and not error strings
-				if strings.Contains(err.Error(), "public key credentials") {
-					event := common.HostEventLoginFailed
-					if strings.Contains(err.Error(), "not found") {
-						event = common.HostEventUserNotFound
-					}
-					common.AddDefenderEvent(ip, event)
-					break
+			if errors.Is(err, sftpAuthError) {
+				if errors.Is(err, util.ErrNotFound) {
+					event = common.HostEventUserNotFound
 				}
 				}
+				break
 			}
 			}
 		}
 		}
+		common.AddDefenderEvent(ip, event)
 	} else {
 	} else {
 		logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
 		logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
 		metric.AddNoAuthTryed()
 		metric.AddNoAuthTryed()
@@ -1131,7 +1144,7 @@ func updateLoginMetrics(user *dataprovider.User, ip, method string, err error) {
 			// record failed login key auth only once for session if the
 			// record failed login key auth only once for session if the
 			// authentication fails in checkAuthError
 			// authentication fails in checkAuthError
 			event := common.HostEventLoginFailed
 			event := common.HostEventLoginFailed
-			if _, ok := err.(*util.RecordNotFoundError); ok {
+			if errors.Is(err, util.ErrNotFound) {
 				event = common.HostEventUserNotFound
 				event = common.HostEventUserNotFound
 			}
 			}
 			common.AddDefenderEvent(ip, event)
 			common.AddDefenderEvent(ip, event)

+ 1 - 2
internal/webdavd/server.go

@@ -188,7 +188,6 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 	user, isCached, lockSystem, loginMethod, err := s.authenticate(r, ipAddr)
 	user, isCached, lockSystem, loginMethod, err := s.authenticate(r, ipAddr)
 	if err != nil {
 	if err != nil {
-		updateLoginMetrics(&user, ipAddr, loginMethod, err)
 		if !s.binding.DisableWWWAuthHeader {
 		if !s.binding.DisableWWWAuthHeader {
 			w.Header().Set("WWW-Authenticate", "Basic realm=\"SFTPGo WebDAV\"")
 			w.Header().Set("WWW-Authenticate", "Basic realm=\"SFTPGo WebDAV\"")
 		}
 		}
@@ -411,7 +410,7 @@ func updateLoginMetrics(user *dataprovider.User, ip, loginMethod string, err err
 	if err != nil && err != common.ErrInternalFailure && err != common.ErrNoCredentials {
 	if err != nil && err != common.ErrInternalFailure && err != common.ErrNoCredentials {
 		logger.ConnectionFailedLog(user.Username, ip, loginMethod, common.ProtocolWebDAV, err.Error())
 		logger.ConnectionFailedLog(user.Username, ip, loginMethod, common.ProtocolWebDAV, err.Error())
 		event := common.HostEventLoginFailed
 		event := common.HostEventLoginFailed
-		if _, ok := err.(*util.RecordNotFoundError); ok {
+		if errors.Is(err, util.ErrNotFound) {
 			event = common.HostEventUserNotFound
 			event = common.HostEventUserNotFound
 		}
 		}
 		common.AddDefenderEvent(ip, event)
 		common.AddDefenderEvent(ip, event)

+ 1 - 0
sftpgo.json

@@ -29,6 +29,7 @@
       "score_invalid": 2,
       "score_invalid": 2,
       "score_valid": 1,
       "score_valid": 1,
       "score_limit_exceeded": 3,
       "score_limit_exceeded": 3,
+      "score_no_auth": 2,
       "observation_time": 30,
       "observation_time": 30,
       "entries_soft_limit": 100,
       "entries_soft_limit": 100,
       "entries_hard_limit": 150,
       "entries_hard_limit": 150,