diff --git a/common/defender.go b/common/defender.go index c1252664..5b55f170 100644 --- a/common/defender.go +++ b/common/defender.go @@ -267,17 +267,27 @@ func (d *memoryDefender) GetHost(ip string) (*DefenderEntry, error) { defer d.RUnlock() if banTime, ok := d.banned[ip]; ok { - return &DefenderEntry{ - IP: ip, - BanTime: banTime, - }, nil + if banTime.After(time.Now()) { + return &DefenderEntry{ + IP: ip, + BanTime: banTime, + }, nil + } } - if ev, ok := d.hosts[ip]; ok { - return &DefenderEntry{ - IP: ip, - Score: ev.TotalScore, - }, nil + if hs, ok := d.hosts[ip]; ok { + score := 0 + for _, event := range hs.Events { + if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) { + score += event.score + } + } + if score > 0 { + return &DefenderEntry{ + IP: ip, + Score: score, + }, nil + } } return nil, dataprovider.NewRecordNotFoundError("host not found") diff --git a/common/defender_test.go b/common/defender_test.go index ba530290..2fab5fe2 100644 --- a/common/defender_test.go +++ b/common/defender_test.go @@ -206,14 +206,14 @@ func TestExpiredHostBans(t *testing.T) { assert.Len(t, res, 0) assert.False(t, defender.IsBanned(testIP)) - entry, err := defender.GetHost(testIP) - assert.NoError(t, err) - assert.Equal(t, testIP, entry.IP) - assert.NotEmpty(t, entry.GetBanTime()) + _, err = defender.GetHost(testIP) + assert.Error(t, err) + _, ok := defender.banned[testIP] + assert.True(t, ok) // now add an event for an expired banned ip, it should be removed defender.AddEvent(testIP, HostEventLoginFailed) assert.False(t, defender.IsBanned(testIP)) - entry, err = defender.GetHost(testIP) + entry, err := defender.GetHost(testIP) assert.NoError(t, err) assert.Equal(t, testIP, entry.IP) assert.Empty(t, entry.GetBanTime()) @@ -246,12 +246,8 @@ func TestExpiredHostBans(t *testing.T) { // the recorded scored are too old res = defender.GetHosts() assert.Len(t, res, 0) - // the old API still returns the host - entry, err = defender.GetHost(testIP) - assert.NoError(t, err) - assert.Equal(t, testIP, entry.IP) - assert.Empty(t, entry.GetBanTime()) - assert.Equal(t, 5, entry.Score) + _, err = defender.GetHost(testIP) + assert.Error(t, err) } func TestLoadHostListFromFile(t *testing.T) {