mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-22 07:30:25 +00:00
sftpd: fix duplicate defender error introduced in the previous commit
improve the defender test cases by verifying the expected score Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
parent
c0fe08b597
commit
27c4ffd663
5 changed files with 91 additions and 27 deletions
|
@ -1749,8 +1749,10 @@ func TestDefender(t *testing.T) {
|
||||||
|
|
||||||
cfg := config.GetCommonConfig()
|
cfg := config.GetCommonConfig()
|
||||||
cfg.DefenderConfig.Enabled = true
|
cfg.DefenderConfig.Enabled = true
|
||||||
cfg.DefenderConfig.Threshold = 3
|
cfg.DefenderConfig.Threshold = 4
|
||||||
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||||
|
cfg.DefenderConfig.ScoreNoAuth = 1
|
||||||
|
cfg.DefenderConfig.ScoreValid = 1
|
||||||
|
|
||||||
err := common.Initialize(cfg, 0)
|
err := common.Initialize(cfg, 0)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -1764,9 +1766,31 @@ func TestDefender(t *testing.T) {
|
||||||
err = client.Quit()
|
err = client.Quit()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
// just dial without login
|
||||||
|
ftpOptions := []ftp.DialOption{ftp.DialWithTimeout(5 * time.Second)}
|
||||||
|
client, err = ftp.Dial(ftpServerAddr, ftpOptions...)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = client.Quit()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
if assert.Len(t, hosts, 1) {
|
||||||
|
host := hosts[0]
|
||||||
|
assert.Empty(t, host.GetBanTime())
|
||||||
|
assert.Equal(t, 1, host.Score)
|
||||||
|
}
|
||||||
|
user.Password = "wrong_pwd"
|
||||||
|
_, err = getFTPClient(user, false, nil)
|
||||||
|
assert.Error(t, err)
|
||||||
|
hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
if assert.Len(t, hosts, 1) {
|
||||||
|
host := hosts[0]
|
||||||
|
assert.Empty(t, host.GetBanTime())
|
||||||
|
assert.Equal(t, 2, host.Score)
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
user.Password = "wrong_pwd"
|
|
||||||
_, err = getFTPClient(user, false, nil)
|
_, err = getFTPClient(user, false, nil)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2229,19 +2229,24 @@ func TestCanReadSymlink(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthenticationErrors(t *testing.T) {
|
func TestAuthenticationErrors(t *testing.T) {
|
||||||
err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found")))
|
loginMethod := dataprovider.SSHLoginMethodPassword
|
||||||
|
err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found")), loginMethod)
|
||||||
assert.ErrorIs(t, err, sftpAuthError)
|
assert.ErrorIs(t, err, sftpAuthError)
|
||||||
assert.ErrorIs(t, err, util.ErrNotFound)
|
assert.ErrorIs(t, err, util.ErrNotFound)
|
||||||
err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission))
|
var sftpAuthErr *authenticationError
|
||||||
|
if assert.ErrorAs(t, err, &sftpAuthErr) {
|
||||||
|
assert.Equal(t, loginMethod, sftpAuthErr.getLoginMethod())
|
||||||
|
}
|
||||||
|
err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission), loginMethod)
|
||||||
assert.ErrorIs(t, err, sftpAuthError)
|
assert.ErrorIs(t, err, sftpAuthError)
|
||||||
assert.NotErrorIs(t, err, util.ErrNotFound)
|
assert.NotErrorIs(t, err, util.ErrNotFound)
|
||||||
err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert))
|
err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert), loginMethod)
|
||||||
assert.ErrorIs(t, err, sftpAuthError)
|
assert.ErrorIs(t, err, sftpAuthError)
|
||||||
assert.NotErrorIs(t, err, util.ErrNotFound)
|
assert.NotErrorIs(t, err, util.ErrNotFound)
|
||||||
err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority"))
|
err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority"), loginMethod)
|
||||||
assert.ErrorIs(t, err, sftpAuthError)
|
assert.ErrorIs(t, err, sftpAuthError)
|
||||||
assert.NotErrorIs(t, err, util.ErrNotFound)
|
assert.NotErrorIs(t, err, util.ErrNotFound)
|
||||||
err = newAuthenticationError(nil)
|
err = newAuthenticationError(nil, loginMethod)
|
||||||
assert.ErrorIs(t, err, sftpAuthError)
|
assert.ErrorIs(t, err, sftpAuthError)
|
||||||
assert.NotErrorIs(t, err, util.ErrNotFound)
|
assert.NotErrorIs(t, err, util.ErrNotFound)
|
||||||
}
|
}
|
||||||
|
|
|
@ -93,7 +93,7 @@ var (
|
||||||
certs: map[string]bool{},
|
certs: map[string]bool{},
|
||||||
}
|
}
|
||||||
|
|
||||||
sftpAuthError = newAuthenticationError(nil)
|
sftpAuthError = newAuthenticationError(nil, "")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Binding defines the configuration for a network listener
|
// Binding defines the configuration for a network listener
|
||||||
|
@ -210,7 +210,8 @@ type Configuration struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type authenticationError struct {
|
type authenticationError struct {
|
||||||
err error
|
err error
|
||||||
|
loginMethod string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *authenticationError) Error() string {
|
func (e *authenticationError) Error() string {
|
||||||
|
@ -228,8 +229,12 @@ func (e *authenticationError) Unwrap() error {
|
||||||
return e.err
|
return e.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAuthenticationError(err error) *authenticationError {
|
func (e *authenticationError) getLoginMethod() string {
|
||||||
return &authenticationError{err: err}
|
return e.loginMethod
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthenticationError(err error, loginMethod string) *authenticationError {
|
||||||
|
return &authenticationError{err: err, loginMethod: loginMethod}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShouldBind returns true if there is at least a valid binding
|
// ShouldBind returns true if there is at least a valid binding
|
||||||
|
@ -253,7 +258,8 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
|
||||||
return sp, err
|
return sp, err
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err))
|
return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err),
|
||||||
|
dataprovider.SSHLoginMethodPublicKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
return sp, nil
|
return sp, nil
|
||||||
|
@ -273,7 +279,8 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
|
||||||
serverConfig.PasswordCallback = func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
|
serverConfig.PasswordCallback = func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
|
||||||
sp, err := c.validatePasswordCredentials(conn, pass)
|
sp, err := c.validatePasswordCredentials(conn, pass)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err))
|
return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err),
|
||||||
|
dataprovider.SSHLoginMethodPassword)
|
||||||
}
|
}
|
||||||
|
|
||||||
return sp, nil
|
return sp, nil
|
||||||
|
@ -487,7 +494,8 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve
|
||||||
serverConfig.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
serverConfig.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||||
sp, err := c.validateKeyboardInteractiveCredentials(conn, client)
|
sp, err := c.validateKeyboardInteractiveCredentials(conn, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err))
|
return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err),
|
||||||
|
dataprovider.SSHLoginMethodKeyboardInteractive)
|
||||||
}
|
}
|
||||||
|
|
||||||
return sp, nil
|
return sp, nil
|
||||||
|
@ -561,7 +569,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
|
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
|
||||||
"User %#v logged in with %#v, from ip %#v, client version %#v", user.Username, loginType,
|
"User %q logged in with %q, from ip %q, client version %q", user.Username, loginType,
|
||||||
ipAddr, string(sconn.ClientVersion()))
|
ipAddr, string(sconn.ClientVersion()))
|
||||||
dataprovider.UpdateLastLogin(&user)
|
dataprovider.UpdateLastLogin(&user)
|
||||||
|
|
||||||
|
@ -683,16 +691,20 @@ func (c *Configuration) createHandlers(connection *Connection) sftp.Handlers {
|
||||||
|
|
||||||
func checkAuthError(ip string, err error) {
|
func checkAuthError(ip string, err error) {
|
||||||
if authErrors, ok := err.(*ssh.ServerAuthError); ok {
|
if authErrors, ok := err.(*ssh.ServerAuthError); ok {
|
||||||
event := common.HostEventLoginFailed
|
// check public key auth errors here
|
||||||
for _, err := range authErrors.Errors {
|
for _, err := range authErrors.Errors {
|
||||||
if errors.Is(err, sftpAuthError) {
|
var sftpAuthErr *authenticationError
|
||||||
if errors.Is(err, util.ErrNotFound) {
|
if errors.As(err, &sftpAuthErr) {
|
||||||
event = common.HostEventUserNotFound
|
if sftpAuthErr.getLoginMethod() == dataprovider.SSHLoginMethodPublicKey {
|
||||||
|
event := common.HostEventLoginFailed
|
||||||
|
if errors.Is(err, util.ErrNotFound) {
|
||||||
|
event = common.HostEventUserNotFound
|
||||||
|
}
|
||||||
|
common.AddDefenderEvent(ip, event)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
common.AddDefenderEvent(ip, event)
|
|
||||||
} else {
|
} else {
|
||||||
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
|
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
|
||||||
metric.AddNoAuthTryed()
|
metric.AddNoAuthTryed()
|
||||||
|
@ -1078,7 +1090,7 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK
|
||||||
cert.KeyId, cert.Serial, cert.Type(), ssh.FingerprintSHA256(cert.SignatureKey))
|
cert.KeyId, cert.Serial, cert.Type(), ssh.FingerprintSHA256(cert.SignatureKey))
|
||||||
}
|
}
|
||||||
if user.IsPartialAuth(method) {
|
if user.IsPartialAuth(method) {
|
||||||
logger.Debug(logSender, connectionID, "user %#v authenticated with partial success", conn.User())
|
logger.Debug(logSender, connectionID, "user %q authenticated with partial success", conn.User())
|
||||||
return certPerm, ssh.ErrPartialSuccess
|
return certPerm, ssh.ErrPartialSuccess
|
||||||
}
|
}
|
||||||
sshPerm, err = loginUser(&user, method, keyID, conn)
|
sshPerm, err = loginUser(&user, method, keyID, conn)
|
||||||
|
|
|
@ -966,6 +966,7 @@ func TestDefender(t *testing.T) {
|
||||||
cfg.DefenderConfig.Enabled = true
|
cfg.DefenderConfig.Enabled = true
|
||||||
cfg.DefenderConfig.Threshold = 3
|
cfg.DefenderConfig.Threshold = 3
|
||||||
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||||
|
cfg.DefenderConfig.ScoreValid = 1
|
||||||
|
|
||||||
err := common.Initialize(cfg, 0)
|
err := common.Initialize(cfg, 0)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -977,12 +978,23 @@ func TestDefender(t *testing.T) {
|
||||||
if assert.NoError(t, err) {
|
if assert.NoError(t, err) {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
|
||||||
err = checkBasicSFTP(client)
|
err = checkBasicSFTP(client)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
user.Password = "wrong_pwd"
|
||||||
user.Password = "wrong_pwd"
|
_, _, err = getSftpClient(user, usePubKey)
|
||||||
|
assert.Error(t, err)
|
||||||
|
hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
if assert.Len(t, hosts, 1) {
|
||||||
|
host := hosts[0]
|
||||||
|
assert.Empty(t, host.GetBanTime())
|
||||||
|
assert.Equal(t, 1, host.Score)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
_, _, err = getSftpClient(user, usePubKey)
|
_, _, err = getSftpClient(user, usePubKey)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1058,6 +1058,7 @@ func TestDefender(t *testing.T) {
|
||||||
cfg.DefenderConfig.Enabled = true
|
cfg.DefenderConfig.Enabled = true
|
||||||
cfg.DefenderConfig.Threshold = 3
|
cfg.DefenderConfig.Threshold = 3
|
||||||
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||||
|
cfg.DefenderConfig.ScoreValid = 1
|
||||||
|
|
||||||
err := common.Initialize(cfg, 0)
|
err := common.Initialize(cfg, 0)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -1067,8 +1068,18 @@ func TestDefender(t *testing.T) {
|
||||||
client := getWebDavClient(user, true, nil)
|
client := getWebDavClient(user, true, nil)
|
||||||
assert.NoError(t, checkBasicFunc(client))
|
assert.NoError(t, checkBasicFunc(client))
|
||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
user.Password = "wrong_pwd"
|
||||||
user.Password = "wrong_pwd"
|
client = getWebDavClient(user, false, nil)
|
||||||
|
assert.Error(t, checkBasicFunc(client))
|
||||||
|
hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
if assert.Len(t, hosts, 1) {
|
||||||
|
host := hosts[0]
|
||||||
|
assert.Empty(t, host.GetBanTime())
|
||||||
|
assert.Equal(t, 1, host.Score)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
client = getWebDavClient(user, false, nil)
|
client = getWebDavClient(user, false, nil)
|
||||||
assert.Error(t, checkBasicFunc(client))
|
assert.Error(t, checkBasicFunc(client))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue