sftpd: fix duplicate defender error introduced in the previous commit

improve the defender test cases by verifying the expected score

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2023-01-25 21:57:27 +01:00
parent 87820d980b
commit 9c9c9fa3a5
No known key found for this signature in database
GPG key ID: 935D2952DEC4EECF
5 changed files with 91 additions and 27 deletions

View file

@ -1724,8 +1724,10 @@ func TestDefender(t *testing.T) {
cfg := config.GetCommonConfig()
cfg.DefenderConfig.Enabled = true
cfg.DefenderConfig.Threshold = 3
cfg.DefenderConfig.Threshold = 4
cfg.DefenderConfig.ScoreLimitExceeded = 2
cfg.DefenderConfig.ScoreNoAuth = 1
cfg.DefenderConfig.ScoreValid = 1
err := common.Initialize(cfg, 0)
assert.NoError(t, err)
@ -1739,9 +1741,31 @@ func TestDefender(t *testing.T) {
err = client.Quit()
assert.NoError(t, err)
}
// just dial without login
ftpOptions := []ftp.DialOption{ftp.DialWithTimeout(5 * time.Second)}
client, err = ftp.Dial(ftpServerAddr, ftpOptions...)
assert.NoError(t, err)
err = client.Quit()
assert.NoError(t, err)
hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK)
assert.NoError(t, err)
if assert.Len(t, hosts, 1) {
host := hosts[0]
assert.Empty(t, host.GetBanTime())
assert.Equal(t, 1, host.Score)
}
user.Password = "wrong_pwd"
_, err = getFTPClient(user, false, nil)
assert.Error(t, err)
hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
assert.NoError(t, err)
if assert.Len(t, hosts, 1) {
host := hosts[0]
assert.Empty(t, host.GetBanTime())
assert.Equal(t, 2, host.Score)
}
for i := 0; i < 3; i++ {
user.Password = "wrong_pwd"
for i := 0; i < 2; i++ {
_, err = getFTPClient(user, false, nil)
assert.Error(t, err)
}

View file

@ -2301,19 +2301,24 @@ func TestCanReadSymlink(t *testing.T) {
}
func TestAuthenticationErrors(t *testing.T) {
err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found")))
loginMethod := dataprovider.SSHLoginMethodPassword
err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found")), loginMethod)
assert.ErrorIs(t, err, sftpAuthError)
assert.ErrorIs(t, err, util.ErrNotFound)
err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission))
var sftpAuthErr *authenticationError
if assert.ErrorAs(t, err, &sftpAuthErr) {
assert.Equal(t, loginMethod, sftpAuthErr.getLoginMethod())
}
err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission), loginMethod)
assert.ErrorIs(t, err, sftpAuthError)
assert.NotErrorIs(t, err, util.ErrNotFound)
err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert))
err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert), loginMethod)
assert.ErrorIs(t, err, sftpAuthError)
assert.NotErrorIs(t, err, util.ErrNotFound)
err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority"))
err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority"), loginMethod)
assert.ErrorIs(t, err, sftpAuthError)
assert.NotErrorIs(t, err, util.ErrNotFound)
err = newAuthenticationError(nil)
err = newAuthenticationError(nil, loginMethod)
assert.ErrorIs(t, err, sftpAuthError)
assert.NotErrorIs(t, err, util.ErrNotFound)
}

View file

@ -93,7 +93,7 @@ var (
certs: map[string]bool{},
}
sftpAuthError = newAuthenticationError(nil)
sftpAuthError = newAuthenticationError(nil, "")
)
// Binding defines the configuration for a network listener
@ -210,7 +210,8 @@ type Configuration struct {
}
type authenticationError struct {
err error
err error
loginMethod string
}
func (e *authenticationError) Error() string {
@ -228,8 +229,12 @@ func (e *authenticationError) Unwrap() error {
return e.err
}
func newAuthenticationError(err error) *authenticationError {
return &authenticationError{err: err}
func (e *authenticationError) getLoginMethod() string {
return e.loginMethod
}
func newAuthenticationError(err error, loginMethod string) *authenticationError {
return &authenticationError{err: err, loginMethod: loginMethod}
}
// ShouldBind returns true if there is at least a valid binding
@ -253,7 +258,8 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
return sp, err
}
if err != nil {
return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err))
return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err),
dataprovider.SSHLoginMethodPublicKey)
}
return sp, nil
@ -273,7 +279,8 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
serverConfig.PasswordCallback = func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
sp, err := c.validatePasswordCredentials(conn, pass)
if err != nil {
return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err))
return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err),
dataprovider.SSHLoginMethodPassword)
}
return sp, nil
@ -487,7 +494,8 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve
serverConfig.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
sp, err := c.validateKeyboardInteractiveCredentials(conn, client)
if err != nil {
return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err))
return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err),
dataprovider.SSHLoginMethodKeyboardInteractive)
}
return sp, nil
@ -561,7 +569,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
}
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
"User %#v logged in with %#v, from ip %#v, client version %#v", user.Username, loginType,
"User %q logged in with %q, from ip %q, client version %q", user.Username, loginType,
ipAddr, string(sconn.ClientVersion()))
dataprovider.UpdateLastLogin(&user)
@ -683,16 +691,20 @@ func (c *Configuration) createHandlers(connection *Connection) sftp.Handlers {
func checkAuthError(ip string, err error) {
if authErrors, ok := err.(*ssh.ServerAuthError); ok {
event := common.HostEventLoginFailed
// check public key auth errors here
for _, err := range authErrors.Errors {
if errors.Is(err, sftpAuthError) {
if errors.Is(err, util.ErrNotFound) {
event = common.HostEventUserNotFound
var sftpAuthErr *authenticationError
if errors.As(err, &sftpAuthErr) {
if sftpAuthErr.getLoginMethod() == dataprovider.SSHLoginMethodPublicKey {
event := common.HostEventLoginFailed
if errors.Is(err, util.ErrNotFound) {
event = common.HostEventUserNotFound
}
common.AddDefenderEvent(ip, event)
return
}
break
}
}
common.AddDefenderEvent(ip, event)
} else {
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
metric.AddNoAuthTryed()
@ -1078,7 +1090,7 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK
cert.KeyId, cert.Serial, cert.Type(), ssh.FingerprintSHA256(cert.SignatureKey))
}
if user.IsPartialAuth(method) {
logger.Debug(logSender, connectionID, "user %#v authenticated with partial success", conn.User())
logger.Debug(logSender, connectionID, "user %q authenticated with partial success", conn.User())
return certPerm, ssh.ErrPartialSuccess
}
sshPerm, err = loginUser(&user, method, keyID, conn)

View file

@ -966,6 +966,7 @@ func TestDefender(t *testing.T) {
cfg.DefenderConfig.Enabled = true
cfg.DefenderConfig.Threshold = 3
cfg.DefenderConfig.ScoreLimitExceeded = 2
cfg.DefenderConfig.ScoreValid = 1
err := common.Initialize(cfg, 0)
assert.NoError(t, err)
@ -977,12 +978,23 @@ func TestDefender(t *testing.T) {
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
err = checkBasicSFTP(client)
assert.NoError(t, err)
}
for i := 0; i < 3; i++ {
user.Password = "wrong_pwd"
user.Password = "wrong_pwd"
_, _, err = getSftpClient(user, usePubKey)
assert.Error(t, err)
hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK)
assert.NoError(t, err)
if assert.Len(t, hosts, 1) {
host := hosts[0]
assert.Empty(t, host.GetBanTime())
assert.Equal(t, 1, host.Score)
}
for i := 0; i < 2; i++ {
_, _, err = getSftpClient(user, usePubKey)
assert.Error(t, err)
}

View file

@ -995,6 +995,7 @@ func TestDefender(t *testing.T) {
cfg.DefenderConfig.Enabled = true
cfg.DefenderConfig.Threshold = 3
cfg.DefenderConfig.ScoreLimitExceeded = 2
cfg.DefenderConfig.ScoreValid = 1
err := common.Initialize(cfg, 0)
assert.NoError(t, err)
@ -1004,8 +1005,18 @@ func TestDefender(t *testing.T) {
client := getWebDavClient(user, true, nil)
assert.NoError(t, checkBasicFunc(client))
for i := 0; i < 3; i++ {
user.Password = "wrong_pwd"
user.Password = "wrong_pwd"
client = getWebDavClient(user, false, nil)
assert.Error(t, checkBasicFunc(client))
hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK)
assert.NoError(t, err)
if assert.Len(t, hosts, 1) {
host := hosts[0]
assert.Empty(t, host.GetBanTime())
assert.Equal(t, 1, host.Score)
}
for i := 0; i < 2; i++ {
client = getWebDavClient(user, false, nil)
assert.Error(t, checkBasicFunc(client))
}