sftpgo-mirror/common/defendermem.go

327 lines
6.9 KiB
Go
Raw Normal View History

package common
import (
"sort"
"time"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/util"
)
type memoryDefender struct {
baseDefender
// IP addresses of the clients trying to connected are stored inside hosts,
// they are added to banned once the thresold is reached.
// A violation from a banned host will increase the ban time
// based on the configured BanTimeIncrement
hosts map[string]hostScore // the key is the host IP
banned map[string]time.Time // the key is the host IP
}
func newInMemoryDefender(config *DefenderConfig) (Defender, error) {
err := config.validate()
if err != nil {
return nil, err
}
defender := &memoryDefender{
baseDefender: baseDefender{
config: config,
},
hosts: make(map[string]hostScore),
banned: make(map[string]time.Time),
}
if err := defender.Reload(); err != nil {
return nil, err
}
return defender, nil
}
// GetHosts returns hosts that are banned or for which some violations have been detected
func (d *memoryDefender) GetHosts() ([]dataprovider.DefenderEntry, error) {
d.RLock()
defer d.RUnlock()
var result []dataprovider.DefenderEntry
for k, v := range d.banned {
if v.After(time.Now()) {
result = append(result, dataprovider.DefenderEntry{
IP: k,
BanTime: v,
})
}
}
for k, v := range d.hosts {
score := 0
for _, event := range v.Events {
if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
score += event.score
}
}
if score > 0 {
result = append(result, dataprovider.DefenderEntry{
IP: k,
Score: score,
})
}
}
return result, nil
}
// GetHost returns a defender host by ip, if any
func (d *memoryDefender) GetHost(ip string) (dataprovider.DefenderEntry, error) {
d.RLock()
defer d.RUnlock()
if banTime, ok := d.banned[ip]; ok {
if banTime.After(time.Now()) {
return dataprovider.DefenderEntry{
IP: ip,
BanTime: banTime,
}, nil
}
}
if hs, ok := d.hosts[ip]; ok {
score := 0
for _, event := range hs.Events {
if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
score += event.score
}
}
if score > 0 {
return dataprovider.DefenderEntry{
IP: ip,
Score: score,
}, nil
}
}
return dataprovider.DefenderEntry{}, util.NewRecordNotFoundError("host not found")
}
// IsBanned returns true if the specified IP is banned
// and increase ban time if the IP is found.
// This method must be called as soon as the client connects
func (d *memoryDefender) IsBanned(ip string) bool {
d.RLock()
if banTime, ok := d.banned[ip]; ok {
if banTime.After(time.Now()) {
increment := d.config.BanTime * d.config.BanTimeIncrement / 100
if increment == 0 {
increment++
}
d.RUnlock()
// we can save an earlier ban time if there are contemporary updates
// but this should not make much difference. I prefer to hold a read lock
// until possible for performance reasons, this method is called each
// time a new client connects and it must be as fast as possible
d.Lock()
d.banned[ip] = banTime.Add(time.Duration(increment) * time.Minute)
d.Unlock()
return true
}
}
defer d.RUnlock()
return d.baseDefender.isBanned(ip)
}
// DeleteHost removes the specified IP from the defender lists
func (d *memoryDefender) DeleteHost(ip string) bool {
d.Lock()
defer d.Unlock()
if _, ok := d.banned[ip]; ok {
delete(d.banned, ip)
return true
}
if _, ok := d.hosts[ip]; ok {
delete(d.hosts, ip)
return true
}
return false
}
// AddEvent adds an event for the given IP.
// This method must be called for clients not yet banned
func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
d.Lock()
defer d.Unlock()
if d.safeList != nil && d.safeList.isListed(ip) {
return
}
// ignore events for already banned hosts
if v, ok := d.banned[ip]; ok {
if v.After(time.Now()) {
return
}
delete(d.banned, ip)
}
score := d.baseDefender.getScore(event)
ev := hostEvent{
dateTime: time.Now(),
score: score,
}
if hs, ok := d.hosts[ip]; ok {
hs.Events = append(hs.Events, ev)
hs.TotalScore = 0
idx := 0
for _, event := range hs.Events {
if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
hs.Events[idx] = event
hs.TotalScore += event.score
idx++
}
}
hs.Events = hs.Events[:idx]
if hs.TotalScore >= d.config.Threshold {
d.banned[ip] = time.Now().Add(time.Duration(d.config.BanTime) * time.Minute)
delete(d.hosts, ip)
d.cleanupBanned()
} else {
d.hosts[ip] = hs
}
} else {
d.hosts[ip] = hostScore{
TotalScore: ev.score,
Events: []hostEvent{ev},
}
d.cleanupHosts()
}
}
func (d *memoryDefender) countBanned() int {
d.RLock()
defer d.RUnlock()
return len(d.banned)
}
func (d *memoryDefender) countHosts() int {
d.RLock()
defer d.RUnlock()
return len(d.hosts)
}
// GetBanTime returns the ban time for the given IP or nil if the IP is not banned
func (d *memoryDefender) GetBanTime(ip string) (*time.Time, error) {
d.RLock()
defer d.RUnlock()
if banTime, ok := d.banned[ip]; ok {
return &banTime, nil
}
return nil, nil
}
// GetScore returns the score for the given IP
func (d *memoryDefender) GetScore(ip string) (int, error) {
d.RLock()
defer d.RUnlock()
score := 0
if hs, ok := d.hosts[ip]; ok {
for _, event := range hs.Events {
if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
score += event.score
}
}
}
return score, nil
}
func (d *memoryDefender) cleanupBanned() {
if len(d.banned) > d.config.EntriesHardLimit {
kvList := make(kvList, 0, len(d.banned))
for k, v := range d.banned {
if v.Before(time.Now()) {
delete(d.banned, k)
}
kvList = append(kvList, kv{
Key: k,
Value: v.UnixNano(),
})
}
// we removed expired ip addresses, if any, above, this could be enough
numToRemove := len(d.banned) - d.config.EntriesSoftLimit
if numToRemove <= 0 {
return
}
sort.Sort(kvList)
for idx, kv := range kvList {
if idx >= numToRemove {
break
}
delete(d.banned, kv.Key)
}
}
}
func (d *memoryDefender) cleanupHosts() {
if len(d.hosts) > d.config.EntriesHardLimit {
kvList := make(kvList, 0, len(d.hosts))
for k, v := range d.hosts {
value := int64(0)
if len(v.Events) > 0 {
value = v.Events[len(v.Events)-1].dateTime.UnixNano()
}
kvList = append(kvList, kv{
Key: k,
Value: value,
})
}
sort.Sort(kvList)
numToRemove := len(d.hosts) - d.config.EntriesSoftLimit
for idx, kv := range kvList {
if idx >= numToRemove {
break
}
delete(d.hosts, kv.Key)
}
}
}
type kv struct {
Key string
Value int64
}
type kvList []kv
func (p kvList) Len() int { return len(p) }
func (p kvList) Less(i, j int) bool { return p[i].Value < p[j].Value }
func (p kvList) Swap(i, j int) { p[i], p[j] = p[j], p[i] }