sftpgo/common/transferschecker.go
Nicola Murino 1df1225eed
add support for data transfer bandwidth limits
with total limit or separate settings for uploads and downloads and
overrides based on the client's IP address.

Limits can be reset using the REST API

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-01-30 11:42:36 +01:00

315 lines
9 KiB
Go

package common
import (
"errors"
"sync"
"time"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
type overquotaTransfer struct {
ConnID string
TransferID int64
TransferType int
}
type uploadAggregationKey struct {
Username string
FolderName string
}
// TransfersChecker defines the interface that transfer checkers must implement.
// A transfer checker ensure that multiple concurrent transfers does not exceeded
// the remaining user quota
type TransfersChecker interface {
AddTransfer(transfer dataprovider.ActiveTransfer)
RemoveTransfer(ID int64, connectionID string)
UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string)
GetOverquotaTransfers() []overquotaTransfer
}
func getTransfersChecker(isShared int) TransfersChecker {
if isShared == 1 {
logger.Info(logSender, "", "using provider transfer checker")
return &transfersCheckerDB{}
}
logger.Info(logSender, "", "using memory transfer checker")
return &transfersCheckerMem{}
}
type baseTransferChecker struct {
transfers []dataprovider.ActiveTransfer
}
func (t *baseTransferChecker) isDataTransferExceeded(user dataprovider.User, transfer dataprovider.ActiveTransfer, ulSize,
dlSize int64,
) bool {
ulQuota, dlQuota, totalQuota := user.GetDataTransferLimits(transfer.IP)
if totalQuota > 0 {
allowedSize := totalQuota - (user.UsedUploadDataTransfer + user.UsedDownloadDataTransfer)
if ulSize+dlSize > allowedSize {
return transfer.CurrentDLSize > 0 || transfer.CurrentULSize > 0
}
}
if dlQuota > 0 {
allowedSize := dlQuota - user.UsedDownloadDataTransfer
if dlSize > allowedSize {
return transfer.CurrentDLSize > 0
}
}
if ulQuota > 0 {
allowedSize := ulQuota - user.UsedUploadDataTransfer
if ulSize > allowedSize {
return transfer.CurrentULSize > 0
}
}
return false
}
func (t *baseTransferChecker) getRemainingDiskQuota(user dataprovider.User, folderName string) (int64, error) {
var result int64
if folderName != "" {
for _, folder := range user.VirtualFolders {
if folder.Name == folderName {
if folder.QuotaSize > 0 {
return folder.QuotaSize - folder.UsedQuotaSize, nil
}
}
}
} else {
if user.QuotaSize > 0 {
return user.QuotaSize - user.UsedQuotaSize, nil
}
}
return result, errors.New("no quota limit defined")
}
func (t *baseTransferChecker) aggregateTransfersByUser(usersToFetch map[string]bool,
) (map[string]bool, map[string][]dataprovider.ActiveTransfer) {
aggregations := make(map[string][]dataprovider.ActiveTransfer)
for _, transfer := range t.transfers {
aggregations[transfer.Username] = append(aggregations[transfer.Username], transfer)
if len(aggregations[transfer.Username]) > 1 {
if _, ok := usersToFetch[transfer.Username]; !ok {
usersToFetch[transfer.Username] = false
}
}
}
return usersToFetch, aggregations
}
func (t *baseTransferChecker) aggregateUploadTransfers() (map[string]bool, map[int][]dataprovider.ActiveTransfer) {
usersToFetch := make(map[string]bool)
aggregations := make(map[int][]dataprovider.ActiveTransfer)
var keys []uploadAggregationKey
for _, transfer := range t.transfers {
if transfer.Type != TransferUpload {
continue
}
key := -1
for idx, k := range keys {
if k.Username == transfer.Username && k.FolderName == transfer.FolderName {
key = idx
break
}
}
if key == -1 {
key = len(keys)
}
keys = append(keys, uploadAggregationKey{
Username: transfer.Username,
FolderName: transfer.FolderName,
})
aggregations[key] = append(aggregations[key], transfer)
if len(aggregations[key]) > 1 {
if transfer.FolderName != "" {
usersToFetch[transfer.Username] = true
} else {
if _, ok := usersToFetch[transfer.Username]; !ok {
usersToFetch[transfer.Username] = false
}
}
}
}
return usersToFetch, aggregations
}
func (t *baseTransferChecker) getUsersToCheck(usersToFetch map[string]bool) (map[string]dataprovider.User, error) {
users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch)
if err != nil {
return nil, err
}
usersMap := make(map[string]dataprovider.User)
for _, user := range users {
usersMap[user.Username] = user
}
return usersMap, nil
}
func (t *baseTransferChecker) getOverquotaTransfers(usersToFetch map[string]bool,
uploadAggregations map[int][]dataprovider.ActiveTransfer,
userAggregations map[string][]dataprovider.ActiveTransfer,
) []overquotaTransfer {
if len(usersToFetch) == 0 {
return nil
}
usersMap, err := t.getUsersToCheck(usersToFetch)
if err != nil {
logger.Warn(logSender, "", "unable to check transfers, error getting users quota: %v", err)
return nil
}
var overquotaTransfers []overquotaTransfer
for _, transfers := range uploadAggregations {
username := transfers[0].Username
folderName := transfers[0].FolderName
remaningDiskQuota, err := t.getRemainingDiskQuota(usersMap[username], folderName)
if err != nil {
continue
}
var usedDiskQuota int64
for _, tr := range transfers {
// We optimistically assume that a cloud transfer that replaces an existing
// file will be successful
usedDiskQuota += tr.CurrentULSize - tr.TruncatedSize
}
logger.Debug(logSender, "", "username %#v, folder %#v, concurrent transfers: %v, remaining disk quota (bytes): %v, disk quota used in ongoing transfers (bytes): %v",
username, folderName, len(transfers), remaningDiskQuota, usedDiskQuota)
if usedDiskQuota > remaningDiskQuota {
for _, tr := range transfers {
if tr.CurrentULSize > tr.TruncatedSize {
overquotaTransfers = append(overquotaTransfers, overquotaTransfer{
ConnID: tr.ConnID,
TransferID: tr.ID,
TransferType: tr.Type,
})
}
}
}
}
for username, transfers := range userAggregations {
var ulSize, dlSize int64
for _, tr := range transfers {
ulSize += tr.CurrentULSize
dlSize += tr.CurrentDLSize
}
logger.Debug(logSender, "", "username %#v, concurrent transfers: %v, quota (bytes) used in ongoing transfers, ul: %v, dl: %v",
username, len(transfers), ulSize, dlSize)
for _, tr := range transfers {
if t.isDataTransferExceeded(usersMap[username], tr, ulSize, dlSize) {
overquotaTransfers = append(overquotaTransfers, overquotaTransfer{
ConnID: tr.ConnID,
TransferID: tr.ID,
TransferType: tr.Type,
})
}
}
}
return overquotaTransfers
}
type transfersCheckerMem struct {
sync.RWMutex
baseTransferChecker
}
func (t *transfersCheckerMem) AddTransfer(transfer dataprovider.ActiveTransfer) {
t.Lock()
defer t.Unlock()
t.transfers = append(t.transfers, transfer)
}
func (t *transfersCheckerMem) RemoveTransfer(ID int64, connectionID string) {
t.Lock()
defer t.Unlock()
for idx, transfer := range t.transfers {
if transfer.ID == ID && transfer.ConnID == connectionID {
lastIdx := len(t.transfers) - 1
t.transfers[idx] = t.transfers[lastIdx]
t.transfers = t.transfers[:lastIdx]
return
}
}
}
func (t *transfersCheckerMem) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) {
t.Lock()
defer t.Unlock()
for idx := range t.transfers {
if t.transfers[idx].ID == ID && t.transfers[idx].ConnID == connectionID {
t.transfers[idx].CurrentDLSize = dlSize
t.transfers[idx].CurrentULSize = ulSize
t.transfers[idx].UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
return
}
}
}
func (t *transfersCheckerMem) GetOverquotaTransfers() []overquotaTransfer {
t.RLock()
usersToFetch, uploadAggregations := t.aggregateUploadTransfers()
usersToFetch, userAggregations := t.aggregateTransfersByUser(usersToFetch)
t.RUnlock()
return t.getOverquotaTransfers(usersToFetch, uploadAggregations, userAggregations)
}
type transfersCheckerDB struct {
baseTransferChecker
lastCleanup time.Time
}
func (t *transfersCheckerDB) AddTransfer(transfer dataprovider.ActiveTransfer) {
dataprovider.AddActiveTransfer(transfer)
}
func (t *transfersCheckerDB) RemoveTransfer(ID int64, connectionID string) {
dataprovider.RemoveActiveTransfer(ID, connectionID)
}
func (t *transfersCheckerDB) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) {
dataprovider.UpdateActiveTransferSizes(ulSize, dlSize, ID, connectionID)
}
func (t *transfersCheckerDB) GetOverquotaTransfers() []overquotaTransfer {
if t.lastCleanup.IsZero() || t.lastCleanup.Add(periodicTimeoutCheckInterval*15).Before(time.Now()) {
before := time.Now().Add(-periodicTimeoutCheckInterval * 5)
err := dataprovider.CleanupActiveTransfers(before)
logger.Debug(logSender, "", "cleanup active transfers completed, err: %v", err)
if err == nil {
t.lastCleanup = time.Now()
}
}
var err error
from := time.Now().Add(-periodicTimeoutCheckInterval * 2)
t.transfers, err = dataprovider.GetActiveTransfers(from)
if err != nil {
logger.Error(logSender, "", "unable to check overquota transfers, error getting active transfers: %v", err)
return nil
}
usersToFetch, uploadAggregations := t.aggregateUploadTransfers()
usersToFetch, userAggregations := t.aggregateTransfersByUser(usersToFetch)
return t.getOverquotaTransfers(usersToFetch, uploadAggregations, userAggregations)
}