diff --git a/internal/ftpd/ftpd_test.go b/internal/ftpd/ftpd_test.go index fadc377e..5fb29ba5 100644 --- a/internal/ftpd/ftpd_test.go +++ b/internal/ftpd/ftpd_test.go @@ -1749,8 +1749,10 @@ func TestDefender(t *testing.T) { cfg := config.GetCommonConfig() cfg.DefenderConfig.Enabled = true - cfg.DefenderConfig.Threshold = 3 + cfg.DefenderConfig.Threshold = 4 cfg.DefenderConfig.ScoreLimitExceeded = 2 + cfg.DefenderConfig.ScoreNoAuth = 1 + cfg.DefenderConfig.ScoreValid = 1 err := common.Initialize(cfg, 0) assert.NoError(t, err) @@ -1764,9 +1766,31 @@ func TestDefender(t *testing.T) { err = client.Quit() assert.NoError(t, err) } + // just dial without login + ftpOptions := []ftp.DialOption{ftp.DialWithTimeout(5 * time.Second)} + client, err = ftp.Dial(ftpServerAddr, ftpOptions...) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + host := hosts[0] + assert.Empty(t, host.GetBanTime()) + assert.Equal(t, 1, host.Score) + } + user.Password = "wrong_pwd" + _, err = getFTPClient(user, false, nil) + assert.Error(t, err) + hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + host := hosts[0] + assert.Empty(t, host.GetBanTime()) + assert.Equal(t, 2, host.Score) + } - for i := 0; i < 3; i++ { - user.Password = "wrong_pwd" + for i := 0; i < 2; i++ { _, err = getFTPClient(user, false, nil) assert.Error(t, err) } diff --git a/internal/sftpd/internal_test.go b/internal/sftpd/internal_test.go index a495485c..ab7b48fb 100644 --- a/internal/sftpd/internal_test.go +++ b/internal/sftpd/internal_test.go @@ -2229,19 +2229,24 @@ func TestCanReadSymlink(t *testing.T) { } func TestAuthenticationErrors(t *testing.T) { - err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found"))) + loginMethod := dataprovider.SSHLoginMethodPassword + err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found")), loginMethod) assert.ErrorIs(t, err, sftpAuthError) assert.ErrorIs(t, err, util.ErrNotFound) - err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission)) + var sftpAuthErr *authenticationError + if assert.ErrorAs(t, err, &sftpAuthErr) { + assert.Equal(t, loginMethod, sftpAuthErr.getLoginMethod()) + } + err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission), loginMethod) assert.ErrorIs(t, err, sftpAuthError) assert.NotErrorIs(t, err, util.ErrNotFound) - err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert)) + err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert), loginMethod) assert.ErrorIs(t, err, sftpAuthError) assert.NotErrorIs(t, err, util.ErrNotFound) - err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority")) + err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority"), loginMethod) assert.ErrorIs(t, err, sftpAuthError) assert.NotErrorIs(t, err, util.ErrNotFound) - err = newAuthenticationError(nil) + err = newAuthenticationError(nil, loginMethod) assert.ErrorIs(t, err, sftpAuthError) assert.NotErrorIs(t, err, util.ErrNotFound) } diff --git a/internal/sftpd/server.go b/internal/sftpd/server.go index e7cd75ab..ca150bc8 100644 --- a/internal/sftpd/server.go +++ b/internal/sftpd/server.go @@ -93,7 +93,7 @@ var ( certs: map[string]bool{}, } - sftpAuthError = newAuthenticationError(nil) + sftpAuthError = newAuthenticationError(nil, "") ) // Binding defines the configuration for a network listener @@ -210,7 +210,8 @@ type Configuration struct { } type authenticationError struct { - err error + err error + loginMethod string } func (e *authenticationError) Error() string { @@ -228,8 +229,12 @@ func (e *authenticationError) Unwrap() error { return e.err } -func newAuthenticationError(err error) *authenticationError { - return &authenticationError{err: err} +func (e *authenticationError) getLoginMethod() string { + return e.loginMethod +} + +func newAuthenticationError(err error, loginMethod string) *authenticationError { + return &authenticationError{err: err, loginMethod: loginMethod} } // ShouldBind returns true if there is at least a valid binding @@ -253,7 +258,8 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig { return sp, err } if err != nil { - return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err)) + return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err), + dataprovider.SSHLoginMethodPublicKey) } return sp, nil @@ -273,7 +279,8 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig { serverConfig.PasswordCallback = func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { sp, err := c.validatePasswordCredentials(conn, pass) if err != nil { - return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err)) + return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err), + dataprovider.SSHLoginMethodPassword) } return sp, nil @@ -487,7 +494,8 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve serverConfig.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { sp, err := c.validateKeyboardInteractiveCredentials(conn, client) if err != nil { - return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err)) + return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err), + dataprovider.SSHLoginMethodKeyboardInteractive) } return sp, nil @@ -561,7 +569,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve } logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID, - "User %#v logged in with %#v, from ip %#v, client version %#v", user.Username, loginType, + "User %q logged in with %q, from ip %q, client version %q", user.Username, loginType, ipAddr, string(sconn.ClientVersion())) dataprovider.UpdateLastLogin(&user) @@ -683,16 +691,20 @@ func (c *Configuration) createHandlers(connection *Connection) sftp.Handlers { func checkAuthError(ip string, err error) { if authErrors, ok := err.(*ssh.ServerAuthError); ok { - event := common.HostEventLoginFailed + // check public key auth errors here for _, err := range authErrors.Errors { - if errors.Is(err, sftpAuthError) { - if errors.Is(err, util.ErrNotFound) { - event = common.HostEventUserNotFound + var sftpAuthErr *authenticationError + if errors.As(err, &sftpAuthErr) { + if sftpAuthErr.getLoginMethod() == dataprovider.SSHLoginMethodPublicKey { + event := common.HostEventLoginFailed + if errors.Is(err, util.ErrNotFound) { + event = common.HostEventUserNotFound + } + common.AddDefenderEvent(ip, event) + return } - break } } - common.AddDefenderEvent(ip, event) } else { logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error()) metric.AddNoAuthTryed() @@ -1078,7 +1090,7 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK cert.KeyId, cert.Serial, cert.Type(), ssh.FingerprintSHA256(cert.SignatureKey)) } if user.IsPartialAuth(method) { - logger.Debug(logSender, connectionID, "user %#v authenticated with partial success", conn.User()) + logger.Debug(logSender, connectionID, "user %q authenticated with partial success", conn.User()) return certPerm, ssh.ErrPartialSuccess } sshPerm, err = loginUser(&user, method, keyID, conn) diff --git a/internal/sftpd/sftpd_test.go b/internal/sftpd/sftpd_test.go index 122c66ff..625d1d37 100644 --- a/internal/sftpd/sftpd_test.go +++ b/internal/sftpd/sftpd_test.go @@ -966,6 +966,7 @@ func TestDefender(t *testing.T) { cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Threshold = 3 cfg.DefenderConfig.ScoreLimitExceeded = 2 + cfg.DefenderConfig.ScoreValid = 1 err := common.Initialize(cfg, 0) assert.NoError(t, err) @@ -977,12 +978,23 @@ func TestDefender(t *testing.T) { if assert.NoError(t, err) { defer conn.Close() defer client.Close() + err = checkBasicSFTP(client) assert.NoError(t, err) } - for i := 0; i < 3; i++ { - user.Password = "wrong_pwd" + user.Password = "wrong_pwd" + _, _, err = getSftpClient(user, usePubKey) + assert.Error(t, err) + hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + host := hosts[0] + assert.Empty(t, host.GetBanTime()) + assert.Equal(t, 1, host.Score) + } + + for i := 0; i < 2; i++ { _, _, err = getSftpClient(user, usePubKey) assert.Error(t, err) } diff --git a/internal/webdavd/webdavd_test.go b/internal/webdavd/webdavd_test.go index 435d4a05..c1d3be3f 100644 --- a/internal/webdavd/webdavd_test.go +++ b/internal/webdavd/webdavd_test.go @@ -1058,6 +1058,7 @@ func TestDefender(t *testing.T) { cfg.DefenderConfig.Enabled = true cfg.DefenderConfig.Threshold = 3 cfg.DefenderConfig.ScoreLimitExceeded = 2 + cfg.DefenderConfig.ScoreValid = 1 err := common.Initialize(cfg, 0) assert.NoError(t, err) @@ -1067,8 +1068,18 @@ func TestDefender(t *testing.T) { client := getWebDavClient(user, true, nil) assert.NoError(t, checkBasicFunc(client)) - for i := 0; i < 3; i++ { - user.Password = "wrong_pwd" + user.Password = "wrong_pwd" + client = getWebDavClient(user, false, nil) + assert.Error(t, checkBasicFunc(client)) + hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + host := hosts[0] + assert.Empty(t, host.GetBanTime()) + assert.Equal(t, 1, host.Score) + } + + for i := 0; i < 2; i++ { client = getWebDavClient(user, false, nil) assert.Error(t, checkBasicFunc(client)) }