db defender: fix list hosts queries

ensure that banned hosts are always returned

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2022-03-16 18:27:47 +01:00
parent 7959737442
commit 5a45af76f3
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
2 changed files with 22 additions and 5 deletions

View file

@ -197,9 +197,10 @@ func TestBasicDbDefender(t *testing.T) {
assert.Equal(t, 1, host.Score) assert.Equal(t, 1, host.Score)
assert.True(t, host.BanTime.IsZero()) assert.True(t, host.BanTime.IsZero())
assert.Empty(t, host.GetBanTime()) assert.Empty(t, host.GetBanTime())
// cleanup db // set a negative observation time so the from field in the queries will be in the future
err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute))) // we still should get the banned hosts
assert.NoError(t, err) defender.config.ObservationTime = -2
assert.Greater(t, defender.getStartObservationTime(), time.Now().UnixMilli())
hosts, err = defender.GetHosts() hosts, err = defender.GetHosts()
assert.NoError(t, err) assert.NoError(t, err)
if assert.Len(t, hosts, 1) { if assert.Len(t, hosts, 1) {
@ -208,6 +209,22 @@ func TestBasicDbDefender(t *testing.T) {
assert.False(t, hosts[0].BanTime.IsZero()) assert.False(t, hosts[0].BanTime.IsZero())
assert.NotEmpty(t, hosts[0].GetBanTime()) assert.NotEmpty(t, hosts[0].GetBanTime())
} }
host, err = defender.GetHost(testIP)
assert.NoError(t, err)
// cleanup db
err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
assert.NoError(t, err)
// the banned host must still be there
hosts, err = defender.GetHosts()
assert.NoError(t, err)
if assert.Len(t, hosts, 1) {
assert.Equal(t, testIP, hosts[0].IP)
assert.Equal(t, 0, hosts[0].Score)
assert.False(t, hosts[0].BanTime.IsZero())
assert.NotEmpty(t, hosts[0].GetBanTime())
}
host, err = defender.GetHost(testIP)
assert.NoError(t, err)
err = dataprovider.SetDefenderBanTime(testIP, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute))) err = dataprovider.SetDefenderBanTime(testIP, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
assert.NoError(t, err) assert.NoError(t, err)
err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute))) err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))

View file

@ -46,12 +46,12 @@ func getAddDefenderEventQuery() string {
} }
func getDefenderHostsQuery() string { func getDefenderHostsQuery() string {
return fmt.Sprintf(`SELECT id,ip,ban_time FROM %v WHERE updated_at >= %v ORDER BY updated_at DESC LIMIT %v`, return fmt.Sprintf(`SELECT id,ip,ban_time FROM %v WHERE updated_at >= %v OR ban_time > 0 ORDER BY updated_at DESC LIMIT %v`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
} }
func getDefenderHostQuery() string { func getDefenderHostQuery() string {
return fmt.Sprintf(`SELECT id,ip,ban_time FROM %v WHERE ip = %v AND updated_at >= %v`, return fmt.Sprintf(`SELECT id,ip,ban_time FROM %v WHERE ip = %v AND (updated_at >= %v OR ban_time > 0)`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
} }