parent
8174349032
commit
7d8823307f
30 changed files with 2177 additions and 609 deletions
|
@ -139,14 +139,6 @@ func Initialize(c Configuration) error {
|
|||
startIdleTimeoutTicker(idleTimeoutCheckInterval)
|
||||
}
|
||||
Config.defender = nil
|
||||
if c.DefenderConfig.Enabled {
|
||||
defender, err := newInMemoryDefender(&c.DefenderConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("defender initialization error: %v", err)
|
||||
}
|
||||
logger.Info(logSender, "", "defender initialized with config %+v", c.DefenderConfig)
|
||||
Config.defender = defender
|
||||
}
|
||||
rateLimiters = make(map[string][]*rateLimiter)
|
||||
for _, rlCfg := range c.RateLimitersConfig {
|
||||
if rlCfg.isEnabled() {
|
||||
|
@ -164,6 +156,24 @@ func Initialize(c Configuration) error {
|
|||
}
|
||||
}
|
||||
}
|
||||
if c.DefenderConfig.Enabled {
|
||||
if !util.IsStringInSlice(c.DefenderConfig.Driver, supportedDefenderDrivers) {
|
||||
return fmt.Errorf("unsupported defender driver %#v", c.DefenderConfig.Driver)
|
||||
}
|
||||
var defender Defender
|
||||
var err error
|
||||
switch c.DefenderConfig.Driver {
|
||||
case DefenderDriverProvider:
|
||||
defender, err = newDBDefender(&c.DefenderConfig)
|
||||
default:
|
||||
defender, err = newInMemoryDefender(&c.DefenderConfig)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("defender initialization error: %v", err)
|
||||
}
|
||||
logger.Info(logSender, "", "defender initialized with config %+v", c.DefenderConfig)
|
||||
Config.defender = defender
|
||||
}
|
||||
vfs.SetTempPath(c.TempPath)
|
||||
dataprovider.SetTempPath(c.TempPath)
|
||||
return nil
|
||||
|
@ -203,25 +213,25 @@ func IsBanned(ip string) bool {
|
|||
|
||||
// GetDefenderBanTime returns the ban time for the given IP
|
||||
// or nil if the IP is not banned or the defender is disabled
|
||||
func GetDefenderBanTime(ip string) *time.Time {
|
||||
func GetDefenderBanTime(ip string) (*time.Time, error) {
|
||||
if Config.defender == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return Config.defender.GetBanTime(ip)
|
||||
}
|
||||
|
||||
// GetDefenderHosts returns hosts that are banned or for which some violations have been detected
|
||||
func GetDefenderHosts() []*DefenderEntry {
|
||||
func GetDefenderHosts() ([]*dataprovider.DefenderEntry, error) {
|
||||
if Config.defender == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return Config.defender.GetHosts()
|
||||
}
|
||||
|
||||
// GetDefenderHost returns a defender host by ip, if any
|
||||
func GetDefenderHost(ip string) (*DefenderEntry, error) {
|
||||
func GetDefenderHost(ip string) (*dataprovider.DefenderEntry, error) {
|
||||
if Config.defender == nil {
|
||||
return nil, errors.New("defender is disabled")
|
||||
}
|
||||
|
@ -239,9 +249,9 @@ func DeleteDefenderHost(ip string) bool {
|
|||
}
|
||||
|
||||
// GetDefenderScore returns the score for the given IP
|
||||
func GetDefenderScore(ip string) int {
|
||||
func GetDefenderScore(ip string) (int, error) {
|
||||
if Config.defender == nil {
|
||||
return 0
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return Config.defender.GetScore(ip)
|
||||
|
|
|
@ -134,15 +134,22 @@ func TestDefenderIntegration(t *testing.T) {
|
|||
AddDefenderEvent(ip, HostEventNoLoginTried)
|
||||
assert.False(t, IsBanned(ip))
|
||||
|
||||
assert.Nil(t, GetDefenderBanTime(ip))
|
||||
banTime, err := GetDefenderBanTime(ip)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, banTime)
|
||||
assert.False(t, DeleteDefenderHost(ip))
|
||||
assert.Equal(t, 0, GetDefenderScore(ip))
|
||||
_, err := GetDefenderHost(ip)
|
||||
score, err := GetDefenderScore(ip)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, score)
|
||||
_, err = GetDefenderHost(ip)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, GetDefenderHosts())
|
||||
hosts, err := GetDefenderHosts()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, hosts)
|
||||
|
||||
Config.DefenderConfig = DefenderConfig{
|
||||
Enabled: true,
|
||||
Driver: DefenderDriverProvider,
|
||||
BanTime: 10,
|
||||
BanTimeIncrement: 50,
|
||||
Threshold: 0,
|
||||
|
@ -153,6 +160,16 @@ func TestDefenderIntegration(t *testing.T) {
|
|||
EntriesHardLimit: 150,
|
||||
}
|
||||
err = Initialize(Config)
|
||||
// ScoreInvalid cannot be greater than threshold
|
||||
assert.Error(t, err)
|
||||
Config.DefenderConfig.Driver = "unsupported"
|
||||
err = Initialize(Config)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "unsupported defender driver")
|
||||
}
|
||||
Config.DefenderConfig.Driver = DefenderDriverMemory
|
||||
err = Initialize(Config)
|
||||
// ScoreInvalid cannot be greater than threshold
|
||||
assert.Error(t, err)
|
||||
Config.DefenderConfig.Threshold = 3
|
||||
err = Initialize(Config)
|
||||
|
@ -161,27 +178,41 @@ func TestDefenderIntegration(t *testing.T) {
|
|||
|
||||
AddDefenderEvent(ip, HostEventNoLoginTried)
|
||||
assert.False(t, IsBanned(ip))
|
||||
assert.Equal(t, 2, GetDefenderScore(ip))
|
||||
score, err = GetDefenderScore(ip)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, score)
|
||||
entry, err := GetDefenderHost(ip)
|
||||
assert.NoError(t, err)
|
||||
asJSON, err := json.Marshal(&entry)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"id":"3132372e312e312e31","ip":"127.1.1.1","score":2}`, string(asJSON), "entry %v", entry)
|
||||
assert.True(t, DeleteDefenderHost(ip))
|
||||
assert.Nil(t, GetDefenderBanTime(ip))
|
||||
banTime, err = GetDefenderBanTime(ip)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, banTime)
|
||||
|
||||
AddDefenderEvent(ip, HostEventLoginFailed)
|
||||
AddDefenderEvent(ip, HostEventNoLoginTried)
|
||||
assert.True(t, IsBanned(ip))
|
||||
assert.Equal(t, 0, GetDefenderScore(ip))
|
||||
assert.NotNil(t, GetDefenderBanTime(ip))
|
||||
assert.Len(t, GetDefenderHosts(), 1)
|
||||
score, err = GetDefenderScore(ip)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, score)
|
||||
banTime, err = GetDefenderBanTime(ip)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, banTime)
|
||||
hosts, err = GetDefenderHosts()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, hosts, 1)
|
||||
entry, err = GetDefenderHost(ip)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, entry.BanTime.IsZero())
|
||||
assert.True(t, DeleteDefenderHost(ip))
|
||||
assert.Len(t, GetDefenderHosts(), 0)
|
||||
assert.Nil(t, GetDefenderBanTime(ip))
|
||||
hosts, err = GetDefenderHosts()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, hosts, 0)
|
||||
banTime, err = GetDefenderBanTime(ip)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, banTime)
|
||||
assert.False(t, DeleteDefenderHost(ip))
|
||||
|
||||
Config = configCopy
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yl2chen/cidranger"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/dataprovider"
|
||||
"github.com/drakkan/sftpgo/v2/logger"
|
||||
"github.com/drakkan/sftpgo/v2/util"
|
||||
)
|
||||
|
@ -27,49 +26,24 @@ const (
|
|||
HostEventLimitExceeded
|
||||
)
|
||||
|
||||
// DefenderEntry defines a defender entry
|
||||
type DefenderEntry struct {
|
||||
IP string `json:"ip"`
|
||||
Score int `json:"score,omitempty"`
|
||||
BanTime time.Time `json:"ban_time,omitempty"`
|
||||
}
|
||||
// Supported defender drivers
|
||||
const (
|
||||
DefenderDriverMemory = "memory"
|
||||
DefenderDriverProvider = "provider"
|
||||
)
|
||||
|
||||
// GetID returns an unique ID for a defender entry
|
||||
func (d *DefenderEntry) GetID() string {
|
||||
return hex.EncodeToString([]byte(d.IP))
|
||||
}
|
||||
|
||||
// GetBanTime returns the ban time for a defender entry as string
|
||||
func (d *DefenderEntry) GetBanTime() string {
|
||||
if d.BanTime.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return d.BanTime.UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// MarshalJSON returns the JSON encoding of a DefenderEntry.
|
||||
func (d *DefenderEntry) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(&struct {
|
||||
ID string `json:"id"`
|
||||
IP string `json:"ip"`
|
||||
Score int `json:"score,omitempty"`
|
||||
BanTime string `json:"ban_time,omitempty"`
|
||||
}{
|
||||
ID: d.GetID(),
|
||||
IP: d.IP,
|
||||
Score: d.Score,
|
||||
BanTime: d.GetBanTime(),
|
||||
})
|
||||
}
|
||||
var (
|
||||
supportedDefenderDrivers = []string{DefenderDriverMemory, DefenderDriverProvider}
|
||||
)
|
||||
|
||||
// Defender defines the interface that a defender must implements
|
||||
type Defender interface {
|
||||
GetHosts() []*DefenderEntry
|
||||
GetHost(ip string) (*DefenderEntry, error)
|
||||
GetHosts() ([]*dataprovider.DefenderEntry, error)
|
||||
GetHost(ip string) (*dataprovider.DefenderEntry, error)
|
||||
AddEvent(ip string, event HostEvent)
|
||||
IsBanned(ip string) bool
|
||||
GetBanTime(ip string) *time.Time
|
||||
GetScore(ip string) int
|
||||
GetBanTime(ip string) (*time.Time, error)
|
||||
GetScore(ip string) (int, error)
|
||||
DeleteHost(ip string) bool
|
||||
Reload() error
|
||||
}
|
||||
|
@ -78,6 +52,11 @@ type Defender interface {
|
|||
type DefenderConfig struct {
|
||||
// Set to true to enable the defender
|
||||
Enabled bool `json:"enabled" mapstructure:"enabled"`
|
||||
// Defender implementation to use, we support "memory" and "provider".
|
||||
// Using "provider" as driver you can share the defender events among
|
||||
// multiple SFTPGo instances. For a single instance "memory" provider will
|
||||
// be much faster
|
||||
Driver string `json:"driver" mapstructure:"driver"`
|
||||
// BanTime is the number of minutes that a host is banned
|
||||
BanTime int `json:"ban_time" mapstructure:"ban_time"`
|
||||
// Percentage increase of the ban time if a banned host tries to connect again
|
||||
|
@ -97,7 +76,9 @@ type DefenderConfig struct {
|
|||
// the last observation time minutes
|
||||
ObservationTime int `json:"observation_time" mapstructure:"observation_time"`
|
||||
// The number of banned IPs and host scores kept in memory will vary between the
|
||||
// soft and hard limit
|
||||
// soft and hard limit for the "memory" driver. For the "provider" driver the
|
||||
// soft limit is ignored and the hard limit is used to limit the number of entries
|
||||
// to return when you request for the entire host list from the defender
|
||||
EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"`
|
||||
EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"`
|
||||
// Path to a file containing a list of ip addresses and/or networks to never ban
|
||||
|
@ -106,19 +87,59 @@ type DefenderConfig struct {
|
|||
BlockListFile string `json:"blocklist_file" mapstructure:"blocklist_file"`
|
||||
}
|
||||
|
||||
type memoryDefender struct {
|
||||
type baseDefender struct {
|
||||
config *DefenderConfig
|
||||
sync.RWMutex
|
||||
// IP addresses of the clients trying to connected are stored inside hosts,
|
||||
// they are added to banned once the thresold is reached.
|
||||
// A violation from a banned host will increase the ban time
|
||||
// based on the configured BanTimeIncrement
|
||||
hosts map[string]hostScore // the key is the host IP
|
||||
banned map[string]time.Time // the key is the host IP
|
||||
safeList *HostList
|
||||
blockList *HostList
|
||||
}
|
||||
|
||||
// Reload reloads block and safe lists
|
||||
func (d *baseDefender) Reload() error {
|
||||
blockList, err := loadHostListFromFile(d.config.BlockListFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.Lock()
|
||||
d.blockList = blockList
|
||||
d.Unlock()
|
||||
|
||||
safeList, err := loadHostListFromFile(d.config.SafeListFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.Lock()
|
||||
d.safeList = safeList
|
||||
d.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *baseDefender) isBanned(ip string) bool {
|
||||
if d.blockList != nil && d.blockList.isListed(ip) {
|
||||
// permanent ban
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *baseDefender) getScore(event HostEvent) int {
|
||||
var score int
|
||||
|
||||
switch event {
|
||||
case HostEventLoginFailed:
|
||||
score = d.config.ScoreValid
|
||||
case HostEventLimitExceeded:
|
||||
score = d.config.ScoreLimitExceeded
|
||||
case HostEventUserNotFound, HostEventNoLoginTried:
|
||||
score = d.config.ScoreInvalid
|
||||
}
|
||||
return score
|
||||
}
|
||||
|
||||
// HostListFile defines the structure expected for safe/block list files
|
||||
type HostListFile struct {
|
||||
IPAddresses []string `json:"addresses"`
|
||||
|
@ -187,337 +208,6 @@ func (c *DefenderConfig) validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func newInMemoryDefender(config *DefenderConfig) (Defender, error) {
|
||||
err := config.validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defender := &memoryDefender{
|
||||
config: config,
|
||||
hosts: make(map[string]hostScore),
|
||||
banned: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
if err := defender.Reload(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return defender, nil
|
||||
}
|
||||
|
||||
// Reload reloads block and safe lists
|
||||
func (d *memoryDefender) Reload() error {
|
||||
blockList, err := loadHostListFromFile(d.config.BlockListFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.Lock()
|
||||
d.blockList = blockList
|
||||
d.Unlock()
|
||||
|
||||
safeList, err := loadHostListFromFile(d.config.SafeListFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.Lock()
|
||||
d.safeList = safeList
|
||||
d.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetHosts returns hosts that are banned or for which some violations have been detected
|
||||
func (d *memoryDefender) GetHosts() []*DefenderEntry {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
var result []*DefenderEntry
|
||||
for k, v := range d.banned {
|
||||
if v.After(time.Now()) {
|
||||
result = append(result, &DefenderEntry{
|
||||
IP: k,
|
||||
BanTime: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
for k, v := range d.hosts {
|
||||
score := 0
|
||||
for _, event := range v.Events {
|
||||
if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
|
||||
score += event.score
|
||||
}
|
||||
}
|
||||
if score > 0 {
|
||||
result = append(result, &DefenderEntry{
|
||||
IP: k,
|
||||
Score: score,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetHost returns a defender host by ip, if any
|
||||
func (d *memoryDefender) GetHost(ip string) (*DefenderEntry, error) {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
if banTime, ok := d.banned[ip]; ok {
|
||||
if banTime.After(time.Now()) {
|
||||
return &DefenderEntry{
|
||||
IP: ip,
|
||||
BanTime: banTime,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
if hs, ok := d.hosts[ip]; ok {
|
||||
score := 0
|
||||
for _, event := range hs.Events {
|
||||
if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
|
||||
score += event.score
|
||||
}
|
||||
}
|
||||
if score > 0 {
|
||||
return &DefenderEntry{
|
||||
IP: ip,
|
||||
Score: score,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, util.NewRecordNotFoundError("host not found")
|
||||
}
|
||||
|
||||
// IsBanned returns true if the specified IP is banned
|
||||
// and increase ban time if the IP is found.
|
||||
// This method must be called as soon as the client connects
|
||||
func (d *memoryDefender) IsBanned(ip string) bool {
|
||||
d.RLock()
|
||||
|
||||
if banTime, ok := d.banned[ip]; ok {
|
||||
if banTime.After(time.Now()) {
|
||||
increment := d.config.BanTime * d.config.BanTimeIncrement / 100
|
||||
if increment == 0 {
|
||||
increment++
|
||||
}
|
||||
|
||||
d.RUnlock()
|
||||
|
||||
// we can save an earlier ban time if there are contemporary updates
|
||||
// but this should not make much difference. I prefer to hold a read lock
|
||||
// until possible for performance reasons, this method is called each
|
||||
// time a new client connects and it must be as fast as possible
|
||||
d.Lock()
|
||||
d.banned[ip] = banTime.Add(time.Duration(increment) * time.Minute)
|
||||
d.Unlock()
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
defer d.RUnlock()
|
||||
|
||||
if d.blockList != nil && d.blockList.isListed(ip) {
|
||||
// permanent ban
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// DeleteHost removes the specified IP from the defender lists
|
||||
func (d *memoryDefender) DeleteHost(ip string) bool {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
|
||||
if _, ok := d.banned[ip]; ok {
|
||||
delete(d.banned, ip)
|
||||
return true
|
||||
}
|
||||
|
||||
if _, ok := d.hosts[ip]; ok {
|
||||
delete(d.hosts, ip)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AddEvent adds an event for the given IP.
|
||||
// This method must be called for clients not yet banned
|
||||
func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
|
||||
if d.safeList != nil && d.safeList.isListed(ip) {
|
||||
return
|
||||
}
|
||||
|
||||
// ignore events for already banned hosts
|
||||
if v, ok := d.banned[ip]; ok {
|
||||
if v.After(time.Now()) {
|
||||
return
|
||||
}
|
||||
delete(d.banned, ip)
|
||||
}
|
||||
|
||||
var score int
|
||||
|
||||
switch event {
|
||||
case HostEventLoginFailed:
|
||||
score = d.config.ScoreValid
|
||||
case HostEventLimitExceeded:
|
||||
score = d.config.ScoreLimitExceeded
|
||||
case HostEventUserNotFound, HostEventNoLoginTried:
|
||||
score = d.config.ScoreInvalid
|
||||
}
|
||||
|
||||
ev := hostEvent{
|
||||
dateTime: time.Now(),
|
||||
score: score,
|
||||
}
|
||||
|
||||
if hs, ok := d.hosts[ip]; ok {
|
||||
hs.Events = append(hs.Events, ev)
|
||||
hs.TotalScore = 0
|
||||
|
||||
idx := 0
|
||||
for _, event := range hs.Events {
|
||||
if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
|
||||
hs.Events[idx] = event
|
||||
hs.TotalScore += event.score
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
hs.Events = hs.Events[:idx]
|
||||
if hs.TotalScore >= d.config.Threshold {
|
||||
d.banned[ip] = time.Now().Add(time.Duration(d.config.BanTime) * time.Minute)
|
||||
delete(d.hosts, ip)
|
||||
d.cleanupBanned()
|
||||
} else {
|
||||
d.hosts[ip] = hs
|
||||
}
|
||||
} else {
|
||||
d.hosts[ip] = hostScore{
|
||||
TotalScore: ev.score,
|
||||
Events: []hostEvent{ev},
|
||||
}
|
||||
d.cleanupHosts()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *memoryDefender) countBanned() int {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
return len(d.banned)
|
||||
}
|
||||
|
||||
func (d *memoryDefender) countHosts() int {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
return len(d.hosts)
|
||||
}
|
||||
|
||||
// GetBanTime returns the ban time for the given IP or nil if the IP is not banned
|
||||
func (d *memoryDefender) GetBanTime(ip string) *time.Time {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
if banTime, ok := d.banned[ip]; ok {
|
||||
return &banTime
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetScore returns the score for the given IP
|
||||
func (d *memoryDefender) GetScore(ip string) int {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
score := 0
|
||||
|
||||
if hs, ok := d.hosts[ip]; ok {
|
||||
for _, event := range hs.Events {
|
||||
if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
|
||||
score += event.score
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
func (d *memoryDefender) cleanupBanned() {
|
||||
if len(d.banned) > d.config.EntriesHardLimit {
|
||||
kvList := make(kvList, 0, len(d.banned))
|
||||
|
||||
for k, v := range d.banned {
|
||||
if v.Before(time.Now()) {
|
||||
delete(d.banned, k)
|
||||
}
|
||||
|
||||
kvList = append(kvList, kv{
|
||||
Key: k,
|
||||
Value: v.UnixNano(),
|
||||
})
|
||||
}
|
||||
|
||||
// we removed expired ip addresses, if any, above, this could be enough
|
||||
numToRemove := len(d.banned) - d.config.EntriesSoftLimit
|
||||
|
||||
if numToRemove <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sort.Sort(kvList)
|
||||
|
||||
for idx, kv := range kvList {
|
||||
if idx >= numToRemove {
|
||||
break
|
||||
}
|
||||
|
||||
delete(d.banned, kv.Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *memoryDefender) cleanupHosts() {
|
||||
if len(d.hosts) > d.config.EntriesHardLimit {
|
||||
kvList := make(kvList, 0, len(d.hosts))
|
||||
|
||||
for k, v := range d.hosts {
|
||||
value := int64(0)
|
||||
if len(v.Events) > 0 {
|
||||
value = v.Events[len(v.Events)-1].dateTime.UnixNano()
|
||||
}
|
||||
kvList = append(kvList, kv{
|
||||
Key: k,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Sort(kvList)
|
||||
|
||||
numToRemove := len(d.hosts) - d.config.EntriesSoftLimit
|
||||
|
||||
for idx, kv := range kvList {
|
||||
if idx >= numToRemove {
|
||||
break
|
||||
}
|
||||
|
||||
delete(d.hosts, kv.Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func loadHostListFromFile(name string) (*HostList, error) {
|
||||
if name == "" {
|
||||
return nil, nil
|
||||
|
@ -582,14 +272,3 @@ func loadHostListFromFile(name string) (*HostList, error) {
|
|||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type kv struct {
|
||||
Key string
|
||||
Value int64
|
||||
}
|
||||
|
||||
type kvList []kv
|
||||
|
||||
func (p kvList) Len() int { return len(p) }
|
||||
func (p kvList) Less(i, j int) bool { return p[i].Value < p[j].Value }
|
||||
func (p kvList) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
|
|
|
@ -73,7 +73,9 @@ func TestBasicDefender(t *testing.T) {
|
|||
assert.False(t, defender.IsBanned("invalid ip"))
|
||||
assert.Equal(t, 0, defender.countBanned())
|
||||
assert.Equal(t, 0, defender.countHosts())
|
||||
assert.Len(t, defender.GetHosts(), 0)
|
||||
hosts, err := defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, hosts, 0)
|
||||
_, err = defender.GetHost("10.8.0.4")
|
||||
assert.Error(t, err)
|
||||
|
||||
|
@ -86,35 +88,53 @@ func TestBasicDefender(t *testing.T) {
|
|||
defender.AddEvent(testIP, HostEventLoginFailed)
|
||||
assert.Equal(t, 1, defender.countHosts())
|
||||
assert.Equal(t, 0, defender.countBanned())
|
||||
assert.Equal(t, 1, defender.GetScore(testIP))
|
||||
if assert.Len(t, defender.GetHosts(), 1) {
|
||||
assert.Equal(t, 1, defender.GetHosts()[0].Score)
|
||||
assert.True(t, defender.GetHosts()[0].BanTime.IsZero())
|
||||
assert.Empty(t, defender.GetHosts()[0].GetBanTime())
|
||||
score, err := defender.GetScore(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, score)
|
||||
hosts, err = defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
assert.Equal(t, 1, hosts[0].Score)
|
||||
assert.True(t, hosts[0].BanTime.IsZero())
|
||||
assert.Empty(t, hosts[0].GetBanTime())
|
||||
}
|
||||
host, err := defender.GetHost(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, host.Score)
|
||||
assert.Empty(t, host.GetBanTime())
|
||||
assert.Nil(t, defender.GetBanTime(testIP))
|
||||
banTime, err := defender.GetBanTime(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, banTime)
|
||||
defender.AddEvent(testIP, HostEventLimitExceeded)
|
||||
assert.Equal(t, 1, defender.countHosts())
|
||||
assert.Equal(t, 0, defender.countBanned())
|
||||
assert.Equal(t, 4, defender.GetScore(testIP))
|
||||
if assert.Len(t, defender.GetHosts(), 1) {
|
||||
assert.Equal(t, 4, defender.GetHosts()[0].Score)
|
||||
score, err = defender.GetScore(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 4, score)
|
||||
hosts, err = defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
assert.Equal(t, 4, hosts[0].Score)
|
||||
assert.True(t, hosts[0].BanTime.IsZero())
|
||||
assert.Empty(t, hosts[0].GetBanTime())
|
||||
}
|
||||
defender.AddEvent(testIP, HostEventNoLoginTried)
|
||||
defender.AddEvent(testIP, HostEventNoLoginTried)
|
||||
assert.Equal(t, 0, defender.countHosts())
|
||||
assert.Equal(t, 1, defender.countBanned())
|
||||
assert.Equal(t, 0, defender.GetScore(testIP))
|
||||
assert.NotNil(t, defender.GetBanTime(testIP))
|
||||
if assert.Len(t, defender.GetHosts(), 1) {
|
||||
assert.Equal(t, 0, defender.GetHosts()[0].Score)
|
||||
assert.False(t, defender.GetHosts()[0].BanTime.IsZero())
|
||||
assert.NotEmpty(t, defender.GetHosts()[0].GetBanTime())
|
||||
assert.Equal(t, hex.EncodeToString([]byte(testIP)), defender.GetHosts()[0].GetID())
|
||||
score, err = defender.GetScore(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, score)
|
||||
banTime, err = defender.GetBanTime(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, banTime)
|
||||
hosts, err = defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
assert.Equal(t, 0, hosts[0].Score)
|
||||
assert.False(t, hosts[0].BanTime.IsZero())
|
||||
assert.NotEmpty(t, hosts[0].GetBanTime())
|
||||
assert.Equal(t, hex.EncodeToString([]byte(testIP)), hosts[0].GetID())
|
||||
}
|
||||
host, err = defender.GetHost(testIP)
|
||||
assert.NoError(t, err)
|
||||
|
@ -134,14 +154,22 @@ func TestBasicDefender(t *testing.T) {
|
|||
assert.Equal(t, defender.config.EntriesSoftLimit, defender.countHosts())
|
||||
// testIP1 and testIP2 should be removed
|
||||
assert.Equal(t, defender.config.EntriesSoftLimit, defender.countHosts())
|
||||
assert.Equal(t, 0, defender.GetScore(testIP1))
|
||||
assert.Equal(t, 0, defender.GetScore(testIP2))
|
||||
assert.Equal(t, 2, defender.GetScore(testIP3))
|
||||
score, err = defender.GetScore(testIP1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, score)
|
||||
score, err = defender.GetScore(testIP2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, score)
|
||||
score, err = defender.GetScore(testIP3)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, score)
|
||||
|
||||
defender.AddEvent(testIP3, HostEventNoLoginTried)
|
||||
defender.AddEvent(testIP3, HostEventNoLoginTried)
|
||||
// IP3 is now banned
|
||||
assert.NotNil(t, defender.GetBanTime(testIP3))
|
||||
banTime, err = defender.GetBanTime(testIP3)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, banTime)
|
||||
assert.Equal(t, 0, defender.countHosts())
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
@ -150,9 +178,15 @@ func TestBasicDefender(t *testing.T) {
|
|||
}
|
||||
assert.Equal(t, 0, defender.countHosts())
|
||||
assert.Equal(t, config.EntriesSoftLimit, defender.countBanned())
|
||||
assert.Nil(t, defender.GetBanTime(testIP))
|
||||
assert.Nil(t, defender.GetBanTime(testIP3))
|
||||
assert.NotNil(t, defender.GetBanTime(testIP1))
|
||||
banTime, err = defender.GetBanTime(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, banTime)
|
||||
banTime, err = defender.GetBanTime(testIP3)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, banTime)
|
||||
banTime, err = defender.GetBanTime(testIP1)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, banTime)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
defender.AddEvent(testIP, HostEventNoLoginTried)
|
||||
|
@ -162,11 +196,13 @@ func TestBasicDefender(t *testing.T) {
|
|||
assert.Equal(t, 0, defender.countHosts())
|
||||
assert.Equal(t, defender.config.EntriesSoftLimit, defender.countBanned())
|
||||
|
||||
banTime := defender.GetBanTime(testIP3)
|
||||
banTime, err = defender.GetBanTime(testIP3)
|
||||
assert.NoError(t, err)
|
||||
if assert.NotNil(t, banTime) {
|
||||
assert.True(t, defender.IsBanned(testIP3))
|
||||
// ban time should increase
|
||||
newBanTime := defender.GetBanTime(testIP3)
|
||||
newBanTime, err := defender.GetBanTime(testIP3)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, newBanTime.After(*banTime))
|
||||
}
|
||||
|
||||
|
@ -202,7 +238,8 @@ func TestExpiredHostBans(t *testing.T) {
|
|||
defender.banned[testIP] = time.Now().Add(-24 * time.Hour)
|
||||
|
||||
// the ban is expired testIP should not be listed
|
||||
res := defender.GetHosts()
|
||||
res, err := defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, res, 0)
|
||||
|
||||
assert.False(t, defender.IsBanned(testIP))
|
||||
|
@ -219,7 +256,8 @@ func TestExpiredHostBans(t *testing.T) {
|
|||
assert.Empty(t, entry.GetBanTime())
|
||||
assert.Equal(t, 1, entry.Score)
|
||||
|
||||
res = defender.GetHosts()
|
||||
res, err = defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, res, 1) {
|
||||
assert.Equal(t, testIP, res[0].IP)
|
||||
assert.Empty(t, res[0].GetBanTime())
|
||||
|
@ -244,7 +282,8 @@ func TestExpiredHostBans(t *testing.T) {
|
|||
|
||||
defender.hosts[testIP] = hs
|
||||
// the recorded scored are too old
|
||||
res = defender.GetHosts()
|
||||
res, err = defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, res, 0)
|
||||
_, err = defender.GetHost(testIP)
|
||||
assert.Error(t, err)
|
||||
|
@ -327,13 +366,15 @@ func TestLoadHostListFromFile(t *testing.T) {
|
|||
|
||||
func TestDefenderCleanup(t *testing.T) {
|
||||
d := memoryDefender{
|
||||
baseDefender: baseDefender{
|
||||
config: &DefenderConfig{
|
||||
ObservationTime: 1,
|
||||
EntriesSoftLimit: 2,
|
||||
EntriesHardLimit: 3,
|
||||
},
|
||||
},
|
||||
banned: make(map[string]time.Time),
|
||||
hosts: make(map[string]hostScore),
|
||||
config: &DefenderConfig{
|
||||
ObservationTime: 1,
|
||||
EntriesSoftLimit: 2,
|
||||
EntriesHardLimit: 3,
|
||||
},
|
||||
}
|
||||
|
||||
d.banned["1.1.1.1"] = time.Now().Add(-24 * time.Hour)
|
||||
|
@ -351,7 +392,9 @@ func TestDefenderCleanup(t *testing.T) {
|
|||
|
||||
d.cleanupBanned()
|
||||
assert.Equal(t, d.config.EntriesSoftLimit, d.countBanned())
|
||||
assert.Nil(t, d.GetBanTime("2.2.2.3"))
|
||||
banTime, err := d.GetBanTime("2.2.2.3")
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, banTime)
|
||||
|
||||
d.hosts["3.3.3.3"] = hostScore{
|
||||
TotalScore: 0,
|
||||
|
@ -398,11 +441,15 @@ func TestDefenderCleanup(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, d.GetScore("3.3.3.3"))
|
||||
score, err := d.GetScore("3.3.3.3")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, score)
|
||||
|
||||
d.cleanupHosts()
|
||||
assert.Equal(t, d.config.EntriesSoftLimit, d.countHosts())
|
||||
assert.Equal(t, 0, d.GetScore("3.3.3.4"))
|
||||
score, err = d.GetScore("3.3.3.4")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, score)
|
||||
}
|
||||
|
||||
func TestDefenderConfig(t *testing.T) {
|
||||
|
@ -613,7 +660,9 @@ func getDefenderForBench() *memoryDefender {
|
|||
EntriesHardLimit: 100,
|
||||
}
|
||||
return &memoryDefender{
|
||||
config: config,
|
||||
baseDefender: baseDefender{
|
||||
config: config,
|
||||
},
|
||||
hosts: make(map[string]hostScore),
|
||||
banned: make(map[string]time.Time),
|
||||
}
|
||||
|
|
157
common/defenderdb.go
Normal file
157
common/defenderdb.go
Normal file
|
@ -0,0 +1,157 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/dataprovider"
|
||||
"github.com/drakkan/sftpgo/v2/logger"
|
||||
"github.com/drakkan/sftpgo/v2/util"
|
||||
)
|
||||
|
||||
type dbDefender struct {
|
||||
baseDefender
|
||||
lastCleanup time.Time
|
||||
}
|
||||
|
||||
func newDBDefender(config *DefenderConfig) (Defender, error) {
|
||||
err := config.validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defender := &dbDefender{
|
||||
baseDefender: baseDefender{
|
||||
config: config,
|
||||
},
|
||||
lastCleanup: time.Time{},
|
||||
}
|
||||
|
||||
if err := defender.Reload(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return defender, nil
|
||||
}
|
||||
|
||||
// GetHosts returns hosts that are banned or for which some violations have been detected
|
||||
func (d *dbDefender) GetHosts() ([]*dataprovider.DefenderEntry, error) {
|
||||
return dataprovider.GetDefenderHosts(d.getStartObservationTime(), d.config.EntriesHardLimit)
|
||||
}
|
||||
|
||||
// GetHost returns a defender host by ip, if any
|
||||
func (d *dbDefender) GetHost(ip string) (*dataprovider.DefenderEntry, error) {
|
||||
return dataprovider.GetDefenderHostByIP(ip, d.getStartObservationTime())
|
||||
}
|
||||
|
||||
// IsBanned returns true if the specified IP is banned
|
||||
// and increase ban time if the IP is found.
|
||||
// This method must be called as soon as the client connects
|
||||
func (d *dbDefender) IsBanned(ip string) bool {
|
||||
d.RLock()
|
||||
if d.baseDefender.isBanned(ip) {
|
||||
d.RUnlock()
|
||||
return true
|
||||
}
|
||||
d.RUnlock()
|
||||
|
||||
_, err := dataprovider.IsDefenderHostBanned(ip)
|
||||
if err != nil {
|
||||
// not found or another error, we allow this host
|
||||
return false
|
||||
}
|
||||
increment := d.config.BanTime * d.config.BanTimeIncrement / 100
|
||||
if increment == 0 {
|
||||
increment++
|
||||
}
|
||||
dataprovider.UpdateDefenderBanTime(ip, increment) //nolint:errcheck
|
||||
return true
|
||||
}
|
||||
|
||||
// DeleteHost removes the specified IP from the defender lists
|
||||
func (d *dbDefender) DeleteHost(ip string) bool {
|
||||
if _, err := d.GetHost(ip); err != nil {
|
||||
return false
|
||||
}
|
||||
return dataprovider.DeleteDefenderHost(ip) == nil
|
||||
}
|
||||
|
||||
// AddEvent adds an event for the given IP.
|
||||
// This method must be called for clients not yet banned
|
||||
func (d *dbDefender) AddEvent(ip string, event HostEvent) {
|
||||
d.RLock()
|
||||
if d.safeList != nil && d.safeList.isListed(ip) {
|
||||
d.RUnlock()
|
||||
return
|
||||
}
|
||||
d.RUnlock()
|
||||
|
||||
score := d.baseDefender.getScore(event)
|
||||
|
||||
host, err := dataprovider.AddDefenderEvent(ip, score, d.getStartObservationTime())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if host.Score > d.config.Threshold {
|
||||
banTime := time.Now().Add(time.Duration(d.config.BanTime) * time.Minute)
|
||||
err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(banTime))
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
d.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
// GetBanTime returns the ban time for the given IP or nil if the IP is not banned
|
||||
func (d *dbDefender) GetBanTime(ip string) (*time.Time, error) {
|
||||
host, err := d.GetHost(ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if host.BanTime.IsZero() {
|
||||
return nil, nil
|
||||
}
|
||||
return &host.BanTime, nil
|
||||
}
|
||||
|
||||
// GetScore returns the score for the given IP
|
||||
func (d *dbDefender) GetScore(ip string) (int, error) {
|
||||
host, err := d.GetHost(ip)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return host.Score, nil
|
||||
}
|
||||
|
||||
func (d *dbDefender) cleanup() {
|
||||
lastCleanup := d.getLastCleanup()
|
||||
if lastCleanup.IsZero() || lastCleanup.Add(time.Duration(d.config.ObservationTime)*time.Minute*3).Before(time.Now()) {
|
||||
// FIXME: this could be racy in rare cases but it is better than acquire the lock for the cleanup duration
|
||||
// or to always acquire a read/write lock.
|
||||
// Concurrent cleanups could happen anyway from multiple SFTPGo instances and should not cause any issues
|
||||
d.setLastCleanup(time.Now())
|
||||
expireTime := time.Now().Add(-time.Duration(d.config.ObservationTime+1) * time.Minute)
|
||||
logger.Debug(logSender, "", "cleanup defender hosts before %v, last cleanup %v", expireTime, lastCleanup)
|
||||
if err := dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(expireTime)); err != nil {
|
||||
logger.Error(logSender, "", "defender cleanup error, reset last cleanup to %v", lastCleanup)
|
||||
d.setLastCleanup(lastCleanup)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dbDefender) getStartObservationTime() int64 {
|
||||
t := time.Now().Add(-time.Duration(d.config.ObservationTime) * time.Minute)
|
||||
return util.GetTimeAsMsSinceEpoch(t)
|
||||
}
|
||||
|
||||
func (d *dbDefender) getLastCleanup() time.Time {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
return d.lastCleanup
|
||||
}
|
||||
|
||||
func (d *dbDefender) setLastCleanup(when time.Time) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
|
||||
d.lastCleanup = when
|
||||
}
|
280
common/defenderdb_test.go
Normal file
280
common/defenderdb_test.go
Normal file
|
@ -0,0 +1,280 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/dataprovider"
|
||||
"github.com/drakkan/sftpgo/v2/util"
|
||||
)
|
||||
|
||||
func TestBasicDbDefender(t *testing.T) {
|
||||
if !isDbDefenderSupported() {
|
||||
t.Skip("this test is not supported with the current database provider")
|
||||
}
|
||||
config := &DefenderConfig{
|
||||
Enabled: true,
|
||||
BanTime: 10,
|
||||
BanTimeIncrement: 2,
|
||||
Threshold: 5,
|
||||
ScoreInvalid: 2,
|
||||
ScoreValid: 1,
|
||||
ScoreLimitExceeded: 3,
|
||||
ObservationTime: 15,
|
||||
EntriesSoftLimit: 1,
|
||||
EntriesHardLimit: 10,
|
||||
SafeListFile: "slFile",
|
||||
BlockListFile: "blFile",
|
||||
}
|
||||
_, err := newDBDefender(config)
|
||||
assert.Error(t, err)
|
||||
|
||||
bl := HostListFile{
|
||||
IPAddresses: []string{"172.16.1.1", "172.16.1.2"},
|
||||
CIDRNetworks: []string{"10.8.0.0/24"},
|
||||
}
|
||||
sl := HostListFile{
|
||||
IPAddresses: []string{"172.16.1.3", "172.16.1.4"},
|
||||
CIDRNetworks: []string{"192.168.8.0/24"},
|
||||
}
|
||||
blFile := filepath.Join(os.TempDir(), "bl.json")
|
||||
slFile := filepath.Join(os.TempDir(), "sl.json")
|
||||
|
||||
data, err := json.Marshal(bl)
|
||||
assert.NoError(t, err)
|
||||
err = os.WriteFile(blFile, data, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
|
||||
data, err = json.Marshal(sl)
|
||||
assert.NoError(t, err)
|
||||
err = os.WriteFile(slFile, data, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.BlockListFile = blFile
|
||||
_, err = newDBDefender(config)
|
||||
assert.Error(t, err)
|
||||
config.SafeListFile = slFile
|
||||
d, err := newDBDefender(config)
|
||||
assert.NoError(t, err)
|
||||
defender := d.(*dbDefender)
|
||||
assert.True(t, defender.IsBanned("172.16.1.1"))
|
||||
assert.False(t, defender.IsBanned("172.16.1.10"))
|
||||
assert.False(t, defender.IsBanned("10.8.1.3"))
|
||||
assert.True(t, defender.IsBanned("10.8.0.4"))
|
||||
assert.False(t, defender.IsBanned("invalid ip"))
|
||||
hosts, err := defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, hosts, 0)
|
||||
_, err = defender.GetHost("10.8.0.3")
|
||||
assert.Error(t, err)
|
||||
|
||||
defender.AddEvent("172.16.1.4", HostEventLoginFailed)
|
||||
defender.AddEvent("192.168.8.4", HostEventUserNotFound)
|
||||
defender.AddEvent("172.16.1.3", HostEventLimitExceeded)
|
||||
hosts, err = defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, hosts, 0)
|
||||
assert.True(t, defender.getLastCleanup().IsZero())
|
||||
|
||||
testIP := "123.45.67.89"
|
||||
defender.AddEvent(testIP, HostEventLoginFailed)
|
||||
lastCleanup := defender.getLastCleanup()
|
||||
assert.False(t, lastCleanup.IsZero())
|
||||
score, err := defender.GetScore(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, score)
|
||||
hosts, err = defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
assert.Equal(t, 1, hosts[0].Score)
|
||||
assert.True(t, hosts[0].BanTime.IsZero())
|
||||
assert.Empty(t, hosts[0].GetBanTime())
|
||||
}
|
||||
host, err := defender.GetHost(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, host.Score)
|
||||
assert.Empty(t, host.GetBanTime())
|
||||
banTime, err := defender.GetBanTime(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, banTime)
|
||||
defender.AddEvent(testIP, HostEventLimitExceeded)
|
||||
score, err = defender.GetScore(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 4, score)
|
||||
hosts, err = defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
assert.Equal(t, 4, hosts[0].Score)
|
||||
assert.True(t, hosts[0].BanTime.IsZero())
|
||||
assert.Empty(t, hosts[0].GetBanTime())
|
||||
}
|
||||
defender.AddEvent(testIP, HostEventNoLoginTried)
|
||||
defender.AddEvent(testIP, HostEventNoLoginTried)
|
||||
score, err = defender.GetScore(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, score)
|
||||
banTime, err = defender.GetBanTime(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, banTime)
|
||||
hosts, err = defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
assert.Equal(t, 0, hosts[0].Score)
|
||||
assert.False(t, hosts[0].BanTime.IsZero())
|
||||
assert.NotEmpty(t, hosts[0].GetBanTime())
|
||||
assert.Equal(t, hex.EncodeToString([]byte(testIP)), hosts[0].GetID())
|
||||
}
|
||||
host, err = defender.GetHost(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, host.Score)
|
||||
assert.NotEmpty(t, host.GetBanTime())
|
||||
// ban time should increase
|
||||
assert.True(t, defender.IsBanned(testIP))
|
||||
newBanTime, err := defender.GetBanTime(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, newBanTime.After(*banTime))
|
||||
|
||||
assert.True(t, defender.DeleteHost(testIP))
|
||||
assert.False(t, defender.DeleteHost(testIP))
|
||||
// test cleanup
|
||||
testIP1 := "123.45.67.90"
|
||||
testIP2 := "123.45.67.91"
|
||||
testIP3 := "123.45.67.92"
|
||||
for i := 0; i < 3; i++ {
|
||||
defender.AddEvent(testIP, HostEventNoLoginTried)
|
||||
defender.AddEvent(testIP1, HostEventNoLoginTried)
|
||||
defender.AddEvent(testIP2, HostEventNoLoginTried)
|
||||
}
|
||||
hosts, err = defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, hosts, 3)
|
||||
for _, host := range hosts {
|
||||
assert.Equal(t, 0, host.Score)
|
||||
assert.False(t, host.BanTime.IsZero())
|
||||
assert.NotEmpty(t, host.GetBanTime())
|
||||
}
|
||||
defender.AddEvent(testIP3, HostEventLoginFailed)
|
||||
hosts, err = defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, hosts, 4)
|
||||
// now set a ban time in the past, so the host will be cleanead up
|
||||
for _, ip := range []string{testIP1, testIP2} {
|
||||
err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
hosts, err = defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, hosts, 4)
|
||||
for _, host := range hosts {
|
||||
switch host.IP {
|
||||
case testIP:
|
||||
assert.Equal(t, 0, host.Score)
|
||||
assert.False(t, host.BanTime.IsZero())
|
||||
assert.NotEmpty(t, host.GetBanTime())
|
||||
case testIP3:
|
||||
assert.Equal(t, 1, host.Score)
|
||||
assert.True(t, host.BanTime.IsZero())
|
||||
assert.Empty(t, host.GetBanTime())
|
||||
default:
|
||||
assert.Equal(t, 6, host.Score)
|
||||
assert.True(t, host.BanTime.IsZero())
|
||||
assert.Empty(t, host.GetBanTime())
|
||||
}
|
||||
}
|
||||
host, err = defender.GetHost(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, host.Score)
|
||||
assert.False(t, host.BanTime.IsZero())
|
||||
assert.NotEmpty(t, host.GetBanTime())
|
||||
host, err = defender.GetHost(testIP3)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, host.Score)
|
||||
assert.True(t, host.BanTime.IsZero())
|
||||
assert.Empty(t, host.GetBanTime())
|
||||
// cleanup db
|
||||
err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
|
||||
assert.NoError(t, err)
|
||||
hosts, err = defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
assert.Equal(t, testIP, hosts[0].IP)
|
||||
assert.Equal(t, 0, hosts[0].Score)
|
||||
assert.False(t, hosts[0].BanTime.IsZero())
|
||||
assert.NotEmpty(t, hosts[0].GetBanTime())
|
||||
}
|
||||
err = dataprovider.SetDefenderBanTime(testIP, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute)))
|
||||
assert.NoError(t, err)
|
||||
err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute)))
|
||||
assert.NoError(t, err)
|
||||
hosts, err = defender.GetHosts()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, hosts, 0)
|
||||
|
||||
err = os.Remove(slFile)
|
||||
assert.NoError(t, err)
|
||||
err = os.Remove(blFile)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDbDefenderCleanup(t *testing.T) {
|
||||
if !isDbDefenderSupported() {
|
||||
t.Skip("this test is not supported with the current database provider")
|
||||
}
|
||||
config := &DefenderConfig{
|
||||
Enabled: true,
|
||||
BanTime: 10,
|
||||
BanTimeIncrement: 2,
|
||||
Threshold: 5,
|
||||
ScoreInvalid: 2,
|
||||
ScoreValid: 1,
|
||||
ScoreLimitExceeded: 3,
|
||||
ObservationTime: 15,
|
||||
EntriesSoftLimit: 1,
|
||||
EntriesHardLimit: 10,
|
||||
}
|
||||
d, err := newDBDefender(config)
|
||||
assert.NoError(t, err)
|
||||
defender := d.(*dbDefender)
|
||||
lastCleanup := defender.getLastCleanup()
|
||||
assert.True(t, lastCleanup.IsZero())
|
||||
defender.cleanup()
|
||||
lastCleanup = defender.getLastCleanup()
|
||||
assert.False(t, lastCleanup.IsZero())
|
||||
defender.cleanup()
|
||||
assert.Equal(t, lastCleanup, defender.getLastCleanup())
|
||||
defender.setLastCleanup(time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4))
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
defender.cleanup()
|
||||
assert.True(t, lastCleanup.Before(defender.getLastCleanup()))
|
||||
|
||||
providerConf := dataprovider.GetProviderConfig()
|
||||
err = dataprovider.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
lastCleanup = time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4)
|
||||
defender.setLastCleanup(lastCleanup)
|
||||
defender.cleanup()
|
||||
// cleanup will fail and so last cleanup should be reset to the previous value
|
||||
assert.Equal(t, lastCleanup, defender.getLastCleanup())
|
||||
|
||||
err = dataprovider.Initialize(providerConf, configDir, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func isDbDefenderSupported() bool {
|
||||
// SQLite shares the implementation with other SQL-based provider but it makes no sense
|
||||
// to use it outside test cases
|
||||
switch dataprovider.GetProviderStatus().Driver {
|
||||
case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName,
|
||||
dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
326
common/defendermem.go
Normal file
326
common/defendermem.go
Normal file
|
@ -0,0 +1,326 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/dataprovider"
|
||||
"github.com/drakkan/sftpgo/v2/util"
|
||||
)
|
||||
|
||||
type memoryDefender struct {
|
||||
baseDefender
|
||||
// IP addresses of the clients trying to connected are stored inside hosts,
|
||||
// they are added to banned once the thresold is reached.
|
||||
// A violation from a banned host will increase the ban time
|
||||
// based on the configured BanTimeIncrement
|
||||
hosts map[string]hostScore // the key is the host IP
|
||||
banned map[string]time.Time // the key is the host IP
|
||||
}
|
||||
|
||||
func newInMemoryDefender(config *DefenderConfig) (Defender, error) {
|
||||
err := config.validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defender := &memoryDefender{
|
||||
baseDefender: baseDefender{
|
||||
config: config,
|
||||
},
|
||||
hosts: make(map[string]hostScore),
|
||||
banned: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
if err := defender.Reload(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return defender, nil
|
||||
}
|
||||
|
||||
// GetHosts returns hosts that are banned or for which some violations have been detected
|
||||
func (d *memoryDefender) GetHosts() ([]*dataprovider.DefenderEntry, error) {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
var result []*dataprovider.DefenderEntry
|
||||
for k, v := range d.banned {
|
||||
if v.After(time.Now()) {
|
||||
result = append(result, &dataprovider.DefenderEntry{
|
||||
IP: k,
|
||||
BanTime: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
for k, v := range d.hosts {
|
||||
score := 0
|
||||
for _, event := range v.Events {
|
||||
if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
|
||||
score += event.score
|
||||
}
|
||||
}
|
||||
if score > 0 {
|
||||
result = append(result, &dataprovider.DefenderEntry{
|
||||
IP: k,
|
||||
Score: score,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetHost returns a defender host by ip, if any
|
||||
func (d *memoryDefender) GetHost(ip string) (*dataprovider.DefenderEntry, error) {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
if banTime, ok := d.banned[ip]; ok {
|
||||
if banTime.After(time.Now()) {
|
||||
return &dataprovider.DefenderEntry{
|
||||
IP: ip,
|
||||
BanTime: banTime,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
if hs, ok := d.hosts[ip]; ok {
|
||||
score := 0
|
||||
for _, event := range hs.Events {
|
||||
if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
|
||||
score += event.score
|
||||
}
|
||||
}
|
||||
if score > 0 {
|
||||
return &dataprovider.DefenderEntry{
|
||||
IP: ip,
|
||||
Score: score,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, util.NewRecordNotFoundError("host not found")
|
||||
}
|
||||
|
||||
// IsBanned returns true if the specified IP is banned
|
||||
// and increase ban time if the IP is found.
|
||||
// This method must be called as soon as the client connects
|
||||
func (d *memoryDefender) IsBanned(ip string) bool {
|
||||
d.RLock()
|
||||
|
||||
if banTime, ok := d.banned[ip]; ok {
|
||||
if banTime.After(time.Now()) {
|
||||
increment := d.config.BanTime * d.config.BanTimeIncrement / 100
|
||||
if increment == 0 {
|
||||
increment++
|
||||
}
|
||||
|
||||
d.RUnlock()
|
||||
|
||||
// we can save an earlier ban time if there are contemporary updates
|
||||
// but this should not make much difference. I prefer to hold a read lock
|
||||
// until possible for performance reasons, this method is called each
|
||||
// time a new client connects and it must be as fast as possible
|
||||
d.Lock()
|
||||
d.banned[ip] = banTime.Add(time.Duration(increment) * time.Minute)
|
||||
d.Unlock()
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
defer d.RUnlock()
|
||||
|
||||
return d.baseDefender.isBanned(ip)
|
||||
}
|
||||
|
||||
// DeleteHost removes the specified IP from the defender lists
|
||||
func (d *memoryDefender) DeleteHost(ip string) bool {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
|
||||
if _, ok := d.banned[ip]; ok {
|
||||
delete(d.banned, ip)
|
||||
return true
|
||||
}
|
||||
|
||||
if _, ok := d.hosts[ip]; ok {
|
||||
delete(d.hosts, ip)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AddEvent adds an event for the given IP.
|
||||
// This method must be called for clients not yet banned
|
||||
func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
|
||||
if d.safeList != nil && d.safeList.isListed(ip) {
|
||||
return
|
||||
}
|
||||
|
||||
// ignore events for already banned hosts
|
||||
if v, ok := d.banned[ip]; ok {
|
||||
if v.After(time.Now()) {
|
||||
return
|
||||
}
|
||||
delete(d.banned, ip)
|
||||
}
|
||||
|
||||
score := d.baseDefender.getScore(event)
|
||||
|
||||
ev := hostEvent{
|
||||
dateTime: time.Now(),
|
||||
score: score,
|
||||
}
|
||||
|
||||
if hs, ok := d.hosts[ip]; ok {
|
||||
hs.Events = append(hs.Events, ev)
|
||||
hs.TotalScore = 0
|
||||
|
||||
idx := 0
|
||||
for _, event := range hs.Events {
|
||||
if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
|
||||
hs.Events[idx] = event
|
||||
hs.TotalScore += event.score
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
hs.Events = hs.Events[:idx]
|
||||
if hs.TotalScore >= d.config.Threshold {
|
||||
d.banned[ip] = time.Now().Add(time.Duration(d.config.BanTime) * time.Minute)
|
||||
delete(d.hosts, ip)
|
||||
d.cleanupBanned()
|
||||
} else {
|
||||
d.hosts[ip] = hs
|
||||
}
|
||||
} else {
|
||||
d.hosts[ip] = hostScore{
|
||||
TotalScore: ev.score,
|
||||
Events: []hostEvent{ev},
|
||||
}
|
||||
d.cleanupHosts()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *memoryDefender) countBanned() int {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
return len(d.banned)
|
||||
}
|
||||
|
||||
func (d *memoryDefender) countHosts() int {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
return len(d.hosts)
|
||||
}
|
||||
|
||||
// GetBanTime returns the ban time for the given IP or nil if the IP is not banned
|
||||
func (d *memoryDefender) GetBanTime(ip string) (*time.Time, error) {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
if banTime, ok := d.banned[ip]; ok {
|
||||
return &banTime, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetScore returns the score for the given IP
|
||||
func (d *memoryDefender) GetScore(ip string) (int, error) {
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
|
||||
score := 0
|
||||
|
||||
if hs, ok := d.hosts[ip]; ok {
|
||||
for _, event := range hs.Events {
|
||||
if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) {
|
||||
score += event.score
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return score, nil
|
||||
}
|
||||
|
||||
func (d *memoryDefender) cleanupBanned() {
|
||||
if len(d.banned) > d.config.EntriesHardLimit {
|
||||
kvList := make(kvList, 0, len(d.banned))
|
||||
|
||||
for k, v := range d.banned {
|
||||
if v.Before(time.Now()) {
|
||||
delete(d.banned, k)
|
||||
}
|
||||
|
||||
kvList = append(kvList, kv{
|
||||
Key: k,
|
||||
Value: v.UnixNano(),
|
||||
})
|
||||
}
|
||||
|
||||
// we removed expired ip addresses, if any, above, this could be enough
|
||||
numToRemove := len(d.banned) - d.config.EntriesSoftLimit
|
||||
|
||||
if numToRemove <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sort.Sort(kvList)
|
||||
|
||||
for idx, kv := range kvList {
|
||||
if idx >= numToRemove {
|
||||
break
|
||||
}
|
||||
|
||||
delete(d.banned, kv.Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *memoryDefender) cleanupHosts() {
|
||||
if len(d.hosts) > d.config.EntriesHardLimit {
|
||||
kvList := make(kvList, 0, len(d.hosts))
|
||||
|
||||
for k, v := range d.hosts {
|
||||
value := int64(0)
|
||||
if len(v.Events) > 0 {
|
||||
value = v.Events[len(v.Events)-1].dateTime.UnixNano()
|
||||
}
|
||||
kvList = append(kvList, kv{
|
||||
Key: k,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Sort(kvList)
|
||||
|
||||
numToRemove := len(d.hosts) - d.config.EntriesSoftLimit
|
||||
|
||||
for idx, kv := range kvList {
|
||||
if idx >= numToRemove {
|
||||
break
|
||||
}
|
||||
|
||||
delete(d.hosts, kv.Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type kv struct {
|
||||
Key string
|
||||
Value int64
|
||||
}
|
||||
|
||||
type kvList []kv
|
||||
|
||||
func (p kvList) Len() int { return len(p) }
|
||||
func (p kvList) Less(i, j int) bool { return p[i].Value < p[j].Value }
|
||||
func (p kvList) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
|
@ -38,6 +38,7 @@ import (
|
|||
"github.com/drakkan/sftpgo/v2/logger"
|
||||
"github.com/drakkan/sftpgo/v2/mfa"
|
||||
"github.com/drakkan/sftpgo/v2/sdk"
|
||||
"github.com/drakkan/sftpgo/v2/util"
|
||||
"github.com/drakkan/sftpgo/v2/vfs"
|
||||
)
|
||||
|
||||
|
@ -2161,6 +2162,58 @@ func TestUserPasswordHashing(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDbDefenderErrors(t *testing.T) {
|
||||
if !isDbDefenderSupported() {
|
||||
t.Skip("this test is not supported with the current database provider")
|
||||
}
|
||||
configCopy := common.Config
|
||||
common.Config.DefenderConfig.Enabled = true
|
||||
common.Config.DefenderConfig.Driver = common.DefenderDriverProvider
|
||||
err := common.Initialize(common.Config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
testIP := "127.1.1.1"
|
||||
hosts, err := common.GetDefenderHosts()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, hosts, 0)
|
||||
common.AddDefenderEvent(testIP, common.HostEventLimitExceeded)
|
||||
hosts, err = common.GetDefenderHosts()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, hosts, 1)
|
||||
score, err := common.GetDefenderScore(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, score)
|
||||
banTime, err := common.GetDefenderBanTime(testIP)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, banTime)
|
||||
|
||||
err = dataprovider.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
common.AddDefenderEvent(testIP, common.HostEventLimitExceeded)
|
||||
_, err = common.GetDefenderHosts()
|
||||
assert.Error(t, err)
|
||||
_, err = common.GetDefenderHost(testIP)
|
||||
assert.Error(t, err)
|
||||
_, err = common.GetDefenderBanTime(testIP)
|
||||
assert.Error(t, err)
|
||||
_, err = common.GetDefenderScore(testIP)
|
||||
assert.Error(t, err)
|
||||
|
||||
err = config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
providerConf := config.GetProviderConf()
|
||||
err = dataprovider.Initialize(providerConf, configDir, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Hour)))
|
||||
assert.NoError(t, err)
|
||||
|
||||
common.Config = configCopy
|
||||
err = common.Initialize(common.Config)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDelayedQuotaUpdater(t *testing.T) {
|
||||
err := dataprovider.Close()
|
||||
assert.NoError(t, err)
|
||||
|
@ -3302,3 +3355,15 @@ func generateTOTPPasscode(secret string, algo otp.Algorithm) (string, error) {
|
|||
Algorithm: algo,
|
||||
})
|
||||
}
|
||||
|
||||
func isDbDefenderSupported() bool {
|
||||
// SQLite shares the implementation with other SQL-based provider but it makes no sense
|
||||
// to use it outside test cases
|
||||
switch dataprovider.GetProviderStatus().Driver {
|
||||
case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName,
|
||||
dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
|
@ -144,6 +144,7 @@ func Init() {
|
|||
MaxPerHostConnections: 20,
|
||||
DefenderConfig: common.DefenderConfig{
|
||||
Enabled: false,
|
||||
Driver: common.DefenderDriverMemory,
|
||||
BanTime: 30,
|
||||
BanTimeIncrement: 50,
|
||||
Threshold: 15,
|
||||
|
@ -530,6 +531,12 @@ func LoadConfig(configDir, configFile string) error {
|
|||
}
|
||||
// viper only supports slice of strings from env vars, so we use our custom method
|
||||
loadBindingsFromEnv()
|
||||
resetInvalidConfigs()
|
||||
logger.Debug(logSender, "", "config file used: '%#v', config loaded: %+v", viper.ConfigFileUsed(), getRedactedGlobalConf())
|
||||
return nil
|
||||
}
|
||||
|
||||
func resetInvalidConfigs() {
|
||||
if strings.TrimSpace(globalConf.SFTPD.Banner) == "" {
|
||||
globalConf.SFTPD.Banner = defaultSFTPDBanner
|
||||
}
|
||||
|
@ -537,39 +544,48 @@ func LoadConfig(configDir, configFile string) error {
|
|||
globalConf.FTPD.Banner = defaultFTPDBanner
|
||||
}
|
||||
if globalConf.ProviderConf.UsersBaseDir != "" && !util.IsFileInputValid(globalConf.ProviderConf.UsersBaseDir) {
|
||||
err = fmt.Errorf("invalid users base dir %#v will be ignored", globalConf.ProviderConf.UsersBaseDir)
|
||||
warn := fmt.Sprintf("invalid users base dir %#v will be ignored", globalConf.ProviderConf.UsersBaseDir)
|
||||
globalConf.ProviderConf.UsersBaseDir = ""
|
||||
logger.Warn(logSender, "", "Configuration error: %v", err)
|
||||
logger.WarnToConsole("Configuration error: %v", err)
|
||||
logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn)
|
||||
logger.WarnToConsole("Non-fatal configuration error: %v", warn)
|
||||
}
|
||||
if globalConf.Common.UploadMode < 0 || globalConf.Common.UploadMode > 2 {
|
||||
warn := fmt.Sprintf("invalid upload_mode 0, 1 and 2 are supported, configured: %v reset upload_mode to 0",
|
||||
globalConf.Common.UploadMode)
|
||||
globalConf.Common.UploadMode = 0
|
||||
logger.Warn(logSender, "", "Configuration error: %v", warn)
|
||||
logger.WarnToConsole("Configuration error: %v", warn)
|
||||
logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn)
|
||||
logger.WarnToConsole("Non-fatal configuration error: %v", warn)
|
||||
}
|
||||
if globalConf.Common.ProxyProtocol < 0 || globalConf.Common.ProxyProtocol > 2 {
|
||||
warn := fmt.Sprintf("invalid proxy_protocol 0, 1 and 2 are supported, configured: %v reset proxy_protocol to 0",
|
||||
globalConf.Common.ProxyProtocol)
|
||||
globalConf.Common.ProxyProtocol = 0
|
||||
logger.Warn(logSender, "", "Configuration error: %v", warn)
|
||||
logger.WarnToConsole("Configuration error: %v", warn)
|
||||
logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn)
|
||||
logger.WarnToConsole("Non-fatal configuration error: %v", warn)
|
||||
}
|
||||
if globalConf.ProviderConf.ExternalAuthScope < 0 || globalConf.ProviderConf.ExternalAuthScope > 15 {
|
||||
warn := fmt.Sprintf("invalid external_auth_scope: %v reset to 0", globalConf.ProviderConf.ExternalAuthScope)
|
||||
globalConf.ProviderConf.ExternalAuthScope = 0
|
||||
logger.Warn(logSender, "", "Configuration error: %v", warn)
|
||||
logger.WarnToConsole("Configuration error: %v", warn)
|
||||
logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn)
|
||||
logger.WarnToConsole("Non-fatal configuration error: %v", warn)
|
||||
}
|
||||
if globalConf.ProviderConf.CredentialsPath == "" {
|
||||
warn := "invalid credentials path, reset to \"credentials\""
|
||||
globalConf.ProviderConf.CredentialsPath = "credentials"
|
||||
logger.Warn(logSender, "", "Configuration error: %v", warn)
|
||||
logger.WarnToConsole("Configuration error: %v", warn)
|
||||
logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn)
|
||||
logger.WarnToConsole("Non-fatal configuration error: %v", warn)
|
||||
}
|
||||
if globalConf.Common.DefenderConfig.Enabled && globalConf.Common.DefenderConfig.Driver == common.DefenderDriverProvider {
|
||||
if !globalConf.ProviderConf.IsDefenderSupported() {
|
||||
warn := fmt.Sprintf("provider based defender is not supported with data provider %#v, "+
|
||||
"the memory defender implementation will be used. If you want to use the provider defender "+
|
||||
"implementation please switch to a shared/distributed data provider",
|
||||
globalConf.ProviderConf.Driver)
|
||||
globalConf.Common.DefenderConfig.Driver = common.DefenderDriverMemory
|
||||
logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn)
|
||||
logger.WarnToConsole("Non-fatal configuration error: %v", warn)
|
||||
}
|
||||
}
|
||||
logger.Debug(logSender, "", "config file used: '%#v', config loaded: %+v", viper.ConfigFileUsed(), getRedactedGlobalConf())
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadBindingsFromEnv() {
|
||||
|
@ -1200,6 +1216,7 @@ func setViperDefaults() {
|
|||
viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections)
|
||||
viper.SetDefault("common.max_per_host_connections", globalConf.Common.MaxPerHostConnections)
|
||||
viper.SetDefault("common.defender.enabled", globalConf.Common.DefenderConfig.Enabled)
|
||||
viper.SetDefault("common.defender.driver", globalConf.Common.DefenderConfig.Driver)
|
||||
viper.SetDefault("common.defender.ban_time", globalConf.Common.DefenderConfig.BanTime)
|
||||
viper.SetDefault("common.defender.ban_time_increment", globalConf.Common.DefenderConfig.BanTimeIncrement)
|
||||
viper.SetDefault("common.defender.threshold", globalConf.Common.DefenderConfig.Threshold)
|
||||
|
|
|
@ -3,6 +3,4 @@
|
|||
|
||||
package config
|
||||
|
||||
func setViperAdditionalConfigPaths() {
|
||||
|
||||
}
|
||||
func setViperAdditionalConfigPaths() {}
|
||||
|
|
|
@ -250,6 +250,35 @@ func TestInvalidUsersBaseDir(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDefenderProviderDriver(t *testing.T) {
|
||||
if config.GetProviderConf().Driver != dataprovider.SQLiteDataProviderName {
|
||||
t.Skip("this test is not supported with the current database provider")
|
||||
}
|
||||
reset()
|
||||
|
||||
configDir := ".."
|
||||
confName := tempConfigName + ".json"
|
||||
configFilePath := filepath.Join(configDir, confName)
|
||||
providerConf := config.GetProviderConf()
|
||||
providerConf.Driver = dataprovider.BoltDataProviderName
|
||||
commonConfig := config.GetCommonConfig()
|
||||
commonConfig.DefenderConfig.Enabled = true
|
||||
commonConfig.DefenderConfig.Driver = common.DefenderDriverProvider
|
||||
c := make(map[string]interface{})
|
||||
c["common"] = commonConfig
|
||||
c["data_provider"] = providerConf
|
||||
jsonConf, err := json.Marshal(c)
|
||||
assert.NoError(t, err)
|
||||
err = os.WriteFile(configFilePath, jsonConf, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, confName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, dataprovider.BoltDataProviderName, config.GetProviderConf().Driver)
|
||||
assert.Equal(t, common.DefenderDriverMemory, config.GetCommonConfig().DefenderConfig.Driver)
|
||||
err = os.Remove(configFilePath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSetGetConfig(t *testing.T) {
|
||||
reset()
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
boltDatabaseVersion = 14
|
||||
boltDatabaseVersion = 15
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -1365,6 +1365,38 @@ func (p *BoltProvider) updateShareLastUse(shareID string, numTokens int) error {
|
|||
})
|
||||
}
|
||||
|
||||
func (p *BoltProvider) getDefenderHosts(from int64, limit int) ([]*DefenderEntry, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *BoltProvider) getDefenderHostByIP(ip string, from int64) (*DefenderEntry, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *BoltProvider) isDefenderHostBanned(ip string) (*DefenderEntry, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *BoltProvider) updateDefenderBanTime(ip string, minutes int) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *BoltProvider) deleteDefenderHost(ip string) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *BoltProvider) addDefenderEvent(ip string, score int) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *BoltProvider) setDefenderBanTime(ip string, banTime int64) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *BoltProvider) cleanupDefender(from int64) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *BoltProvider) close() error {
|
||||
return p.dbHandle.Close()
|
||||
}
|
||||
|
@ -1393,13 +1425,15 @@ func (p *BoltProvider) migrateDatabase() error {
|
|||
logger.ErrorToConsole("%v", err)
|
||||
return err
|
||||
case version == 10:
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 14)
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 15)
|
||||
case version == 11:
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 14)
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 15)
|
||||
case version == 12:
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 14)
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 15)
|
||||
case version == 13:
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 14)
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 15)
|
||||
case version == 14:
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 15)
|
||||
default:
|
||||
if version > boltDatabaseVersion {
|
||||
providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version,
|
||||
|
@ -1421,13 +1455,7 @@ func (p *BoltProvider) revertDatabase(targetVersion int) error {
|
|||
return errors.New("current version match target version, nothing to do")
|
||||
}
|
||||
switch dbVersion.Version {
|
||||
case 14:
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 10)
|
||||
case 13:
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 10)
|
||||
case 12:
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 10)
|
||||
case 11:
|
||||
case 15, 14, 13, 12, 11:
|
||||
return updateBoltDatabaseVersion(p.dbHandle, 10)
|
||||
default:
|
||||
return fmt.Errorf("database version not handled: %v", dbVersion.Version)
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -125,6 +126,8 @@ var (
|
|||
SSHMultiStepsLoginMethods = []string{SSHLoginMethodKeyAndPassword, SSHLoginMethodKeyAndKeyboardInt}
|
||||
// ErrNoAuthTryed defines the error for connection closed before authentication
|
||||
ErrNoAuthTryed = errors.New("no auth tryed")
|
||||
// ErrNotImplemented defines the error for features not supported for a particular data provider
|
||||
ErrNotImplemented = errors.New("feature not supported with the configured data provider")
|
||||
// ValidProtocols defines all the valid protcols
|
||||
ValidProtocols = []string{protocolSSH, protocolFTP, protocolWebDAV, protocolHTTP}
|
||||
// MFAProtocols defines the supported protocols for multi-factor authentication
|
||||
|
@ -160,6 +163,8 @@ var (
|
|||
sqlTableAdmins = "admins"
|
||||
sqlTableAPIKeys = "api_keys"
|
||||
sqlTableShares = "shares"
|
||||
sqlTableDefenderHosts = "defender_hosts"
|
||||
sqlTableDefenderEvents = "defender_events"
|
||||
sqlTableSchemaVersion = "schema_version"
|
||||
argon2Params *argon2id.Params
|
||||
lastLoginMinDelay = 10 * time.Minute
|
||||
|
@ -366,6 +371,52 @@ type Config struct {
|
|||
IsShared int `json:"is_shared" mapstructure:"is_shared"`
|
||||
}
|
||||
|
||||
// IsDefenderSupported returns true if the configured provider supports the defender
|
||||
func (c *Config) IsDefenderSupported() bool {
|
||||
switch c.Driver {
|
||||
case MySQLDataProviderName, PGSQLDataProviderName, CockroachDataProviderName:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// DefenderEntry defines a defender entry
|
||||
type DefenderEntry struct {
|
||||
ID int64 `json:"-"`
|
||||
IP string `json:"ip"`
|
||||
Score int `json:"score,omitempty"`
|
||||
BanTime time.Time `json:"ban_time,omitempty"`
|
||||
}
|
||||
|
||||
// GetID returns an unique ID for a defender entry
|
||||
func (d *DefenderEntry) GetID() string {
|
||||
return hex.EncodeToString([]byte(d.IP))
|
||||
}
|
||||
|
||||
// GetBanTime returns the ban time for a defender entry as string
|
||||
func (d *DefenderEntry) GetBanTime() string {
|
||||
if d.BanTime.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return d.BanTime.UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// MarshalJSON returns the JSON encoding of a DefenderEntry.
|
||||
func (d *DefenderEntry) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(&struct {
|
||||
ID string `json:"id"`
|
||||
IP string `json:"ip"`
|
||||
Score int `json:"score,omitempty"`
|
||||
BanTime string `json:"ban_time,omitempty"`
|
||||
}{
|
||||
ID: d.GetID(),
|
||||
IP: d.IP,
|
||||
Score: d.Score,
|
||||
BanTime: d.GetBanTime(),
|
||||
})
|
||||
}
|
||||
|
||||
// BackupData defines the structure for the backup/restore files
|
||||
type BackupData struct {
|
||||
Users []User `json:"users"`
|
||||
|
@ -452,6 +503,14 @@ type Provider interface {
|
|||
getShares(limit int, offset int, order, username string) ([]Share, error)
|
||||
dumpShares() ([]Share, error)
|
||||
updateShareLastUse(shareID string, numTokens int) error
|
||||
getDefenderHosts(from int64, limit int) ([]*DefenderEntry, error)
|
||||
getDefenderHostByIP(ip string, from int64) (*DefenderEntry, error)
|
||||
isDefenderHostBanned(ip string) (*DefenderEntry, error)
|
||||
updateDefenderBanTime(ip string, minutes int) error
|
||||
deleteDefenderHost(ip string) error
|
||||
addDefenderEvent(ip string, score int) error
|
||||
setDefenderBanTime(ip string, banTime int64) error
|
||||
cleanupDefender(from int64) error
|
||||
checkAvailability() error
|
||||
close() error
|
||||
reloadConfig() error
|
||||
|
@ -861,6 +920,50 @@ func CheckKeyboardInteractiveAuth(username, authHook string, client ssh.Keyboard
|
|||
return doKeyboardInteractiveAuth(&user, authHook, client, ip, protocol)
|
||||
}
|
||||
|
||||
// GetDefenderHosts returns hosts that are banned or for which some violations have been detected
|
||||
func GetDefenderHosts(from int64, limit int) ([]*DefenderEntry, error) {
|
||||
return provider.getDefenderHosts(from, limit)
|
||||
}
|
||||
|
||||
// GetDefenderHostByIP returns a defender host by ip, if any
|
||||
func GetDefenderHostByIP(ip string, from int64) (*DefenderEntry, error) {
|
||||
return provider.getDefenderHostByIP(ip, from)
|
||||
}
|
||||
|
||||
// IsDefenderHostBanned returns a defender entry and no error if the specified host is banned
|
||||
func IsDefenderHostBanned(ip string) (*DefenderEntry, error) {
|
||||
return provider.isDefenderHostBanned(ip)
|
||||
}
|
||||
|
||||
// UpdateDefenderBanTime increments ban time for the specified ip
|
||||
func UpdateDefenderBanTime(ip string, minutes int) error {
|
||||
return provider.updateDefenderBanTime(ip, minutes)
|
||||
}
|
||||
|
||||
// DeleteDefenderHost removes the specified IP from the defender lists
|
||||
func DeleteDefenderHost(ip string) error {
|
||||
return provider.deleteDefenderHost(ip)
|
||||
}
|
||||
|
||||
// AddDefenderEvent adds an event for the given IP with the given score
|
||||
// and returns the host with the updated score
|
||||
func AddDefenderEvent(ip string, score int, from int64) (*DefenderEntry, error) {
|
||||
if err := provider.addDefenderEvent(ip, score); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return provider.getDefenderHostByIP(ip, from)
|
||||
}
|
||||
|
||||
// SetDefenderBanTime sets the ban time for the specified IP
|
||||
func SetDefenderBanTime(ip string, banTime int64) error {
|
||||
return provider.setDefenderBanTime(ip, banTime)
|
||||
}
|
||||
|
||||
// CleanupDefender removes events and hosts older than "from" from the data provider
|
||||
func CleanupDefender(from int64) error {
|
||||
return provider.cleanupDefender(from)
|
||||
}
|
||||
|
||||
// UpdateShareLastUse updates the LastUseAt and UsedTokens for the given share
|
||||
func UpdateShareLastUse(share *Share, numTokens int) error {
|
||||
return provider.updateShareLastUse(share.ShareID, numTokens)
|
||||
|
@ -1253,6 +1356,11 @@ func ParseDumpData(data []byte) (BackupData, error) {
|
|||
return dump, err
|
||||
}
|
||||
|
||||
// GetProviderConfig returns the current provider configuration
|
||||
func GetProviderConfig() Config {
|
||||
return config
|
||||
}
|
||||
|
||||
// GetProviderStatus returns an error if the provider is not available
|
||||
func GetProviderStatus() ProviderStatus {
|
||||
err := provider.checkAvailability()
|
||||
|
|
|
@ -1256,6 +1256,38 @@ func (p *MemoryProvider) updateShareLastUse(shareID string, numTokens int) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) getDefenderHosts(from int64, limit int) ([]*DefenderEntry, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) getDefenderHostByIP(ip string, from int64) (*DefenderEntry, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) isDefenderHostBanned(ip string) (*DefenderEntry, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) updateDefenderBanTime(ip string, minutes int) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) deleteDefenderHost(ip string) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) addDefenderEvent(ip string, score int) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) setDefenderBanTime(ip string, banTime int64) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) cleanupDefender(from int64) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) getNextID() int64 {
|
||||
nextID := int64(1)
|
||||
for _, v := range p.dbHandle.users {
|
||||
|
|
|
@ -27,6 +27,8 @@ const (
|
|||
"DROP TABLE IF EXISTS `{{folders}}` CASCADE;" +
|
||||
"DROP TABLE IF EXISTS `{{shares}}` CASCADE;" +
|
||||
"DROP TABLE IF EXISTS `{{users}}` CASCADE;" +
|
||||
"DROP TABLE IF EXISTS `{{defender_events}}` CASCADE;" +
|
||||
"DROP TABLE IF EXISTS `{{defender_hosts}}` CASCADE;" +
|
||||
"DROP TABLE IF EXISTS `{{schema_version}}` CASCADE;"
|
||||
mysqlInitialSQL = "CREATE TABLE `{{schema_version}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" +
|
||||
"CREATE TABLE `{{admins}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " +
|
||||
|
@ -82,6 +84,17 @@ const (
|
|||
"ALTER TABLE `{{shares}}` ADD CONSTRAINT `{{prefix}}shares_user_id_fk_users_id` " +
|
||||
"FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;"
|
||||
mysqlV14DownSQL = "DROP TABLE `{{shares}}` CASCADE;"
|
||||
mysqlV15SQL = "CREATE TABLE `{{defender_hosts}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
|
||||
"`ip` varchar(50) NOT NULL UNIQUE, `ban_time` bigint NOT NULL, `updated_at` bigint NOT NULL);" +
|
||||
"CREATE TABLE `{{defender_events}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
|
||||
"`date_time` bigint NOT NULL, `score` integer NOT NULL, `host_id` bigint NOT NULL);" +
|
||||
"ALTER TABLE `{{defender_events}}` ADD CONSTRAINT `{{prefix}}defender_events_host_id_fk_defender_hosts_id` " +
|
||||
"FOREIGN KEY (`host_id`) REFERENCES `{{defender_hosts}}` (`id`) ON DELETE CASCADE;" +
|
||||
"CREATE INDEX `{{prefix}}defender_hosts_updated_at_idx` ON `{{defender_hosts}}` (`updated_at`);" +
|
||||
"CREATE INDEX `{{prefix}}defender_hosts_ban_time_idx` ON `{{defender_hosts}}` (`ban_time`);" +
|
||||
"CREATE INDEX `{{prefix}}defender_events_date_time_idx` ON `{{defender_events}}` (`date_time`);"
|
||||
mysqlV15DownSQL = "DROP TABLE `{{defender_events}}` CASCADE;" +
|
||||
"DROP TABLE `{{defender_hosts}}` CASCADE;"
|
||||
)
|
||||
|
||||
// MySQLProvider auth provider for MySQL/MariaDB database
|
||||
|
@ -311,6 +324,38 @@ func (p *MySQLProvider) updateShareLastUse(shareID string, numTokens int) error
|
|||
return sqlCommonUpdateShareLastUse(shareID, numTokens, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) getDefenderHosts(from int64, limit int) ([]*DefenderEntry, error) {
|
||||
return sqlCommonGetDefenderHosts(from, limit, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) getDefenderHostByIP(ip string, from int64) (*DefenderEntry, error) {
|
||||
return sqlCommonGetDefenderHostByIP(ip, from, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) isDefenderHostBanned(ip string) (*DefenderEntry, error) {
|
||||
return sqlCommonIsDefenderHostBanned(ip, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) updateDefenderBanTime(ip string, minutes int) error {
|
||||
return sqlCommonDefenderIncrementBanTime(ip, minutes, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) deleteDefenderHost(ip string) error {
|
||||
return sqlCommonDeleteDefenderHost(ip, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) addDefenderEvent(ip string, score int) error {
|
||||
return sqlCommonAddDefenderHostAndEvent(ip, score, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) setDefenderBanTime(ip string, banTime int64) error {
|
||||
return sqlCommonSetDefenderBanTime(ip, banTime, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) cleanupDefender(from int64) error {
|
||||
return sqlCommonDefenderCleanup(from, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *MySQLProvider) close() error {
|
||||
return p.dbHandle.Close()
|
||||
}
|
||||
|
@ -362,6 +407,8 @@ func (p *MySQLProvider) migrateDatabase() error {
|
|||
return updateMySQLDatabaseFromV12(p.dbHandle)
|
||||
case version == 13:
|
||||
return updateMySQLDatabaseFromV13(p.dbHandle)
|
||||
case version == 14:
|
||||
return updateMySQLDatabaseFromV14(p.dbHandle)
|
||||
default:
|
||||
if version > sqlDatabaseVersion {
|
||||
providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version,
|
||||
|
@ -384,6 +431,8 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error {
|
|||
}
|
||||
|
||||
switch dbVersion.Version {
|
||||
case 15:
|
||||
return downgradeMySQLDatabaseFromV15(p.dbHandle)
|
||||
case 14:
|
||||
return downgradeMySQLDatabaseFromV14(p.dbHandle)
|
||||
case 13:
|
||||
|
@ -405,6 +454,8 @@ func (p *MySQLProvider) resetDatabase() error {
|
|||
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
|
||||
sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys)
|
||||
sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares)
|
||||
sql = strings.ReplaceAll(sql, "{{defender_events}}", sqlTableDefenderEvents)
|
||||
sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(sql, ";"), 0)
|
||||
}
|
||||
|
||||
|
@ -430,7 +481,21 @@ func updateMySQLDatabaseFromV12(dbHandle *sql.DB) error {
|
|||
}
|
||||
|
||||
func updateMySQLDatabaseFromV13(dbHandle *sql.DB) error {
|
||||
return updateMySQLDatabaseFrom13To14(dbHandle)
|
||||
if err := updateMySQLDatabaseFrom13To14(dbHandle); err != nil {
|
||||
return err
|
||||
}
|
||||
return updateMySQLDatabaseFromV14(dbHandle)
|
||||
}
|
||||
|
||||
func updateMySQLDatabaseFromV14(dbHandle *sql.DB) error {
|
||||
return updateMySQLDatabaseFrom14To15(dbHandle)
|
||||
}
|
||||
|
||||
func downgradeMySQLDatabaseFromV15(dbHandle *sql.DB) error {
|
||||
if err := downgradeMySQLDatabaseFrom15To14(dbHandle); err != nil {
|
||||
return err
|
||||
}
|
||||
return downgradeMySQLDatabaseFromV14(dbHandle)
|
||||
}
|
||||
|
||||
func downgradeMySQLDatabaseFromV14(dbHandle *sql.DB) error {
|
||||
|
@ -467,6 +532,23 @@ func updateMySQLDatabaseFrom13To14(dbHandle *sql.DB) error {
|
|||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 14)
|
||||
}
|
||||
|
||||
func updateMySQLDatabaseFrom14To15(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("updating database version: 14 -> 15")
|
||||
providerLog(logger.LevelInfo, "updating database version: 14 -> 15")
|
||||
sql := strings.ReplaceAll(mysqlV15SQL, "{{defender_events}}", sqlTableDefenderEvents)
|
||||
sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
|
||||
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 15)
|
||||
}
|
||||
|
||||
func downgradeMySQLDatabaseFrom15To14(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database version: 15 -> 14")
|
||||
providerLog(logger.LevelInfo, "downgrading database version: 15 -> 14")
|
||||
sql := strings.ReplaceAll(mysqlV15DownSQL, "{{defender_events}}", sqlTableDefenderEvents)
|
||||
sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 14)
|
||||
}
|
||||
|
||||
func downgradeMySQLDatabaseFrom14To13(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database version: 14 -> 13")
|
||||
providerLog(logger.LevelInfo, "downgrading database version: 14 -> 13")
|
||||
|
|
|
@ -27,6 +27,8 @@ DROP TABLE IF EXISTS "{{admins}}" CASCADE;
|
|||
DROP TABLE IF EXISTS "{{folders}}" CASCADE;
|
||||
DROP TABLE IF EXISTS "{{shares}}" CASCADE;
|
||||
DROP TABLE IF EXISTS "{{users}}" CASCADE;
|
||||
DROP TABLE IF EXISTS "{{defender_events}}" CASCADE;
|
||||
DROP TABLE IF EXISTS "{{defender_hosts}}" CASCADE;
|
||||
DROP TABLE IF EXISTS "{{schema_version}}" CASCADE;
|
||||
`
|
||||
pgsqlInitial = `CREATE TABLE "{{schema_version}}" ("id" serial NOT NULL PRIMARY KEY, "version" integer NOT NULL);
|
||||
|
@ -97,6 +99,20 @@ REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE
|
|||
CREATE INDEX "{{prefix}}shares_user_id_idx" ON "{{shares}}" ("user_id");
|
||||
`
|
||||
pgsqlV14DownSQL = `DROP TABLE "{{shares}}" CASCADE;`
|
||||
pgsqlV15SQL = `CREATE TABLE "{{defender_hosts}}" ("id" bigserial NOT NULL PRIMARY KEY, "ip" varchar(50) NOT NULL UNIQUE,
|
||||
"ban_time" bigint NOT NULL, "updated_at" bigint NOT NULL);
|
||||
CREATE TABLE "{{defender_events}}" ("id" bigserial NOT NULL PRIMARY KEY, "date_time" bigint NOT NULL, "score" integer NOT NULL,
|
||||
"host_id" bigint NOT NULL);
|
||||
ALTER TABLE "{{defender_events}}" ADD CONSTRAINT "{{prefix}}defender_events_host_id_fk_defender_hosts_id" FOREIGN KEY
|
||||
("host_id") REFERENCES "{{defender_hosts}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
|
||||
CREATE INDEX "{{prefix}}defender_hosts_updated_at_idx" ON "{{defender_hosts}}" ("updated_at");
|
||||
CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("ban_time");
|
||||
CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time");
|
||||
CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id");
|
||||
`
|
||||
pgsqlV15DownSQL = `DROP TABLE "{{defender_events}}" CASCADE;
|
||||
DROP TABLE "{{defender_hosts}}" CASCADE;
|
||||
`
|
||||
)
|
||||
|
||||
// PGSQLProvider auth provider for PostgreSQL database
|
||||
|
@ -326,6 +342,38 @@ func (p *PGSQLProvider) updateShareLastUse(shareID string, numTokens int) error
|
|||
return sqlCommonUpdateShareLastUse(shareID, numTokens, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) getDefenderHosts(from int64, limit int) ([]*DefenderEntry, error) {
|
||||
return sqlCommonGetDefenderHosts(from, limit, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) getDefenderHostByIP(ip string, from int64) (*DefenderEntry, error) {
|
||||
return sqlCommonGetDefenderHostByIP(ip, from, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) isDefenderHostBanned(ip string) (*DefenderEntry, error) {
|
||||
return sqlCommonIsDefenderHostBanned(ip, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) updateDefenderBanTime(ip string, minutes int) error {
|
||||
return sqlCommonDefenderIncrementBanTime(ip, minutes, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) deleteDefenderHost(ip string) error {
|
||||
return sqlCommonDeleteDefenderHost(ip, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) addDefenderEvent(ip string, score int) error {
|
||||
return sqlCommonAddDefenderHostAndEvent(ip, score, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) setDefenderBanTime(ip string, banTime int64) error {
|
||||
return sqlCommonSetDefenderBanTime(ip, banTime, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) cleanupDefender(from int64) error {
|
||||
return sqlCommonDefenderCleanup(from, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *PGSQLProvider) close() error {
|
||||
return p.dbHandle.Close()
|
||||
}
|
||||
|
@ -383,6 +431,8 @@ func (p *PGSQLProvider) migrateDatabase() error {
|
|||
return updatePGSQLDatabaseFromV12(p.dbHandle)
|
||||
case version == 13:
|
||||
return updatePGSQLDatabaseFromV13(p.dbHandle)
|
||||
case version == 14:
|
||||
return updatePGSQLDatabaseFromV14(p.dbHandle)
|
||||
default:
|
||||
if version > sqlDatabaseVersion {
|
||||
providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version,
|
||||
|
@ -405,6 +455,8 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error {
|
|||
}
|
||||
|
||||
switch dbVersion.Version {
|
||||
case 15:
|
||||
return downgradePGSQLDatabaseFromV15(p.dbHandle)
|
||||
case 14:
|
||||
return downgradePGSQLDatabaseFromV14(p.dbHandle)
|
||||
case 13:
|
||||
|
@ -426,6 +478,8 @@ func (p *PGSQLProvider) resetDatabase() error {
|
|||
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
|
||||
sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys)
|
||||
sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares)
|
||||
sql = strings.ReplaceAll(sql, "{{defender_events}}", sqlTableDefenderEvents)
|
||||
sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0)
|
||||
}
|
||||
|
||||
|
@ -451,7 +505,21 @@ func updatePGSQLDatabaseFromV12(dbHandle *sql.DB) error {
|
|||
}
|
||||
|
||||
func updatePGSQLDatabaseFromV13(dbHandle *sql.DB) error {
|
||||
return updatePGSQLDatabaseFrom13To14(dbHandle)
|
||||
if err := updatePGSQLDatabaseFrom13To14(dbHandle); err != nil {
|
||||
return err
|
||||
}
|
||||
return updatePGSQLDatabaseFromV14(dbHandle)
|
||||
}
|
||||
|
||||
func updatePGSQLDatabaseFromV14(dbHandle *sql.DB) error {
|
||||
return updatePGSQLDatabaseFrom14To15(dbHandle)
|
||||
}
|
||||
|
||||
func downgradePGSQLDatabaseFromV15(dbHandle *sql.DB) error {
|
||||
if err := downgradePGSQLDatabaseFrom15To14(dbHandle); err != nil {
|
||||
return err
|
||||
}
|
||||
return downgradePGSQLDatabaseFromV14(dbHandle)
|
||||
}
|
||||
|
||||
func downgradePGSQLDatabaseFromV14(dbHandle *sql.DB) error {
|
||||
|
@ -488,6 +556,23 @@ func updatePGSQLDatabaseFrom13To14(dbHandle *sql.DB) error {
|
|||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 14)
|
||||
}
|
||||
|
||||
func updatePGSQLDatabaseFrom14To15(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("updating database version: 14 -> 15")
|
||||
providerLog(logger.LevelInfo, "updating database version: 14 -> 15")
|
||||
sql := strings.ReplaceAll(pgsqlV15SQL, "{{defender_events}}", sqlTableDefenderEvents)
|
||||
sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
|
||||
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 15)
|
||||
}
|
||||
|
||||
func downgradePGSQLDatabaseFrom15To14(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database version: 15 -> 14")
|
||||
providerLog(logger.LevelInfo, "downgrading database version: 15 -> 14")
|
||||
sql := strings.ReplaceAll(pgsqlV15DownSQL, "{{defender_events}}", sqlTableDefenderEvents)
|
||||
sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 14)
|
||||
}
|
||||
|
||||
func downgradePGSQLDatabaseFrom14To13(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database version: 14 -> 13")
|
||||
providerLog(logger.LevelInfo, "downgrading database version: 14 -> 13")
|
||||
|
|
|
@ -19,7 +19,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
sqlDatabaseVersion = 14
|
||||
sqlDatabaseVersion = 15
|
||||
defaultSQLQueryTimeout = 10 * time.Second
|
||||
longSQLQueryTimeout = 60 * time.Second
|
||||
)
|
||||
|
@ -971,6 +971,267 @@ func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier)
|
|||
return getUsersWithVirtualFolders(ctx, users, dbHandle)
|
||||
}
|
||||
|
||||
func sqlCommonGetDefenderHosts(from int64, limit int, dbHandle sqlQuerier) ([]*DefenderEntry, error) {
|
||||
hosts := make([]*DefenderEntry, 0, 100)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getDefenderHostsQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
rows, err := stmt.QueryContext(ctx, from, limit)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to get defender hosts: %v", err)
|
||||
return hosts, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var idForScores []int64
|
||||
|
||||
for rows.Next() {
|
||||
var banTime sql.NullInt64
|
||||
host := DefenderEntry{}
|
||||
err = rows.Scan(&host.ID, &host.IP, &banTime)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to scan defender host row: %v", err)
|
||||
return hosts, err
|
||||
}
|
||||
var hostBanTime time.Time
|
||||
if banTime.Valid && banTime.Int64 > 0 {
|
||||
hostBanTime = util.GetTimeFromMsecSinceEpoch(banTime.Int64)
|
||||
}
|
||||
if hostBanTime.IsZero() || hostBanTime.Before(time.Now()) {
|
||||
idForScores = append(idForScores, host.ID)
|
||||
} else {
|
||||
host.BanTime = hostBanTime
|
||||
}
|
||||
hosts = append(hosts, &host)
|
||||
}
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to iterate over defender host rows: %v", err)
|
||||
return hosts, err
|
||||
}
|
||||
|
||||
return getDefenderHostsWithScores(ctx, hosts, from, idForScores, dbHandle)
|
||||
}
|
||||
|
||||
func sqlCommonIsDefenderHostBanned(ip string, dbHandle sqlQuerier) (*DefenderEntry, error) {
|
||||
var host DefenderEntry
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getDefenderIsHostBannedQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
row := stmt.QueryRowContext(ctx, ip, util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
err = row.Scan(&host.ID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, util.NewRecordNotFoundError("host not found")
|
||||
}
|
||||
providerLog(logger.LevelError, "unable to check ban status for host %#v: %v", ip, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &host, nil
|
||||
}
|
||||
|
||||
func sqlCommonGetDefenderHostByIP(ip string, from int64, dbHandle sqlQuerier) (*DefenderEntry, error) {
|
||||
var host DefenderEntry
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getDefenderHostQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
row := stmt.QueryRowContext(ctx, ip, from)
|
||||
var banTime sql.NullInt64
|
||||
err = row.Scan(&host.ID, &host.IP, &banTime)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, util.NewRecordNotFoundError("host not found")
|
||||
}
|
||||
providerLog(logger.LevelError, "unable to get host for ip %#v: %v", ip, err)
|
||||
return nil, err
|
||||
}
|
||||
if banTime.Valid && banTime.Int64 > 0 {
|
||||
hostBanTime := util.GetTimeFromMsecSinceEpoch(banTime.Int64)
|
||||
if !hostBanTime.IsZero() && hostBanTime.After(time.Now()) {
|
||||
host.BanTime = hostBanTime
|
||||
return &host, nil
|
||||
}
|
||||
}
|
||||
|
||||
hosts, err := getDefenderHostsWithScores(ctx, []*DefenderEntry{&host}, from, []int64{host.ID}, dbHandle)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(hosts) == 0 {
|
||||
return nil, util.NewRecordNotFoundError("host not found")
|
||||
}
|
||||
|
||||
return hosts[0], nil
|
||||
}
|
||||
|
||||
func sqlCommonDefenderIncrementBanTime(ip string, minutesToAdd int, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getDefenderIncrementBanTimeQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, minutesToAdd*60000, ip)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "ban time updated for ip %#v, increment (minutes): %v",
|
||||
ip, minutesToAdd)
|
||||
} else {
|
||||
providerLog(logger.LevelError, "error updating ban time for ip %#v: %v", ip, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonSetDefenderBanTime(ip string, banTime int64, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getDefenderSetBanTimeQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, banTime, ip)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "ip %#v banned until %v", ip, util.GetTimeFromMsecSinceEpoch(banTime))
|
||||
} else {
|
||||
providerLog(logger.LevelError, "error setting ban time for ip %#v: %v", ip, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonDeleteDefenderHost(ip string, dbHandle sqlQuerier) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getDeleteDefenderHostQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, ip)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to delete defender host %#v: %v", ip, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonAddDefenderHostAndEvent(ip string, score int, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
||||
if err := sqlCommonAddDefenderHost(ctx, ip, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonAddDefenderEvent(ctx, ip, score, tx)
|
||||
})
|
||||
}
|
||||
|
||||
func sqlCommonDefenderCleanup(from int64, dbHandler *sql.DB) error {
|
||||
if err := sqlCommonCleanupDefenderEvents(from, dbHandler); err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonCleanupDefenderHosts(from, dbHandler)
|
||||
}
|
||||
|
||||
func sqlCommonAddDefenderHost(ctx context.Context, ip string, tx *sql.Tx) error {
|
||||
q := getAddDefenderHostQuery()
|
||||
stmt, err := tx.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, ip, util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to add defender host %#v: %v", ip, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonAddDefenderEvent(ctx context.Context, ip string, score int, tx *sql.Tx) error {
|
||||
q := getAddDefenderEventQuery()
|
||||
stmt, err := tx.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, util.GetTimeAsMsSinceEpoch(time.Now()), score, ip)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to add defender event for %#v: %v", ip, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonCleanupDefenderHosts(from int64, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getDefenderHostsCleanupQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, util.GetTimeAsMsSinceEpoch(time.Now()), from)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to cleanup defender hosts: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonCleanupDefenderEvents(from int64, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getDefenderEventsCleanupQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, from)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to cleanup defender events: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func getShareFromDbRow(row sqlScanner) (Share, error) {
|
||||
var share Share
|
||||
var description, password, allowFrom, paths sql.NullString
|
||||
|
@ -1449,6 +1710,61 @@ func getUserWithVirtualFolders(ctx context.Context, user User, dbHandle sqlQueri
|
|||
return users[0], err
|
||||
}
|
||||
|
||||
func getDefenderHostsWithScores(ctx context.Context, hosts []*DefenderEntry, from int64, idForScores []int64,
|
||||
dbHandle sqlQuerier) (
|
||||
[]*DefenderEntry,
|
||||
error,
|
||||
) {
|
||||
if len(idForScores) == 0 {
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
hostsWithScores := make(map[int64]int)
|
||||
q := getDefenderEventsQuery(idForScores)
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, from)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to get score for hosts with id %+v: %v", idForScores, err)
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var hostID int64
|
||||
var score int
|
||||
err = rows.Scan(&hostID, &score)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error scanning host score row: %v", err)
|
||||
return hosts, err
|
||||
}
|
||||
if score > 0 {
|
||||
hostsWithScores[hostID] = score
|
||||
}
|
||||
}
|
||||
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return hosts, err
|
||||
}
|
||||
|
||||
result := make([]*DefenderEntry, 0, len(hosts))
|
||||
|
||||
for idx := range hosts {
|
||||
hosts[idx].Score = hostsWithScores[hosts[idx].ID]
|
||||
if hosts[idx].Score > 0 || !hosts[idx].BanTime.IsZero() {
|
||||
result = append(result, hosts[idx])
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) {
|
||||
if len(users) == 0 {
|
||||
return users, nil
|
||||
|
|
|
@ -28,6 +28,8 @@ DROP TABLE IF EXISTS "{{admins}}";
|
|||
DROP TABLE IF EXISTS "{{folders}}";
|
||||
DROP TABLE IF EXISTS "{{shares}}";
|
||||
DROP TABLE IF EXISTS "{{users}}";
|
||||
DROP TABLE IF EXISTS "{{defender_events}}";
|
||||
DROP TABLE IF EXISTS "{{defender_hosts}}";
|
||||
DROP TABLE IF EXISTS "{{schema_version}}";
|
||||
`
|
||||
sqliteInitialSQL = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "version" integer NOT NULL);
|
||||
|
@ -86,6 +88,19 @@ ALTER TABLE "{{admins}}" DROP COLUMN "last_login";
|
|||
CREATE INDEX "{{prefix}}shares_user_id_idx" ON "{{shares}}" ("user_id");
|
||||
`
|
||||
sqliteV14DownSQL = `DROP TABLE "{{shares}}";`
|
||||
sqliteV15SQL = `CREATE TABLE "{{defender_hosts}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"ip" varchar(50) NOT NULL UNIQUE, "ban_time" bigint NOT NULL, "updated_at" bigint NOT NULL);
|
||||
CREATE TABLE "{{defender_events}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "date_time" bigint NOT NULL,
|
||||
"score" integer NOT NULL, "host_id" integer NOT NULL REFERENCES "{{defender_hosts}}" ("id") ON DELETE CASCADE
|
||||
DEFERRABLE INITIALLY DEFERRED);
|
||||
CREATE INDEX "{{prefix}}defender_hosts_updated_at_idx" ON "{{defender_hosts}}" ("updated_at");
|
||||
CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("ban_time");
|
||||
CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time");
|
||||
CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id");
|
||||
`
|
||||
sqliteV15DownSQL = `DROP TABLE "{{defender_events}}";
|
||||
DROP TABLE "{{defender_hosts}}";
|
||||
`
|
||||
)
|
||||
|
||||
// SQLiteProvider auth provider for SQLite database
|
||||
|
@ -308,6 +323,38 @@ func (p *SQLiteProvider) updateShareLastUse(shareID string, numTokens int) error
|
|||
return sqlCommonUpdateShareLastUse(shareID, numTokens, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) getDefenderHosts(from int64, limit int) ([]*DefenderEntry, error) {
|
||||
return sqlCommonGetDefenderHosts(from, limit, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) getDefenderHostByIP(ip string, from int64) (*DefenderEntry, error) {
|
||||
return sqlCommonGetDefenderHostByIP(ip, from, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) isDefenderHostBanned(ip string) (*DefenderEntry, error) {
|
||||
return sqlCommonIsDefenderHostBanned(ip, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) updateDefenderBanTime(ip string, minutes int) error {
|
||||
return sqlCommonDefenderIncrementBanTime(ip, minutes, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) deleteDefenderHost(ip string) error {
|
||||
return sqlCommonDeleteDefenderHost(ip, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) addDefenderEvent(ip string, score int) error {
|
||||
return sqlCommonAddDefenderHostAndEvent(ip, score, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) setDefenderBanTime(ip string, banTime int64) error {
|
||||
return sqlCommonSetDefenderBanTime(ip, banTime, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) cleanupDefender(from int64) error {
|
||||
return sqlCommonDefenderCleanup(from, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p *SQLiteProvider) close() error {
|
||||
return p.dbHandle.Close()
|
||||
}
|
||||
|
@ -359,6 +406,8 @@ func (p *SQLiteProvider) migrateDatabase() error {
|
|||
return updateSQLiteDatabaseFromV12(p.dbHandle)
|
||||
case version == 13:
|
||||
return updateSQLiteDatabaseFromV13(p.dbHandle)
|
||||
case version == 14:
|
||||
return updateSQLiteDatabaseFromV14(p.dbHandle)
|
||||
default:
|
||||
if version > sqlDatabaseVersion {
|
||||
providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version,
|
||||
|
@ -381,6 +430,8 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error {
|
|||
}
|
||||
|
||||
switch dbVersion.Version {
|
||||
case 15:
|
||||
return downgradeSQLiteDatabaseFromV15(p.dbHandle)
|
||||
case 14:
|
||||
return downgradeSQLiteDatabaseFromV14(p.dbHandle)
|
||||
case 13:
|
||||
|
@ -402,6 +453,8 @@ func (p *SQLiteProvider) resetDatabase() error {
|
|||
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
|
||||
sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys)
|
||||
sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares)
|
||||
sql = strings.ReplaceAll(sql, "{{defender_events}}", sqlTableDefenderEvents)
|
||||
sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0)
|
||||
}
|
||||
|
||||
|
@ -427,7 +480,21 @@ func updateSQLiteDatabaseFromV12(dbHandle *sql.DB) error {
|
|||
}
|
||||
|
||||
func updateSQLiteDatabaseFromV13(dbHandle *sql.DB) error {
|
||||
return updateSQLiteDatabaseFrom13To14(dbHandle)
|
||||
if err := updateSQLiteDatabaseFrom13To14(dbHandle); err != nil {
|
||||
return err
|
||||
}
|
||||
return updateSQLiteDatabaseFromV14(dbHandle)
|
||||
}
|
||||
|
||||
func updateSQLiteDatabaseFromV14(dbHandle *sql.DB) error {
|
||||
return updateSQLiteDatabaseFrom14To15(dbHandle)
|
||||
}
|
||||
|
||||
func downgradeSQLiteDatabaseFromV15(dbHandle *sql.DB) error {
|
||||
if err := downgradeSQLiteDatabaseFrom15To14(dbHandle); err != nil {
|
||||
return err
|
||||
}
|
||||
return downgradeSQLiteDatabaseFromV14(dbHandle)
|
||||
}
|
||||
|
||||
func downgradeSQLiteDatabaseFromV14(dbHandle *sql.DB) error {
|
||||
|
@ -464,6 +531,23 @@ func updateSQLiteDatabaseFrom13To14(dbHandle *sql.DB) error {
|
|||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 14)
|
||||
}
|
||||
|
||||
func updateSQLiteDatabaseFrom14To15(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("updating database version: 14 -> 15")
|
||||
providerLog(logger.LevelInfo, "updating database version: 14 -> 15")
|
||||
sql := strings.ReplaceAll(sqliteV15SQL, "{{defender_events}}", sqlTableDefenderEvents)
|
||||
sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
|
||||
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 15)
|
||||
}
|
||||
|
||||
func downgradeSQLiteDatabaseFrom15To14(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database version: 15 -> 14")
|
||||
providerLog(logger.LevelInfo, "downgrading database version: 15 -> 14")
|
||||
sql := strings.ReplaceAll(sqliteV15DownSQL, "{{defender_events}}", sqlTableDefenderEvents)
|
||||
sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
|
||||
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 14)
|
||||
}
|
||||
|
||||
func downgradeSQLiteDatabaseFrom14To13(dbHandle *sql.DB) error {
|
||||
logger.InfoToConsole("downgrading database version: 14 -> 13")
|
||||
providerLog(logger.LevelInfo, "downgrading database version: 14 -> 13")
|
||||
|
|
|
@ -31,6 +31,79 @@ func getSQLPlaceholders() []string {
|
|||
return placeholders
|
||||
}
|
||||
|
||||
func getAddDefenderHostQuery() string {
|
||||
if config.Driver == MySQLDataProviderName {
|
||||
return fmt.Sprintf("INSERT INTO %v (`ip`,`updated_at`,`ban_time`) VALUES (%v,%v,0) ON DUPLICATE KEY UPDATE `updated_at`=VALUES(`updated_at`)",
|
||||
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
|
||||
}
|
||||
return fmt.Sprintf(`INSERT INTO %v (ip,updated_at,ban_time) VALUES (%v,%v,0) ON CONFLICT (ip) DO UPDATE SET updated_at = EXCLUDED.updated_at RETURNING id`,
|
||||
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
|
||||
}
|
||||
|
||||
func getAddDefenderEventQuery() string {
|
||||
return fmt.Sprintf(`INSERT INTO %v (date_time,score,host_id) VALUES (%v,%v,(SELECT id from %v WHERE ip = %v))`,
|
||||
sqlTableDefenderEvents, sqlPlaceholders[0], sqlPlaceholders[1], sqlTableDefenderHosts, sqlPlaceholders[2])
|
||||
}
|
||||
|
||||
func getDefenderHostsQuery() string {
|
||||
return fmt.Sprintf(`SELECT id,ip,ban_time FROM %v WHERE updated_at >= %v ORDER BY updated_at DESC LIMIT %v`,
|
||||
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
|
||||
}
|
||||
|
||||
func getDefenderHostQuery() string {
|
||||
return fmt.Sprintf(`SELECT id,ip,ban_time FROM %v WHERE ip = %v AND updated_at >= %v`,
|
||||
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
|
||||
}
|
||||
|
||||
func getDefenderEventsQuery(hostIDS []int64) string {
|
||||
var sb strings.Builder
|
||||
for _, hID := range hostIDS {
|
||||
if sb.Len() == 0 {
|
||||
sb.WriteString("(")
|
||||
} else {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(strconv.FormatInt(hID, 10))
|
||||
}
|
||||
if sb.Len() > 0 {
|
||||
sb.WriteString(")")
|
||||
} else {
|
||||
sb.WriteString("(0)")
|
||||
}
|
||||
return fmt.Sprintf(`SELECT host_id,SUM(score) FROM %v WHERE date_time >= %v AND host_id IN %v GROUP BY host_id`,
|
||||
sqlTableDefenderEvents, sqlPlaceholders[0], sb.String())
|
||||
}
|
||||
|
||||
func getDefenderIsHostBannedQuery() string {
|
||||
return fmt.Sprintf(`SELECT id FROM %v WHERE ip = %v AND ban_time >= %v`,
|
||||
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
|
||||
}
|
||||
|
||||
func getDefenderIncrementBanTimeQuery() string {
|
||||
return fmt.Sprintf(`UPDATE %v SET ban_time = ban_time + %v WHERE ip = %v`,
|
||||
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
|
||||
}
|
||||
|
||||
func getDefenderSetBanTimeQuery() string {
|
||||
return fmt.Sprintf(`UPDATE %v SET ban_time = %v WHERE ip = %v`,
|
||||
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
|
||||
}
|
||||
|
||||
func getDeleteDefenderHostQuery() string {
|
||||
return fmt.Sprintf(`DELETE FROM %v WHERE ip = %v`, sqlTableDefenderHosts, sqlPlaceholders[0])
|
||||
}
|
||||
|
||||
func getDefenderHostsCleanupQuery() string {
|
||||
return fmt.Sprintf(`DELETE FROM %v WHERE ban_time < %v AND NOT EXISTS (
|
||||
SELECT id FROM %v WHERE %v.host_id = %v.id AND %v.date_time > %v)`,
|
||||
sqlTableDefenderHosts, sqlPlaceholders[0], sqlTableDefenderEvents, sqlTableDefenderEvents, sqlTableDefenderHosts,
|
||||
sqlTableDefenderEvents, sqlPlaceholders[1])
|
||||
}
|
||||
|
||||
func getDefenderEventsCleanupQuery() string {
|
||||
return fmt.Sprintf(`DELETE FROM %v WHERE date_time < %v`, sqlTableDefenderEvents, sqlPlaceholders[0])
|
||||
}
|
||||
|
||||
func getAdminByUsernameQuery() string {
|
||||
return fmt.Sprintf(`SELECT %v FROM %v WHERE username = %v`, selectAdminFields, sqlTableAdmins, sqlPlaceholders[0])
|
||||
}
|
||||
|
|
|
@ -26,7 +26,13 @@ If an already banned client tries to log in again, its ban time will be incremen
|
|||
|
||||
The `ban_time_increment` is calculated as percentage of `ban_time`, so if `ban_time` is 30 minutes and `ban_time_increment` is 50 the host will be banned for additionally 15 minutes. You can also specify values greater than 100 for `ban_time_increment` if you want to increase the penalty for already banned hosts.
|
||||
|
||||
The `defender` will keep in memory both the host scores and the banned hosts, you can limit the memory usage using the `entries_soft_limit` and `entries_hard_limit` configuration keys.
|
||||
SFTPGo can store host scores and banned hosts in memory or within the configured data provider according to the `driver` set in the `defender` configuration section. The available drivers are `memory` and `provider`.
|
||||
The `provider` driver is useful if you want to share the defender data across multiple SFTPGo instances and it requires a shared or distributed data provider: `MySQL`, `PostgreSQL` and `CockroachDB` are supported.
|
||||
If you set the `provider` driver, the defender implementation may do many database queries (at least one query every time a new client connects to check if it is banned), if you have a single SFTPGo instance the `memory` driver is recommended.
|
||||
|
||||
For the `memory` driver, you can limit the memory usage using the `entries_soft_limit` and `entries_hard_limit` configuration keys.
|
||||
|
||||
The `provider` driver will periodically clean up expired hosts and events.
|
||||
|
||||
Using the REST API you can:
|
||||
|
||||
|
@ -58,6 +64,4 @@ Here is a small example:
|
|||
}
|
||||
```
|
||||
|
||||
These list will be loaded in memory for faster lookups. The REST API queries "live" data and not these lists.
|
||||
|
||||
The `defender` is optimized for fast and time constant lookups however as it keeps all the lists and the entries in memory you should carefully measure the memory requirements for your use case.
|
||||
These list will be always loaded in memory (even if you use the `provider` driver) for faster lookups. The REST API queries "live" data and not these lists.
|
||||
|
|
|
@ -75,6 +75,7 @@ The configuration file contains the following sections:
|
|||
- `max_per_host_connections`, integer. Maximum number of concurrent client connections from the same host (IP). If the defender is enabled, exceeding this limit will generate `score_limit_exceeded` events and thus hosts that repeatedly exceed the max allowed connections can be automatically blocked. 0 means unlimited. Default: 20.
|
||||
- `defender`, struct containing the defender configuration. See [Defender](./defender.md) for more details.
|
||||
- `enabled`, boolean. Default `false`.
|
||||
- `driver`, string. Supported drivers are `memory` and `provider`. The `provider` driver will use the configured data provider to store defender events and it is supported for `MySQL`, `PostgreSQL` and `CockroachDB` data providers. Using the `provider` driver you can share the defender events among multiple SFTPGO instances. For a single instance the `memory` driver will be much faster. Default: `memory`.
|
||||
- `ban_time`, integer. Ban time in minutes.
|
||||
- `ban_time_increment`, integer. Ban time increment, as a percentage, if a banned host tries to connect again.
|
||||
- `threshold`, integer. Threshold value for banning a client.
|
||||
|
@ -82,8 +83,8 @@ The configuration file contains the following sections:
|
|||
- `score_valid`, integer. Score for valid login attempts, eg. user accounts that exist.
|
||||
- `score_limit_exceeded`, integer. Score for hosts that exceeded the configured rate limits or the maximum, per-host, allowed connections.
|
||||
- `observation_time`, integer. Defines the time window, in minutes, for tracking client errors. A host is banned if it has exceeded the defined threshold during the last observation time minutes.
|
||||
- `entries_soft_limit`, integer.
|
||||
- `entries_hard_limit`, integer. The number of banned IPs and host scores kept in memory will vary between the soft and hard limit.
|
||||
- `entries_soft_limit`, integer. Ignored for `provider` driver. Default: 100.
|
||||
- `entries_hard_limit`, integer. The number of banned IPs and host scores kept in memory will vary between the soft and hard limit for `memory` driver. If you use the `provider` driver, this setting will limit the number of entries to return when you ask for the entire host list from the defender. Default: 150.
|
||||
- `safelist_file`, string. Path to a file containing a list of ip addresses and/or networks to never ban.
|
||||
- `blocklist_file`, string. Path to a file containing a list of ip addresses and/or networks to always ban. The lists can be reloaded on demand sending a `SIGHUP` signal on Unix based systems and a `paramchange` request to the running service on Windows. An host that is already banned will not be automatically unbanned if you put it inside the safe list, you have to unban it using the REST API.
|
||||
- `rate_limiters`, list of structs containing the rate limiters configuration. Take a look [here](./rate-limiting.md) for more details. Each struct has the following fields:
|
||||
|
|
15
go.mod
15
go.mod
|
@ -7,7 +7,7 @@ require (
|
|||
github.com/Azure/azure-storage-blob-go v0.14.0
|
||||
github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962
|
||||
github.com/alexedwards/argon2id v0.0.0-20211130144151-3585854a6387
|
||||
github.com/aws/aws-sdk-go v1.42.23
|
||||
github.com/aws/aws-sdk-go v1.42.25
|
||||
github.com/cockroachdb/cockroach-go/v2 v2.2.5
|
||||
github.com/eikenb/pipeat v0.0.0-20210603033007-44fc3ffce52b
|
||||
github.com/fclairamb/ftpserverlib v0.16.0
|
||||
|
@ -25,7 +25,7 @@ require (
|
|||
github.com/hashicorp/go-retryablehttp v0.7.0
|
||||
github.com/jlaffaye/ftp v0.0.0-20201112195030-9aae4d151126
|
||||
github.com/klauspost/compress v1.13.6
|
||||
github.com/lestrrat-go/jwx v1.2.13
|
||||
github.com/lestrrat-go/jwx v1.2.14
|
||||
github.com/lib/pq v1.10.4
|
||||
github.com/lithammer/shortuuid/v3 v3.0.7
|
||||
github.com/mattn/go-sqlite3 v1.14.9
|
||||
|
@ -36,11 +36,11 @@ require (
|
|||
github.com/pkg/sftp v1.13.5-0.20211217081921-1849af66afae
|
||||
github.com/pquerna/otp v1.3.0
|
||||
github.com/prometheus/client_golang v1.11.0
|
||||
github.com/rs/cors v1.8.0
|
||||
github.com/rs/cors v1.8.2
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/rs/zerolog v1.26.2-0.20211219225053-665519c4da50
|
||||
github.com/shirou/gopsutil/v3 v3.21.11
|
||||
github.com/spf13/afero v1.6.0
|
||||
github.com/spf13/afero v1.7.0
|
||||
github.com/spf13/cobra v1.3.0
|
||||
github.com/spf13/viper v1.10.1
|
||||
github.com/stretchr/testify v1.7.0
|
||||
|
@ -85,6 +85,7 @@ require (
|
|||
github.com/golang/protobuf v1.5.2 // indirect
|
||||
github.com/google/go-cmp v0.5.6 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.1.1 // indirect
|
||||
github.com/googleapis/google-cloud-go-testing v0.0.0-20210719221736-1c9a4c676720 // indirect
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/hashicorp/yamux v0.0.0-20211028200310-0bc27b27de87 // indirect
|
||||
|
@ -104,7 +105,7 @@ require (
|
|||
github.com/mattn/go-ieproxy v0.0.1 // indirect
|
||||
github.com/mattn/go-isatty v0.0.14 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
|
||||
github.com/miekg/dns v1.1.43 // indirect
|
||||
github.com/miekg/dns v1.1.45 // indirect
|
||||
github.com/minio/sha256-simd v1.0.0 // indirect
|
||||
github.com/mitchellh/go-testing-interface v1.14.1 // indirect
|
||||
github.com/mitchellh/mapstructure v1.4.3 // indirect
|
||||
|
@ -125,11 +126,13 @@ require (
|
|||
github.com/tklauser/numcpus v0.3.0 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.2 // indirect
|
||||
go.opencensus.io v0.23.0 // indirect
|
||||
golang.org/x/mod v0.5.1 // indirect
|
||||
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 // indirect
|
||||
golang.org/x/text v0.3.7 // indirect
|
||||
golang.org/x/tools v0.1.8 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/genproto v0.0.0-20211208223120-3a66f561d7aa // indirect
|
||||
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb // indirect
|
||||
gopkg.in/ini.v1 v1.66.2 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
|
||||
|
|
45
go.sum
45
go.sum
|
@ -4,6 +4,7 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT
|
|||
cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU=
|
||||
cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU=
|
||||
cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY=
|
||||
cloud.google.com/go v0.44.3/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY=
|
||||
cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc=
|
||||
cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0=
|
||||
cloud.google.com/go v0.50.0/go.mod h1:r9sluTvynVuxRIOHXQEHMFffphuXHOMZMycpNR5e6To=
|
||||
|
@ -60,6 +61,7 @@ cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0Zeo
|
|||
cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk=
|
||||
cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs=
|
||||
cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
|
||||
cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo=
|
||||
cloud.google.com/go/storage v1.16.1/go.mod h1:LaNorbty3ehnU3rEjXSNV/NRgQA0O8Y+uh6bPe5UOk4=
|
||||
cloud.google.com/go/storage v1.18.2 h1:5NQw6tOn3eMm0oE8vTkfjau18kjL79FlMjy/CHTpmoY=
|
||||
cloud.google.com/go/storage v1.18.2/go.mod h1:AiIj7BWXyhO5gGVmYJ+S8tbkCx3yb0IMjua8Aw4naVM=
|
||||
|
@ -138,8 +140,8 @@ github.com/aws/aws-sdk-go v1.15.27/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZo
|
|||
github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro=
|
||||
github.com/aws/aws-sdk-go v1.38.68/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro=
|
||||
github.com/aws/aws-sdk-go v1.40.34/go.mod h1:585smgzpB/KqRA+K3y/NL/oYRqQvpNJYvLm+LY1U59Q=
|
||||
github.com/aws/aws-sdk-go v1.42.23 h1:V0V5hqMEyVelgpu1e4gMPVCJ+KhmscdNxP/NWP1iCOA=
|
||||
github.com/aws/aws-sdk-go v1.42.23/go.mod h1:gyRszuZ/icHmHAVE4gc/r+cfCmhA1AD+vqfWbgI+eHs=
|
||||
github.com/aws/aws-sdk-go v1.42.25 h1:BbdvHAi+t9LRiaYUyd53noq9jcaAcfzOhSVbKfr6Avs=
|
||||
github.com/aws/aws-sdk-go v1.42.25/go.mod h1:gyRszuZ/icHmHAVE4gc/r+cfCmhA1AD+vqfWbgI+eHs=
|
||||
github.com/aws/aws-sdk-go-v2 v1.7.0/go.mod h1:tb9wi5s61kTDA5qCkcDbt3KRVV74GGslQkl/DRdX/P4=
|
||||
github.com/aws/aws-sdk-go-v2 v1.9.0/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.7.0/go.mod h1:w9+nMZ7soXCe5nT46Ri354SNhXDQ6v+V5wqDjnZE+GY=
|
||||
|
@ -264,7 +266,6 @@ github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWp
|
|||
github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU=
|
||||
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do=
|
||||
github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
|
||||
github.com/go-chi/chi/v5 v5.0.4/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
|
||||
github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8=
|
||||
|
@ -289,9 +290,7 @@ github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG
|
|||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
|
||||
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
|
||||
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
|
||||
github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
|
||||
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
|
||||
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
|
@ -407,6 +406,9 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m
|
|||
github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0=
|
||||
github.com/googleapis/gax-go/v2 v2.1.1 h1:dp3bWCh+PPO1zjRRiCSczJav13sBvG4UhNyVTa1KqdU=
|
||||
github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM=
|
||||
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
|
||||
github.com/googleapis/google-cloud-go-testing v0.0.0-20210719221736-1c9a4c676720 h1:zC34cGQu69FG7qzJ3WiKW244WfhDC3xxYMeNOX2gtUQ=
|
||||
github.com/googleapis/google-cloud-go-testing v0.0.0-20210719221736-1c9a4c676720/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
|
||||
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
||||
github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
|
||||
|
@ -522,7 +524,6 @@ github.com/jmoiron/sqlx v1.3.1/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXL
|
|||
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
|
||||
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
|
||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
|
@ -557,7 +558,6 @@ github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
|
|||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
|
||||
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
|
||||
github.com/lestrrat-go/backoff/v2 v2.0.8 h1:oNb5E5isby2kiro9AgdHLv5N5tint1AnDVVf2E2un5A=
|
||||
github.com/lestrrat-go/backoff/v2 v2.0.8/go.mod h1:rHP/q/r9aT27n24JQLa7JhSQZCKBBOiM/uP402WwN8Y=
|
||||
|
@ -569,8 +569,8 @@ github.com/lestrrat-go/httpcc v1.0.0/go.mod h1:tGS/u00Vh5N6FHNkExqGGNId8e0Big+++
|
|||
github.com/lestrrat-go/iter v1.0.1 h1:q8faalr2dY6o8bV45uwrxq12bRa1ezKrB6oM9FUgN4A=
|
||||
github.com/lestrrat-go/iter v1.0.1/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbqPDrJ/OJc=
|
||||
github.com/lestrrat-go/jwx v1.2.6/go.mod h1:tJuGuAI3LC71IicTx82Mz1n3w9woAs2bYJZpkjJQ5aU=
|
||||
github.com/lestrrat-go/jwx v1.2.13 h1:GxuOPfAz4+nzL98WaKxBxEEZ9b7qmyDetMGfBm9yVvE=
|
||||
github.com/lestrrat-go/jwx v1.2.13/go.mod h1:3Q3Re8TaOcVTdpx4Tvz++OWmryDklihTDqrrwQiyS2A=
|
||||
github.com/lestrrat-go/jwx v1.2.14 h1:69OeaiFKCTn8xDmBGzHTgv/GBoO1LJcXw99GfYCDKzg=
|
||||
github.com/lestrrat-go/jwx v1.2.14/go.mod h1:3Q3Re8TaOcVTdpx4Tvz++OWmryDklihTDqrrwQiyS2A=
|
||||
github.com/lestrrat-go/option v1.0.0 h1:WqAWL8kh8VcSoD6xjSH34/1m8yxluXQbDeKNfvFeEO4=
|
||||
github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
|
||||
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
|
@ -619,8 +619,8 @@ github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3N
|
|||
github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
|
||||
github.com/miekg/dns v1.1.27/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
|
||||
github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI=
|
||||
github.com/miekg/dns v1.1.43 h1:JKfpVSCB84vrAmHzyrsxB5NAr5kLoMXZArPSw7Qlgyg=
|
||||
github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4=
|
||||
github.com/miekg/dns v1.1.45 h1:g5fRIhm9nx7g8osrAvgb16QJfmyMsyOCb+J7LSv+Qzk=
|
||||
github.com/miekg/dns v1.1.45/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
|
||||
github.com/minio/highwayhash v1.0.1/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY=
|
||||
github.com/minio/sha256-simd v1.0.0 h1:v1ta+49hkWZyvaKwrQB8elexRqm6Y0aMLjCNsrYxo6g=
|
||||
github.com/minio/sha256-simd v1.0.0/go.mod h1:OuYzVNI5vcoYIAmbIvHPl3N3jUzVedXbKy5RFepssQM=
|
||||
|
@ -682,6 +682,7 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
|||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA=
|
||||
github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI=
|
||||
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
|
||||
github.com/pkg/sftp v1.13.5-0.20211217081921-1849af66afae h1:J8MHmz3LSjRtoR4SKiPq8BNo3DacJl5kQRjJeWilkUI=
|
||||
github.com/pkg/sftp v1.13.5-0.20211217081921-1849af66afae/go.mod h1:wHDZ0IZX6JcBYRK1TH9bcVq8G7TLpVHYIGJRFnmPfxg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
|
@ -719,8 +720,8 @@ github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
|
|||
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so=
|
||||
github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM=
|
||||
github.com/rs/cors v1.8.2 h1:KCooALfAYGs415Cwu5ABvv9n9509fSiG5SQJn/AQo4U=
|
||||
github.com/rs/cors v1.8.2/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
|
||||
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
||||
github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4=
|
||||
github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
|
@ -752,8 +753,9 @@ github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9
|
|||
github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
|
||||
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||
github.com/spf13/afero v1.3.3/go.mod h1:5KUK8ByomD5Ti5Artl0RtHeI5pTF7MIDuXL3yY520V4=
|
||||
github.com/spf13/afero v1.6.0 h1:xoax2sJ2DT8S8xA2paPFjDCScCNeWsg75VG0DLRreiY=
|
||||
github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I=
|
||||
github.com/spf13/afero v1.7.0 h1:xc1yh8vgcNB8yQ+UqY4cpD56Ogo573e+CJ/C4YmMFTg=
|
||||
github.com/spf13/afero v1.7.0/go.mod h1:CtAatgMJh6bJEIs48Ay/FOnkljP3WeGUG0MC1RfAqwo=
|
||||
github.com/spf13/cast v1.4.1 h1:s0hze+J0196ZfEMTs80N7UlFt0BDuQ7Q+JDnHiMWKdA=
|
||||
github.com/spf13/cast v1.4.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
|
||||
github.com/spf13/cobra v1.3.0 h1:R7cSvGu+Vv+qX0gW5R/85dx2kmmJT5z5NM8ifdYjdn0=
|
||||
|
@ -800,6 +802,7 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de
|
|||
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPRg=
|
||||
github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
|
||||
|
@ -879,6 +882,8 @@ golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
|||
golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.5.0/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro=
|
||||
golang.org/x/mod v0.5.1 h1:OJxoQ/rynoF0dcCdI7cLPktw/hR2cueqYfjm43oqK38=
|
||||
golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
|
@ -967,11 +972,13 @@ golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210223095934-7937bea0104d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210225134936-a50acf3fe073/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210503080704-8803ae5d1324/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
@ -991,6 +998,7 @@ golang.org/x/sys v0.0.0-20210917161153-d61c044b1678/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211013075003-97ac67df715c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211205182925-97ca703d548d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
@ -1080,6 +1088,9 @@ golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
|||
golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.8 h1:P1HhGGuLW4aAclzjtmJdf0mJOjVUZUzOTqkAkWL+l6w=
|
||||
golang.org/x/tools v0.1.8/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU=
|
||||
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
@ -1172,6 +1183,7 @@ google.golang.org/genproto v0.0.0-20201214200347-8c77b98c765d/go.mod h1:FWY/as6D
|
|||
google.golang.org/genproto v0.0.0-20210108203827-ffc7fda8c3d7/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||
google.golang.org/genproto v0.0.0-20210126160654-44e461bb6506/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||
google.golang.org/genproto v0.0.0-20210222152913-aa3ee6e6a81c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||
google.golang.org/genproto v0.0.0-20210226172003-ab064af71705/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||
google.golang.org/genproto v0.0.0-20210303154014-9728d6b83eeb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||
google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||
google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||
|
@ -1205,8 +1217,9 @@ google.golang.org/genproto v0.0.0-20211118181313-81c1377c94b1/go.mod h1:5CzLGKJ6
|
|||
google.golang.org/genproto v0.0.0-20211129164237-f09f9a12af12/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
|
||||
google.golang.org/genproto v0.0.0-20211203200212-54befc351ae9/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
|
||||
google.golang.org/genproto v0.0.0-20211206160659-862468c7d6e0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
|
||||
google.golang.org/genproto v0.0.0-20211208223120-3a66f561d7aa h1:I0YcKz0I7OAhddo7ya8kMnvprhcWM045PmkBdMO9zN0=
|
||||
google.golang.org/genproto v0.0.0-20211208223120-3a66f561d7aa/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
|
||||
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb h1:ZrsicilzPCS/Xr8qtBZZLpy4P9TYXAfl49ctG1/5tgw=
|
||||
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
|
||||
google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
||||
|
@ -1261,8 +1274,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
|
|||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||
gopkg.in/gcfg.v1 v1.2.3/go.mod h1:yesOnuUOFQAhST5vPY4nbZsb/huCgGGXlipJsBn0b3o=
|
||||
gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=
|
||||
gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
|
||||
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
|
||||
gopkg.in/ini.v1 v1.66.2 h1:XfR1dOYubytKy4Shzc2LHrrGhU0lDCfDGG1yLPmpgsI=
|
||||
gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
|
|
|
@ -11,13 +11,19 @@ import (
|
|||
"github.com/go-chi/render"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/common"
|
||||
"github.com/drakkan/sftpgo/v2/dataprovider"
|
||||
"github.com/drakkan/sftpgo/v2/util"
|
||||
)
|
||||
|
||||
func getDefenderHosts(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
hosts := common.GetDefenderHosts()
|
||||
hosts, err := common.GetDefenderHosts()
|
||||
if err != nil {
|
||||
sendAPIResponse(w, r, err, "", getRespStatus(err))
|
||||
return
|
||||
}
|
||||
if hosts == nil {
|
||||
render.JSON(w, r, make([]common.DefenderEntry, 0))
|
||||
render.JSON(w, r, make([]dataprovider.DefenderEntry, 0))
|
||||
return
|
||||
}
|
||||
render.JSON(w, r, hosts)
|
||||
|
@ -64,7 +70,15 @@ func getBanTime(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
banStatus := make(map[string]*string)
|
||||
|
||||
banTime := common.GetDefenderBanTime(ip)
|
||||
banTime, err := common.GetDefenderBanTime(ip)
|
||||
if err != nil {
|
||||
if _, ok := err.(*util.RecordNotFoundError); ok {
|
||||
banTime = nil
|
||||
} else {
|
||||
sendAPIResponse(w, r, err, "", getRespStatus(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
var banTimeString *string
|
||||
if banTime != nil {
|
||||
rfc3339String := banTime.UTC().Format(time.RFC3339)
|
||||
|
@ -84,8 +98,18 @@ func getScore(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
score, err := common.GetDefenderScore(ip)
|
||||
if err != nil {
|
||||
if _, ok := err.(*util.RecordNotFoundError); ok {
|
||||
score = 0
|
||||
} else {
|
||||
sendAPIResponse(w, r, err, "", getRespStatus(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
scoreStatus := make(map[string]int)
|
||||
scoreStatus["score"] = common.GetDefenderScore(ip)
|
||||
scoreStatus["score"] = score
|
||||
|
||||
render.JSON(w, r, scoreStatus)
|
||||
}
|
||||
|
|
|
@ -85,7 +85,7 @@ func getRespStatus(err error) int {
|
|||
if os.IsPermission(err) {
|
||||
return http.StatusForbidden
|
||||
}
|
||||
if errors.Is(err, plugin.ErrNoSearcher) {
|
||||
if errors.Is(err, plugin.ErrNoSearcher) || errors.Is(err, dataprovider.ErrNotImplemented) {
|
||||
return http.StatusNotImplemented
|
||||
}
|
||||
return http.StatusInternalServerError
|
||||
|
|
|
@ -83,7 +83,9 @@ const (
|
|||
updateUsedQuotaCompatPath = "/api/v2/quota-update"
|
||||
updateFolderUsedQuotaCompatPath = "/api/v2/folder-quota-update"
|
||||
defenderHosts = "/api/v2/defender/hosts"
|
||||
defenderBanTime = "/api/v2/defender/bantime"
|
||||
defenderUnban = "/api/v2/defender/unban"
|
||||
defenderScore = "/api/v2/defender/score"
|
||||
versionPath = "/api/v2/version"
|
||||
logoutPath = "/api/v2/logout"
|
||||
userPwdPath = "/api/v2/user/changepwd"
|
||||
|
@ -4292,106 +4294,118 @@ func TestDumpdata(t *testing.T) {
|
|||
func TestDefenderAPI(t *testing.T) {
|
||||
oldConfig := config.GetCommonConfig()
|
||||
|
||||
cfg := config.GetCommonConfig()
|
||||
cfg.DefenderConfig.Enabled = true
|
||||
cfg.DefenderConfig.Threshold = 3
|
||||
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||
drivers := []string{common.DefenderDriverMemory}
|
||||
if isDbDefenderSupported() {
|
||||
drivers = append(drivers, common.DefenderDriverProvider)
|
||||
}
|
||||
|
||||
err := common.Initialize(cfg)
|
||||
require.NoError(t, err)
|
||||
for _, driver := range drivers {
|
||||
cfg := config.GetCommonConfig()
|
||||
cfg.DefenderConfig.Enabled = true
|
||||
cfg.DefenderConfig.Driver = driver
|
||||
cfg.DefenderConfig.Threshold = 3
|
||||
cfg.DefenderConfig.ScoreLimitExceeded = 2
|
||||
|
||||
ip := "::1"
|
||||
err := common.Initialize(cfg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
response, _, err := httpdtest.GetBanTime(ip, http.StatusOK)
|
||||
require.NoError(t, err)
|
||||
banTime, ok := response["date_time"]
|
||||
require.True(t, ok)
|
||||
assert.Nil(t, banTime)
|
||||
ip := "::1"
|
||||
|
||||
hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, hosts, 0)
|
||||
response, _, err := httpdtest.GetBanTime(ip, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
banTime, ok := response["date_time"]
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, banTime)
|
||||
|
||||
response, _, err = httpdtest.GetScore(ip, http.StatusOK)
|
||||
require.NoError(t, err)
|
||||
score, ok := response["score"]
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, float64(0), score)
|
||||
hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, hosts, 0)
|
||||
|
||||
err = httpdtest.UnbanIP(ip, http.StatusNotFound)
|
||||
require.NoError(t, err)
|
||||
response, _, err = httpdtest.GetScore(ip, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
score, ok := response["score"]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, float64(0), score)
|
||||
|
||||
_, err = httpdtest.RemoveDefenderHostByIP(ip, http.StatusNotFound)
|
||||
require.NoError(t, err)
|
||||
err = httpdtest.UnbanIP(ip, http.StatusNotFound)
|
||||
assert.NoError(t, err)
|
||||
|
||||
common.AddDefenderEvent(ip, common.HostEventNoLoginTried)
|
||||
response, _, err = httpdtest.GetScore(ip, http.StatusOK)
|
||||
require.NoError(t, err)
|
||||
score, ok = response["score"]
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, float64(2), score)
|
||||
_, err = httpdtest.RemoveDefenderHostByIP(ip, http.StatusNotFound)
|
||||
assert.NoError(t, err)
|
||||
|
||||
hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
|
||||
require.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
host := hosts[0]
|
||||
common.AddDefenderEvent(ip, common.HostEventNoLoginTried)
|
||||
response, _, err = httpdtest.GetScore(ip, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
score, ok = response["score"]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, float64(2), score)
|
||||
|
||||
hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
host := hosts[0]
|
||||
assert.Empty(t, host.GetBanTime())
|
||||
assert.Equal(t, 2, host.Score)
|
||||
assert.Equal(t, ip, host.IP)
|
||||
}
|
||||
host, _, err := httpdtest.GetDefenderHostByIP(ip, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, host.GetBanTime())
|
||||
assert.Equal(t, 2, host.Score)
|
||||
assert.Equal(t, ip, host.IP)
|
||||
}
|
||||
host, _, err := httpdtest.GetDefenderHostByIP(ip, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, host.GetBanTime())
|
||||
assert.Equal(t, 2, host.Score)
|
||||
|
||||
common.AddDefenderEvent(ip, common.HostEventNoLoginTried)
|
||||
response, _, err = httpdtest.GetBanTime(ip, http.StatusOK)
|
||||
require.NoError(t, err)
|
||||
banTime, ok = response["date_time"]
|
||||
require.True(t, ok)
|
||||
assert.NotNil(t, banTime)
|
||||
hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
|
||||
require.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
host := hosts[0]
|
||||
common.AddDefenderEvent(ip, common.HostEventNoLoginTried)
|
||||
response, _, err = httpdtest.GetBanTime(ip, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
banTime, ok = response["date_time"]
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, banTime)
|
||||
hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, hosts, 1) {
|
||||
host := hosts[0]
|
||||
assert.NotEmpty(t, host.GetBanTime())
|
||||
assert.Equal(t, 0, host.Score)
|
||||
assert.Equal(t, ip, host.IP)
|
||||
}
|
||||
host, _, err = httpdtest.GetDefenderHostByIP(ip, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, host.GetBanTime())
|
||||
assert.Equal(t, 0, host.Score)
|
||||
assert.Equal(t, ip, host.IP)
|
||||
|
||||
err = httpdtest.UnbanIP(ip, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = httpdtest.UnbanIP(ip, http.StatusNotFound)
|
||||
assert.NoError(t, err)
|
||||
|
||||
host, _, err = httpdtest.GetDefenderHostByIP(ip, http.StatusNotFound)
|
||||
assert.NoError(t, err)
|
||||
|
||||
common.AddDefenderEvent(ip, common.HostEventNoLoginTried)
|
||||
common.AddDefenderEvent(ip, common.HostEventNoLoginTried)
|
||||
hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, hosts, 1)
|
||||
|
||||
_, err = httpdtest.RemoveDefenderHostByIP(ip, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
|
||||
host, _, err = httpdtest.GetDefenderHostByIP(ip, http.StatusNotFound)
|
||||
assert.NoError(t, err)
|
||||
_, err = httpdtest.RemoveDefenderHostByIP(ip, http.StatusNotFound)
|
||||
assert.NoError(t, err)
|
||||
|
||||
host, _, err = httpdtest.GetDefenderHostByIP("invalid_ip", http.StatusBadRequest)
|
||||
assert.NoError(t, err)
|
||||
_, err = httpdtest.RemoveDefenderHostByIP("invalid_ip", http.StatusBadRequest)
|
||||
assert.NoError(t, err)
|
||||
if driver == common.DefenderDriverProvider {
|
||||
err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Hour)))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
host, _, err = httpdtest.GetDefenderHostByIP(ip, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, host.GetBanTime())
|
||||
assert.Equal(t, 0, host.Score)
|
||||
|
||||
err = httpdtest.UnbanIP(ip, http.StatusOK)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = httpdtest.UnbanIP(ip, http.StatusNotFound)
|
||||
require.NoError(t, err)
|
||||
|
||||
host, _, err = httpdtest.GetDefenderHostByIP(ip, http.StatusNotFound)
|
||||
assert.NoError(t, err)
|
||||
|
||||
common.AddDefenderEvent(ip, common.HostEventNoLoginTried)
|
||||
common.AddDefenderEvent(ip, common.HostEventNoLoginTried)
|
||||
hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, hosts, 1)
|
||||
|
||||
_, err = httpdtest.RemoveDefenderHostByIP(ip, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
|
||||
host, _, err = httpdtest.GetDefenderHostByIP(ip, http.StatusNotFound)
|
||||
assert.NoError(t, err)
|
||||
_, err = httpdtest.RemoveDefenderHostByIP(ip, http.StatusNotFound)
|
||||
assert.NoError(t, err)
|
||||
|
||||
host, _, err = httpdtest.GetDefenderHostByIP("invalid_ip", http.StatusBadRequest)
|
||||
assert.NoError(t, err)
|
||||
_, err = httpdtest.RemoveDefenderHostByIP("invalid_ip", http.StatusBadRequest)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = common.Initialize(oldConfig)
|
||||
err := common.Initialize(oldConfig)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
@ -4407,6 +4421,53 @@ func TestDefenderAPIErrors(t *testing.T) {
|
|||
|
||||
err = httpdtest.UnbanIP("", http.StatusBadRequest)
|
||||
require.NoError(t, err)
|
||||
if isDbDefenderSupported() {
|
||||
oldConfig := config.GetCommonConfig()
|
||||
|
||||
cfg := config.GetCommonConfig()
|
||||
cfg.DefenderConfig.Enabled = true
|
||||
cfg.DefenderConfig.Driver = common.DefenderDriverProvider
|
||||
err := common.Initialize(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = dataprovider.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
ip := "127.1.1.2"
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, defenderHosts, nil)
|
||||
assert.NoError(t, err)
|
||||
setBearerForReq(req, token)
|
||||
rr := executeRequest(req)
|
||||
checkResponseCode(t, http.StatusInternalServerError, rr)
|
||||
|
||||
req, err = http.NewRequest(http.MethodGet, defenderBanTime+"?ip="+ip, nil)
|
||||
assert.NoError(t, err)
|
||||
setBearerForReq(req, token)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusInternalServerError, rr)
|
||||
|
||||
req, err = http.NewRequest(http.MethodGet, defenderScore+"?ip="+ip, nil)
|
||||
assert.NoError(t, err)
|
||||
setBearerForReq(req, token)
|
||||
rr = executeRequest(req)
|
||||
checkResponseCode(t, http.StatusInternalServerError, rr)
|
||||
|
||||
err = config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
providerConf := config.GetProviderConf()
|
||||
providerConf.CredentialsPath = credentialsPath
|
||||
err = os.RemoveAll(credentialsPath)
|
||||
assert.NoError(t, err)
|
||||
err = dataprovider.Initialize(providerConf, configDir, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = common.Initialize(oldConfig)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreShares(t *testing.T) {
|
||||
|
@ -15961,6 +16022,18 @@ func generateTOTPPasscode(secret string) (string, error) {
|
|||
})
|
||||
}
|
||||
|
||||
func isDbDefenderSupported() bool {
|
||||
// SQLite shares the implementation with other SQL-based provider but it makes no sense
|
||||
// to use it outside test cases
|
||||
switch dataprovider.GetProviderStatus().Driver {
|
||||
case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName,
|
||||
dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSecretDecryption(b *testing.B) {
|
||||
s := kms.NewPlainSecret("test data")
|
||||
s.SetAdditionalData("username")
|
||||
|
|
|
@ -799,8 +799,8 @@ func GetStatus(expectedStatusCode int) (httpd.ServicesStatus, []byte, error) {
|
|||
}
|
||||
|
||||
// GetDefenderHosts returns hosts that are banned or for which some violations have been detected
|
||||
func GetDefenderHosts(expectedStatusCode int) ([]common.DefenderEntry, []byte, error) {
|
||||
var response []common.DefenderEntry
|
||||
func GetDefenderHosts(expectedStatusCode int) ([]dataprovider.DefenderEntry, []byte, error) {
|
||||
var response []dataprovider.DefenderEntry
|
||||
var body []byte
|
||||
url, err := url.Parse(buildURLRelativeToBase(defenderHosts))
|
||||
if err != nil {
|
||||
|
@ -821,8 +821,8 @@ func GetDefenderHosts(expectedStatusCode int) ([]common.DefenderEntry, []byte, e
|
|||
}
|
||||
|
||||
// GetDefenderHostByIP returns the host with the given IP, if it exists
|
||||
func GetDefenderHostByIP(ip string, expectedStatusCode int) (common.DefenderEntry, []byte, error) {
|
||||
var host common.DefenderEntry
|
||||
func GetDefenderHostByIP(ip string, expectedStatusCode int) (dataprovider.DefenderEntry, []byte, error) {
|
||||
var host dataprovider.DefenderEntry
|
||||
var body []byte
|
||||
id := hex.EncodeToString([]byte(ip))
|
||||
resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(defenderHosts, id),
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
NFPM_VERSION=2.11.0
|
||||
NFPM_VERSION=2.11.2
|
||||
NFPM_ARCH=${NFPM_ARCH:-amd64}
|
||||
if [ -z ${SFTPGO_VERSION} ]
|
||||
then
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
"max_per_host_connections": 20,
|
||||
"defender": {
|
||||
"enabled": false,
|
||||
"driver": "memory",
|
||||
"ban_time": 30,
|
||||
"ban_time_increment": 50,
|
||||
"threshold": 15,
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package util
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ValidationError raised if input data is not valid
|
||||
type ValidationError struct {
|
||||
|
|
Loading…
Reference in a new issue