浏览代码

defender: fix GetHost for blocklist entries too

Nicola Murino 4 年之前
父节点
当前提交
e09bdd43d4
共有 2 个文件被更改,包括 12 次插入10 次删除
  1. 6 4
      common/defender.go
  2. 6 6
      common/defender_test.go

+ 6 - 4
common/defender.go

@@ -266,10 +266,12 @@ func (d *memoryDefender) GetHost(ip string) (*DefenderEntry, error) {
 	defer d.RUnlock()
 
 	if banTime, ok := d.banned[ip]; ok {
-		return &DefenderEntry{
-			IP:      ip,
-			BanTime: banTime,
-		}, nil
+		if banTime.After(time.Now()) {
+			return &DefenderEntry{
+				IP:      ip,
+				BanTime: banTime,
+			}, nil
+		}
 	}
 
 	if hs, ok := d.hosts[ip]; ok {

+ 6 - 6
common/defender_test.go

@@ -206,14 +206,14 @@ func TestExpiredHostBans(t *testing.T) {
 	assert.Len(t, res, 0)
 
 	assert.False(t, defender.IsBanned(testIP))
-	entry, err := defender.GetHost(testIP)
-	assert.NoError(t, err)
-	assert.Equal(t, testIP, entry.IP)
-	assert.NotEmpty(t, entry.GetBanTime())
+	_, err = defender.GetHost(testIP)
+	assert.Error(t, err)
+	_, ok := defender.banned[testIP]
+	assert.True(t, ok)
 	// now add an event for an expired banned ip, it should be removed
 	defender.AddEvent(testIP, HostEventLoginFailed)
 	assert.False(t, defender.IsBanned(testIP))
-	entry, err = defender.GetHost(testIP)
+	entry, err := defender.GetHost(testIP)
 	assert.NoError(t, err)
 	assert.Equal(t, testIP, entry.IP)
 	assert.Empty(t, entry.GetBanTime())
@@ -248,7 +248,7 @@ func TestExpiredHostBans(t *testing.T) {
 	assert.Len(t, res, 0)
 	_, err = defender.GetHost(testIP)
 	assert.Error(t, err)
-	_, ok := defender.hosts[testIP]
+	_, ok = defender.hosts[testIP]
 	assert.True(t, ok)
 }