sftpgo-mirror/internal/dataprovider/iplist.go
Nicola Murino d94f80c8da
replace utils.Contains with slices.Contains
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-24 18:27:13 +02:00

494 lines
13 KiB
Go

// Copyright (C) 2019 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package dataprovider
import (
"encoding/json"
"fmt"
"net"
"net/netip"
"slices"
"strings"
"sync"
"sync/atomic"
"github.com/yl2chen/cidranger"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
)
const (
// maximum number of entries to match in memory
// if the list contains more elements than this limit a
// database query will be executed
ipListMemoryLimit = 15000
)
var (
inMemoryLists map[IPListType]*IPList
)
func init() {
inMemoryLists = map[IPListType]*IPList{}
}
// IPListType is the enumerable for the supported IP list types
type IPListType int
// AsString returns the string representation for the list type
func (t IPListType) AsString() string {
switch t {
case IPListTypeAllowList:
return "Allow list"
case IPListTypeDefender:
return "Defender"
case IPListTypeRateLimiterSafeList:
return "Rate limiters safe list"
default:
return ""
}
}
// Supported IP list types
const (
IPListTypeAllowList IPListType = iota + 1
IPListTypeDefender
IPListTypeRateLimiterSafeList
)
// Supported IP list modes
const (
ListModeAllow = iota + 1
ListModeDeny
)
const (
ipTypeV4 = iota + 1
ipTypeV6
)
var (
supportedIPListType = []IPListType{IPListTypeAllowList, IPListTypeDefender, IPListTypeRateLimiterSafeList}
)
// CheckIPListType returns an error if the provided IP list type is not valid
func CheckIPListType(t IPListType) error {
if !slices.Contains(supportedIPListType, t) {
return util.NewValidationError(fmt.Sprintf("invalid list type %d", t))
}
return nil
}
// IPListEntry defines an entry for the IP addresses list
type IPListEntry struct {
IPOrNet string `json:"ipornet"`
Description string `json:"description,omitempty"`
Type IPListType `json:"type"`
Mode int `json:"mode"`
// Defines the protocols the entry applies to
// - 0 all the supported protocols
// - 1 SSH
// - 2 FTP
// - 4 WebDAV
// - 8 HTTP
// Protocols can be combined
Protocols int `json:"protocols"`
First []byte `json:"first,omitempty"`
Last []byte `json:"last,omitempty"`
IPType int `json:"ip_type,omitempty"`
// Creation time as unix timestamp in milliseconds
CreatedAt int64 `json:"created_at"`
// last update time as unix timestamp in milliseconds
UpdatedAt int64 `json:"updated_at"`
// in multi node setups we mark the rule as deleted to be able to update the cache
DeletedAt int64 `json:"-"`
}
// PrepareForRendering prepares an IP list entry for rendering.
// It hides internal fields
func (e *IPListEntry) PrepareForRendering() {
e.First = nil
e.Last = nil
e.IPType = 0
}
// HasProtocol returns true if the specified protocol is defined
func (e *IPListEntry) HasProtocol(proto string) bool {
switch proto {
case protocolSSH:
return e.Protocols&1 != 0
case protocolFTP:
return e.Protocols&2 != 0
case protocolWebDAV:
return e.Protocols&4 != 0
case protocolHTTP:
return e.Protocols&8 != 0
default:
return false
}
}
// RenderAsJSON implements the renderer interface used within plugins
func (e *IPListEntry) RenderAsJSON(reload bool) ([]byte, error) {
if reload {
entry, err := provider.ipListEntryExists(e.IPOrNet, e.Type)
if err != nil {
providerLog(logger.LevelError, "unable to reload IP list entry before rendering as json: %v", err)
return nil, err
}
entry.PrepareForRendering()
return json.Marshal(entry)
}
e.PrepareForRendering()
return json.Marshal(e)
}
func (e *IPListEntry) getKey() string {
return fmt.Sprintf("%d_%s", e.Type, e.IPOrNet)
}
func (e *IPListEntry) getName() string {
return e.Type.AsString() + "-" + e.IPOrNet
}
func (e *IPListEntry) getFirst() netip.Addr {
if e.IPType == ipTypeV4 {
var a4 [4]byte
copy(a4[:], e.First)
return netip.AddrFrom4(a4)
}
var a16 [16]byte
copy(a16[:], e.First)
return netip.AddrFrom16(a16)
}
func (e *IPListEntry) getLast() netip.Addr {
if e.IPType == ipTypeV4 {
var a4 [4]byte
copy(a4[:], e.Last)
return netip.AddrFrom4(a4)
}
var a16 [16]byte
copy(a16[:], e.Last)
return netip.AddrFrom16(a16)
}
func (e *IPListEntry) checkProtocols() {
for _, proto := range ValidProtocols {
if !e.HasProtocol(proto) {
return
}
}
e.Protocols = 0
}
func (e *IPListEntry) validate() error {
if err := CheckIPListType(e.Type); err != nil {
return err
}
e.checkProtocols()
switch e.Type {
case IPListTypeDefender:
if e.Mode < ListModeAllow || e.Mode > ListModeDeny {
return util.NewValidationError(fmt.Sprintf("invalid list mode: %d", e.Mode))
}
default:
if e.Mode != ListModeAllow {
return util.NewValidationError("invalid list mode")
}
}
e.PrepareForRendering()
if !strings.Contains(e.IPOrNet, "/") {
// parse as IP
parsed, err := netip.ParseAddr(e.IPOrNet)
if err != nil {
return util.NewI18nError(util.NewValidationError(fmt.Sprintf("invalid IP %q", e.IPOrNet)), util.I18nErrorIPInvalid)
}
if parsed.Is4() {
e.IPOrNet += "/32"
} else if parsed.Is4In6() {
e.IPOrNet = netip.AddrFrom4(parsed.As4()).String() + "/32"
} else {
e.IPOrNet += "/128"
}
}
prefix, err := netip.ParsePrefix(e.IPOrNet)
if err != nil {
return util.NewI18nError(util.NewValidationError(fmt.Sprintf("invalid network %q: %v", e.IPOrNet, err)), util.I18nErrorNetInvalid)
}
prefix = prefix.Masked()
if prefix.Addr().Is4In6() {
e.IPOrNet = fmt.Sprintf("%s/%d", netip.AddrFrom4(prefix.Addr().As4()).String(), prefix.Bits()-96)
}
// TODO: to remove when the in memory ranger switch to netip
_, _, err = net.ParseCIDR(e.IPOrNet)
if err != nil {
return util.NewI18nError(util.NewValidationError(fmt.Sprintf("invalid network: %v", err)), util.I18nErrorNetInvalid)
}
if prefix.Addr().Is4() || prefix.Addr().Is4In6() {
e.IPType = ipTypeV4
first := prefix.Addr().As4()
last := util.GetLastIPForPrefix(prefix).As4()
e.First = first[:]
e.Last = last[:]
} else {
e.IPType = ipTypeV6
first := prefix.Addr().As16()
last := util.GetLastIPForPrefix(prefix).As16()
e.First = first[:]
e.Last = last[:]
}
return nil
}
func (e *IPListEntry) getACopy() IPListEntry {
first := make([]byte, len(e.First))
copy(first, e.First)
last := make([]byte, len(e.Last))
copy(last, e.Last)
return IPListEntry{
IPOrNet: e.IPOrNet,
Description: e.Description,
Type: e.Type,
Mode: e.Mode,
First: first,
Last: last,
IPType: e.IPType,
Protocols: e.Protocols,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
DeletedAt: e.DeletedAt,
}
}
// getAsRangerEntry returns the entry as cidranger.RangerEntry
func (e *IPListEntry) getAsRangerEntry() (cidranger.RangerEntry, error) {
_, network, err := net.ParseCIDR(e.IPOrNet)
if err != nil {
return nil, err
}
entry := e.getACopy()
return &rangerEntry{
entry: &entry,
network: *network,
}, nil
}
func (e IPListEntry) satisfySearchConstraints(filter, from, order string) bool {
if filter != "" && !strings.HasPrefix(e.IPOrNet, filter) {
return false
}
if from != "" {
if order == OrderASC {
return e.IPOrNet > from
}
return e.IPOrNet < from
}
return true
}
type rangerEntry struct {
entry *IPListEntry
network net.IPNet
}
func (e *rangerEntry) Network() net.IPNet {
return e.network
}
// IPList defines an IP list
type IPList struct {
isInMemory atomic.Bool
listType IPListType
mu sync.RWMutex
Ranges cidranger.Ranger
}
func (l *IPList) addEntry(e *IPListEntry) {
if l.listType != e.Type {
return
}
if !l.isInMemory.Load() {
return
}
entry, err := e.getAsRangerEntry()
if err != nil {
providerLog(logger.LevelError, "unable to get entry to add %q for list type %d, disabling memory mode, err: %v",
e.IPOrNet, l.listType, err)
l.isInMemory.Store(false)
return
}
l.mu.Lock()
defer l.mu.Unlock()
if err := l.Ranges.Insert(entry); err != nil {
providerLog(logger.LevelError, "unable to add entry %q for list type %d, disabling memory mode, err: %v",
e.IPOrNet, l.listType, err)
l.isInMemory.Store(false)
return
}
if l.Ranges.Len() >= ipListMemoryLimit {
providerLog(logger.LevelError, "memory limit exceeded for list type %d, disabling memory mode", l.listType)
l.isInMemory.Store(false)
}
}
func (l *IPList) removeEntry(e *IPListEntry) {
if l.listType != e.Type {
return
}
if !l.isInMemory.Load() {
return
}
entry, err := e.getAsRangerEntry()
if err != nil {
providerLog(logger.LevelError, "unable to get entry to remove %q for list type %d, disabling memory mode, err: %v",
e.IPOrNet, l.listType, err)
l.isInMemory.Store(false)
return
}
l.mu.Lock()
defer l.mu.Unlock()
if _, err := l.Ranges.Remove(entry.Network()); err != nil {
providerLog(logger.LevelError, "unable to remove entry %q for list type %d, disabling memory mode, err: %v",
e.IPOrNet, l.listType, err)
l.isInMemory.Store(false)
}
}
func (l *IPList) updateEntry(e *IPListEntry) {
if l.listType != e.Type {
return
}
if !l.isInMemory.Load() {
return
}
entry, err := e.getAsRangerEntry()
if err != nil {
providerLog(logger.LevelError, "unable to get entry to update %q for list type %d, disabling memory mode, err: %v",
e.IPOrNet, l.listType, err)
l.isInMemory.Store(false)
return
}
l.mu.Lock()
defer l.mu.Unlock()
if _, err := l.Ranges.Remove(entry.Network()); err != nil {
providerLog(logger.LevelError, "unable to remove entry to update %q for list type %d, disabling memory mode, err: %v",
e.IPOrNet, l.listType, err)
l.isInMemory.Store(false)
return
}
if err := l.Ranges.Insert(entry); err != nil {
providerLog(logger.LevelError, "unable to add entry to update %q for list type %d, disabling memory mode, err: %v",
e.IPOrNet, l.listType, err)
l.isInMemory.Store(false)
}
if l.Ranges.Len() >= ipListMemoryLimit {
providerLog(logger.LevelError, "memory limit exceeded for list type %d, disabling memory mode", l.listType)
l.isInMemory.Store(false)
}
}
// DisableMemoryMode disables memory mode forcing database queries
func (l *IPList) DisableMemoryMode() {
l.isInMemory.Store(false)
}
// IsListed checks if there is a match for the specified IP and protocol.
// If there are multiple matches, the first one is returned, in no particular order,
// so the behavior is undefined
func (l *IPList) IsListed(ip, protocol string) (bool, int, error) {
if l.isInMemory.Load() {
l.mu.RLock()
defer l.mu.RUnlock()
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false, 0, fmt.Errorf("invalid IP %s", ip)
}
entries, err := l.Ranges.ContainingNetworks(parsedIP)
if err != nil {
return false, 0, fmt.Errorf("unable to find containing networks for ip %q: %w", ip, err)
}
for _, e := range entries {
entry, ok := e.(*rangerEntry)
if ok {
if entry.entry.Protocols == 0 || entry.entry.HasProtocol(protocol) {
return true, entry.entry.Mode, nil
}
}
}
return false, 0, nil
}
entries, err := provider.getListEntriesForIP(ip, l.listType)
if err != nil {
return false, 0, err
}
for _, e := range entries {
if e.Protocols == 0 || e.HasProtocol(protocol) {
return true, e.Mode, nil
}
}
return false, 0, nil
}
// NewIPList returns a new IP list for the specified type
func NewIPList(listType IPListType) (*IPList, error) {
delete(inMemoryLists, listType)
count, err := provider.countIPListEntries(listType)
if err != nil {
return nil, err
}
if count < ipListMemoryLimit {
providerLog(logger.LevelInfo, "using in-memory matching for list type %d, num entries: %d", listType, count)
entries, err := provider.getIPListEntries(listType, "", "", OrderASC, 0)
if err != nil {
return nil, err
}
ipList := &IPList{
listType: listType,
Ranges: cidranger.NewPCTrieRanger(),
}
for idx := range entries {
e := entries[idx]
entry, err := e.getAsRangerEntry()
if err != nil {
return nil, fmt.Errorf("unable to get ranger for entry %q: %w", e.IPOrNet, err)
}
if err := ipList.Ranges.Insert(entry); err != nil {
return nil, fmt.Errorf("unable to add ranger for entry %q: %w", e.IPOrNet, err)
}
}
ipList.isInMemory.Store(true)
inMemoryLists[listType] = ipList
return ipList, nil
}
providerLog(logger.LevelInfo, "list type %d has %d entries, in-memory matching disabled", listType, count)
ipList := &IPList{
listType: listType,
Ranges: nil,
}
ipList.isInMemory.Store(false)
return ipList, nil
}