diff --git a/common/defenderdb_test.go b/common/defenderdb_test.go index b78979e2..0ae894ba 100644 --- a/common/defenderdb_test.go +++ b/common/defenderdb_test.go @@ -197,9 +197,10 @@ func TestBasicDbDefender(t *testing.T) { assert.Equal(t, 1, host.Score) assert.True(t, host.BanTime.IsZero()) assert.Empty(t, host.GetBanTime()) - // cleanup db - err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute))) - assert.NoError(t, err) + // set a negative observation time so the from field in the queries will be in the future + // we still should get the banned hosts + defender.config.ObservationTime = -2 + assert.Greater(t, defender.getStartObservationTime(), time.Now().UnixMilli()) hosts, err = defender.GetHosts() assert.NoError(t, err) if assert.Len(t, hosts, 1) { @@ -208,6 +209,22 @@ func TestBasicDbDefender(t *testing.T) { assert.False(t, hosts[0].BanTime.IsZero()) assert.NotEmpty(t, hosts[0].GetBanTime()) } + host, err = defender.GetHost(testIP) + assert.NoError(t, err) + // cleanup db + err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute))) + assert.NoError(t, err) + // the banned host must still be there + hosts, err = defender.GetHosts() + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + assert.Equal(t, testIP, hosts[0].IP) + assert.Equal(t, 0, hosts[0].Score) + assert.False(t, hosts[0].BanTime.IsZero()) + assert.NotEmpty(t, hosts[0].GetBanTime()) + } + host, err = defender.GetHost(testIP) + assert.NoError(t, err) err = dataprovider.SetDefenderBanTime(testIP, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute))) assert.NoError(t, err) err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute))) diff --git a/dataprovider/sqlqueries.go b/dataprovider/sqlqueries.go index 133c1a89..b5889ee3 100644 --- a/dataprovider/sqlqueries.go +++ b/dataprovider/sqlqueries.go @@ -46,12 +46,12 @@ func getAddDefenderEventQuery() string { } func getDefenderHostsQuery() string { - return fmt.Sprintf(`SELECT id,ip,ban_time FROM %v WHERE updated_at >= %v ORDER BY updated_at DESC LIMIT %v`, + return fmt.Sprintf(`SELECT id,ip,ban_time FROM %v WHERE updated_at >= %v OR ban_time > 0 ORDER BY updated_at DESC LIMIT %v`, sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) } func getDefenderHostQuery() string { - return fmt.Sprintf(`SELECT id,ip,ban_time FROM %v WHERE ip = %v AND updated_at >= %v`, + return fmt.Sprintf(`SELECT id,ip,ban_time FROM %v WHERE ip = %v AND (updated_at >= %v OR ban_time > 0)`, sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) }