فهرست منبع

defender: don't return expired hosts/banned ip

Nicola Murino 4 سال پیش
والد
کامیت
f2f612b450
2فایلهای تغییر یافته به همراه98 افزوده شده و 10 حذف شده
  1. 23 10
      common/defender.go
  2. 75 0
      common/defender_test.go

+ 23 - 10
common/defender.go

@@ -236,16 +236,26 @@ func (d *memoryDefender) GetHosts() []*DefenderEntry {
 
 	var result []*DefenderEntry
 	for k, v := range d.banned {
-		result = append(result, &DefenderEntry{
-			IP:      k,
-			BanTime: v,
-		})
+		if v.After(time.Now()) {
+			result = append(result, &DefenderEntry{
+				IP:      k,
+				BanTime: v,
+			})
+		}
 	}
 	for k, v := range d.hosts {
-		result = append(result, &DefenderEntry{
-			IP:    k,
-			Score: v.TotalScore,
-		})
+		score := 0
+		for _, event := range v.Events {
+			if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
+				score += event.score
+			}
+		}
+		if score > 0 {
+			result = append(result, &DefenderEntry{
+				IP:    k,
+				Score: score,
+			})
+		}
 	}
 
 	return result
@@ -339,8 +349,11 @@ func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
 	}
 
 	// ignore events for already banned hosts
-	if _, ok := d.banned[ip]; ok {
-		return
+	if v, ok := d.banned[ip]; ok {
+		if v.After(time.Now()) {
+			return
+		}
+		delete(d.banned, ip)
 	}
 
 	var score int

+ 75 - 0
common/defender_test.go

@@ -179,6 +179,81 @@ func TestBasicDefender(t *testing.T) {
 	assert.NoError(t, err)
 }
 
+func TestExpiredHostBans(t *testing.T) {
+	config := &DefenderConfig{
+		Enabled:            true,
+		BanTime:            10,
+		BanTimeIncrement:   2,
+		Threshold:          5,
+		ScoreInvalid:       2,
+		ScoreValid:         1,
+		ScoreLimitExceeded: 3,
+		ObservationTime:    15,
+		EntriesSoftLimit:   1,
+		EntriesHardLimit:   2,
+	}
+
+	d, err := newInMemoryDefender(config)
+	assert.NoError(t, err)
+
+	defender := d.(*memoryDefender)
+
+	testIP := "1.2.3.4"
+	defender.banned[testIP] = time.Now().Add(-24 * time.Hour)
+
+	// the ban is expired testIP should not be listed
+	res := defender.GetHosts()
+	assert.Len(t, res, 0)
+
+	assert.False(t, defender.IsBanned(testIP))
+	entry, err := defender.GetHost(testIP)
+	assert.NoError(t, err)
+	assert.Equal(t, testIP, entry.IP)
+	assert.NotEmpty(t, entry.GetBanTime())
+	// now add an event for an expired banned ip, it should be removed
+	defender.AddEvent(testIP, HostEventLoginFailed)
+	assert.False(t, defender.IsBanned(testIP))
+	entry, err = defender.GetHost(testIP)
+	assert.NoError(t, err)
+	assert.Equal(t, testIP, entry.IP)
+	assert.Empty(t, entry.GetBanTime())
+	assert.Equal(t, 1, entry.Score)
+
+	res = defender.GetHosts()
+	if assert.Len(t, res, 1) {
+		assert.Equal(t, testIP, res[0].IP)
+		assert.Empty(t, res[0].GetBanTime())
+		assert.Equal(t, 1, res[0].Score)
+	}
+
+	events := []hostEvent{
+		{
+			dateTime: time.Now().Add(-24 * time.Hour),
+			score:    2,
+		},
+		{
+			dateTime: time.Now().Add(-24 * time.Hour),
+			score:    3,
+		},
+	}
+
+	hs := hostScore{
+		Events:     events,
+		TotalScore: 5,
+	}
+
+	defender.hosts[testIP] = hs
+	// the recorded scored are too old
+	res = defender.GetHosts()
+	assert.Len(t, res, 0)
+	// the old API still returns the host
+	entry, err = defender.GetHost(testIP)
+	assert.NoError(t, err)
+	assert.Equal(t, testIP, entry.IP)
+	assert.Empty(t, entry.GetBanTime())
+	assert.Equal(t, 5, entry.Score)
+}
+
 func TestLoadHostListFromFile(t *testing.T) {
 	_, err := loadHostListFromFile(".")
 	assert.Error(t, err)