mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-21 23:20:24 +00:00
sftpd: fix duplicate defender error introduced in the previous commit
improve the defender test cases by verifying the expected score Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
parent
c0fe08b597
commit
27c4ffd663
5 changed files with 91 additions and 27 deletions
|
@ -1749,8 +1749,10 @@ func TestDefender(t *testing.T) {
|
|||
|
||||
cfg := config.GetCommonConfig()
|
||||
cfg.DefenderConfig.Enabled = true
|
||||
cfg.DefenderConfig.Threshold = 3
|
||||
cfg.DefenderConfig.Threshold = 4
|
||||
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||
cfg.DefenderConfig.ScoreNoAuth = 1
|
||||
cfg.DefenderConfig.ScoreValid = 1
|
||||
|
||||
err := common.Initialize(cfg, 0)
|
||||
assert.NoError(t, err)
|
||||
|
@ -1764,9 +1766,31 @@ func TestDefender(t *testing.T) {
|
|||
err = client.Quit()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
// just dial without login
|
||||
ftpOptions := []ftp.DialOption{ftp.DialWithTimeout(5 * time.Second)}
|
||||
client, err = ftp.Dial(ftpServerAddr, ftpOptions...)
|
||||
assert.NoError(t, err)
|
||||
err = client.Quit()
|
||||
assert.NoError(t, err)
|
||||
hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
host := hosts[0]
|
||||
assert.Empty(t, host.GetBanTime())
|
||||
assert.Equal(t, 1, host.Score)
|
||||
}
|
||||
user.Password = "wrong_pwd"
|
||||
_, err = getFTPClient(user, false, nil)
|
||||
assert.Error(t, err)
|
||||
hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
host := hosts[0]
|
||||
assert.Empty(t, host.GetBanTime())
|
||||
assert.Equal(t, 2, host.Score)
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
user.Password = "wrong_pwd"
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err = getFTPClient(user, false, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
|
|
@ -2229,19 +2229,24 @@ func TestCanReadSymlink(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAuthenticationErrors(t *testing.T) {
|
||||
err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found")))
|
||||
loginMethod := dataprovider.SSHLoginMethodPassword
|
||||
err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found")), loginMethod)
|
||||
assert.ErrorIs(t, err, sftpAuthError)
|
||||
assert.ErrorIs(t, err, util.ErrNotFound)
|
||||
err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission))
|
||||
var sftpAuthErr *authenticationError
|
||||
if assert.ErrorAs(t, err, &sftpAuthErr) {
|
||||
assert.Equal(t, loginMethod, sftpAuthErr.getLoginMethod())
|
||||
}
|
||||
err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission), loginMethod)
|
||||
assert.ErrorIs(t, err, sftpAuthError)
|
||||
assert.NotErrorIs(t, err, util.ErrNotFound)
|
||||
err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert))
|
||||
err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert), loginMethod)
|
||||
assert.ErrorIs(t, err, sftpAuthError)
|
||||
assert.NotErrorIs(t, err, util.ErrNotFound)
|
||||
err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority"))
|
||||
err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority"), loginMethod)
|
||||
assert.ErrorIs(t, err, sftpAuthError)
|
||||
assert.NotErrorIs(t, err, util.ErrNotFound)
|
||||
err = newAuthenticationError(nil)
|
||||
err = newAuthenticationError(nil, loginMethod)
|
||||
assert.ErrorIs(t, err, sftpAuthError)
|
||||
assert.NotErrorIs(t, err, util.ErrNotFound)
|
||||
}
|
||||
|
|
|
@ -93,7 +93,7 @@ var (
|
|||
certs: map[string]bool{},
|
||||
}
|
||||
|
||||
sftpAuthError = newAuthenticationError(nil)
|
||||
sftpAuthError = newAuthenticationError(nil, "")
|
||||
)
|
||||
|
||||
// Binding defines the configuration for a network listener
|
||||
|
@ -210,7 +210,8 @@ type Configuration struct {
|
|||
}
|
||||
|
||||
type authenticationError struct {
|
||||
err error
|
||||
err error
|
||||
loginMethod string
|
||||
}
|
||||
|
||||
func (e *authenticationError) Error() string {
|
||||
|
@ -228,8 +229,12 @@ func (e *authenticationError) Unwrap() error {
|
|||
return e.err
|
||||
}
|
||||
|
||||
func newAuthenticationError(err error) *authenticationError {
|
||||
return &authenticationError{err: err}
|
||||
func (e *authenticationError) getLoginMethod() string {
|
||||
return e.loginMethod
|
||||
}
|
||||
|
||||
func newAuthenticationError(err error, loginMethod string) *authenticationError {
|
||||
return &authenticationError{err: err, loginMethod: loginMethod}
|
||||
}
|
||||
|
||||
// ShouldBind returns true if there is at least a valid binding
|
||||
|
@ -253,7 +258,8 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
|
|||
return sp, err
|
||||
}
|
||||
if err != nil {
|
||||
return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err))
|
||||
return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err),
|
||||
dataprovider.SSHLoginMethodPublicKey)
|
||||
}
|
||||
|
||||
return sp, nil
|
||||
|
@ -273,7 +279,8 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
|
|||
serverConfig.PasswordCallback = func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
|
||||
sp, err := c.validatePasswordCredentials(conn, pass)
|
||||
if err != nil {
|
||||
return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err))
|
||||
return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err),
|
||||
dataprovider.SSHLoginMethodPassword)
|
||||
}
|
||||
|
||||
return sp, nil
|
||||
|
@ -487,7 +494,8 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve
|
|||
serverConfig.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||
sp, err := c.validateKeyboardInteractiveCredentials(conn, client)
|
||||
if err != nil {
|
||||
return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err))
|
||||
return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err),
|
||||
dataprovider.SSHLoginMethodKeyboardInteractive)
|
||||
}
|
||||
|
||||
return sp, nil
|
||||
|
@ -561,7 +569,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
|
|||
}
|
||||
|
||||
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
|
||||
"User %#v logged in with %#v, from ip %#v, client version %#v", user.Username, loginType,
|
||||
"User %q logged in with %q, from ip %q, client version %q", user.Username, loginType,
|
||||
ipAddr, string(sconn.ClientVersion()))
|
||||
dataprovider.UpdateLastLogin(&user)
|
||||
|
||||
|
@ -683,16 +691,20 @@ func (c *Configuration) createHandlers(connection *Connection) sftp.Handlers {
|
|||
|
||||
func checkAuthError(ip string, err error) {
|
||||
if authErrors, ok := err.(*ssh.ServerAuthError); ok {
|
||||
event := common.HostEventLoginFailed
|
||||
// check public key auth errors here
|
||||
for _, err := range authErrors.Errors {
|
||||
if errors.Is(err, sftpAuthError) {
|
||||
if errors.Is(err, util.ErrNotFound) {
|
||||
event = common.HostEventUserNotFound
|
||||
var sftpAuthErr *authenticationError
|
||||
if errors.As(err, &sftpAuthErr) {
|
||||
if sftpAuthErr.getLoginMethod() == dataprovider.SSHLoginMethodPublicKey {
|
||||
event := common.HostEventLoginFailed
|
||||
if errors.Is(err, util.ErrNotFound) {
|
||||
event = common.HostEventUserNotFound
|
||||
}
|
||||
common.AddDefenderEvent(ip, event)
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
common.AddDefenderEvent(ip, event)
|
||||
} else {
|
||||
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
|
||||
metric.AddNoAuthTryed()
|
||||
|
@ -1078,7 +1090,7 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK
|
|||
cert.KeyId, cert.Serial, cert.Type(), ssh.FingerprintSHA256(cert.SignatureKey))
|
||||
}
|
||||
if user.IsPartialAuth(method) {
|
||||
logger.Debug(logSender, connectionID, "user %#v authenticated with partial success", conn.User())
|
||||
logger.Debug(logSender, connectionID, "user %q authenticated with partial success", conn.User())
|
||||
return certPerm, ssh.ErrPartialSuccess
|
||||
}
|
||||
sshPerm, err = loginUser(&user, method, keyID, conn)
|
||||
|
|
|
@ -966,6 +966,7 @@ func TestDefender(t *testing.T) {
|
|||
cfg.DefenderConfig.Enabled = true
|
||||
cfg.DefenderConfig.Threshold = 3
|
||||
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||
cfg.DefenderConfig.ScoreValid = 1
|
||||
|
||||
err := common.Initialize(cfg, 0)
|
||||
assert.NoError(t, err)
|
||||
|
@ -977,12 +978,23 @@ func TestDefender(t *testing.T) {
|
|||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
|
||||
err = checkBasicSFTP(client)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
user.Password = "wrong_pwd"
|
||||
user.Password = "wrong_pwd"
|
||||
_, _, err = getSftpClient(user, usePubKey)
|
||||
assert.Error(t, err)
|
||||
hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
host := hosts[0]
|
||||
assert.Empty(t, host.GetBanTime())
|
||||
assert.Equal(t, 1, host.Score)
|
||||
}
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
_, _, err = getSftpClient(user, usePubKey)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
|
|
@ -1058,6 +1058,7 @@ func TestDefender(t *testing.T) {
|
|||
cfg.DefenderConfig.Enabled = true
|
||||
cfg.DefenderConfig.Threshold = 3
|
||||
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||
cfg.DefenderConfig.ScoreValid = 1
|
||||
|
||||
err := common.Initialize(cfg, 0)
|
||||
assert.NoError(t, err)
|
||||
|
@ -1067,8 +1068,18 @@ func TestDefender(t *testing.T) {
|
|||
client := getWebDavClient(user, true, nil)
|
||||
assert.NoError(t, checkBasicFunc(client))
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
user.Password = "wrong_pwd"
|
||||
user.Password = "wrong_pwd"
|
||||
client = getWebDavClient(user, false, nil)
|
||||
assert.Error(t, checkBasicFunc(client))
|
||||
hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
host := hosts[0]
|
||||
assert.Empty(t, host.GetBanTime())
|
||||
assert.Equal(t, 1, host.Score)
|
||||
}
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
client = getWebDavClient(user, false, nil)
|
||||
assert.Error(t, checkBasicFunc(client))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue