sftpgo/common/transferschecker.go

316 lines
9 KiB
Go
Raw Normal View History

package common
import (
"errors"
"sync"
"time"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
type overquotaTransfer struct {
ConnID string
TransferID int64
TransferType int
}
type uploadAggregationKey struct {
Username string
FolderName string
}
// TransfersChecker defines the interface that transfer checkers must implement.
// A transfer checker ensure that multiple concurrent transfers does not exceeded
// the remaining user quota
type TransfersChecker interface {
AddTransfer(transfer dataprovider.ActiveTransfer)
RemoveTransfer(ID int64, connectionID string)
UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string)
GetOverquotaTransfers() []overquotaTransfer
}
func getTransfersChecker(isShared int) TransfersChecker {
if isShared == 1 {
logger.Info(logSender, "", "using provider transfer checker")
return &transfersCheckerDB{}
}
logger.Info(logSender, "", "using memory transfer checker")
return &transfersCheckerMem{}
}
type baseTransferChecker struct {
transfers []dataprovider.ActiveTransfer
}
func (t *baseTransferChecker) isDataTransferExceeded(user dataprovider.User, transfer dataprovider.ActiveTransfer, ulSize,
dlSize int64,
) bool {
ulQuota, dlQuota, totalQuota := user.GetDataTransferLimits(transfer.IP)
if totalQuota > 0 {
allowedSize := totalQuota - (user.UsedUploadDataTransfer + user.UsedDownloadDataTransfer)
if ulSize+dlSize > allowedSize {
return transfer.CurrentDLSize > 0 || transfer.CurrentULSize > 0
}
}
if dlQuota > 0 {
allowedSize := dlQuota - user.UsedDownloadDataTransfer
if dlSize > allowedSize {
return transfer.CurrentDLSize > 0
}
}
if ulQuota > 0 {
allowedSize := ulQuota - user.UsedUploadDataTransfer
if ulSize > allowedSize {
return transfer.CurrentULSize > 0
}
}
return false
}
func (t *baseTransferChecker) getRemainingDiskQuota(user dataprovider.User, folderName string) (int64, error) {
var result int64
if folderName != "" {
for _, folder := range user.VirtualFolders {
if folder.Name == folderName {
if folder.QuotaSize > 0 {
return folder.QuotaSize - folder.UsedQuotaSize, nil
}
}
}
} else {
if user.QuotaSize > 0 {
return user.QuotaSize - user.UsedQuotaSize, nil
}
}
return result, errors.New("no quota limit defined")
}
func (t *baseTransferChecker) aggregateTransfersByUser(usersToFetch map[string]bool,
) (map[string]bool, map[string][]dataprovider.ActiveTransfer) {
aggregations := make(map[string][]dataprovider.ActiveTransfer)
for _, transfer := range t.transfers {
aggregations[transfer.Username] = append(aggregations[transfer.Username], transfer)
if len(aggregations[transfer.Username]) > 1 {
if _, ok := usersToFetch[transfer.Username]; !ok {
usersToFetch[transfer.Username] = false
}
}
}
return usersToFetch, aggregations
}
func (t *baseTransferChecker) aggregateUploadTransfers() (map[string]bool, map[int][]dataprovider.ActiveTransfer) {
usersToFetch := make(map[string]bool)
aggregations := make(map[int][]dataprovider.ActiveTransfer)
var keys []uploadAggregationKey
for _, transfer := range t.transfers {
if transfer.Type != TransferUpload {
continue
}
key := -1
for idx, k := range keys {
if k.Username == transfer.Username && k.FolderName == transfer.FolderName {
key = idx
break
}
}
if key == -1 {
key = len(keys)
}
keys = append(keys, uploadAggregationKey{
Username: transfer.Username,
FolderName: transfer.FolderName,
})
aggregations[key] = append(aggregations[key], transfer)
if len(aggregations[key]) > 1 {
if transfer.FolderName != "" {
usersToFetch[transfer.Username] = true
} else {
if _, ok := usersToFetch[transfer.Username]; !ok {
usersToFetch[transfer.Username] = false
}
}
}
}
return usersToFetch, aggregations
}
func (t *baseTransferChecker) getUsersToCheck(usersToFetch map[string]bool) (map[string]dataprovider.User, error) {
users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch)
if err != nil {
return nil, err
}
usersMap := make(map[string]dataprovider.User)
for _, user := range users {
usersMap[user.Username] = user
}
return usersMap, nil
}
func (t *baseTransferChecker) getOverquotaTransfers(usersToFetch map[string]bool,
uploadAggregations map[int][]dataprovider.ActiveTransfer,
userAggregations map[string][]dataprovider.ActiveTransfer,
) []overquotaTransfer {
if len(usersToFetch) == 0 {
return nil
}
usersMap, err := t.getUsersToCheck(usersToFetch)
if err != nil {
logger.Warn(logSender, "", "unable to check transfers, error getting users quota: %v", err)
return nil
}
var overquotaTransfers []overquotaTransfer
for _, transfers := range uploadAggregations {
username := transfers[0].Username
folderName := transfers[0].FolderName
remaningDiskQuota, err := t.getRemainingDiskQuota(usersMap[username], folderName)
if err != nil {
continue
}
var usedDiskQuota int64
for _, tr := range transfers {
// We optimistically assume that a cloud transfer that replaces an existing
// file will be successful
usedDiskQuota += tr.CurrentULSize - tr.TruncatedSize
}
logger.Debug(logSender, "", "username %#v, folder %#v, concurrent transfers: %v, remaining disk quota (bytes): %v, disk quota used in ongoing transfers (bytes): %v",
username, folderName, len(transfers), remaningDiskQuota, usedDiskQuota)
if usedDiskQuota > remaningDiskQuota {
for _, tr := range transfers {
if tr.CurrentULSize > tr.TruncatedSize {
overquotaTransfers = append(overquotaTransfers, overquotaTransfer{
ConnID: tr.ConnID,
TransferID: tr.ID,
TransferType: tr.Type,
})
}
}
}
}
for username, transfers := range userAggregations {
var ulSize, dlSize int64
for _, tr := range transfers {
ulSize += tr.CurrentULSize
dlSize += tr.CurrentDLSize
}
logger.Debug(logSender, "", "username %#v, concurrent transfers: %v, quota (bytes) used in ongoing transfers, ul: %v, dl: %v",
username, len(transfers), ulSize, dlSize)
for _, tr := range transfers {
if t.isDataTransferExceeded(usersMap[username], tr, ulSize, dlSize) {
overquotaTransfers = append(overquotaTransfers, overquotaTransfer{
ConnID: tr.ConnID,
TransferID: tr.ID,
TransferType: tr.Type,
})
}
}
}
return overquotaTransfers
}
type transfersCheckerMem struct {
sync.RWMutex
baseTransferChecker
}
func (t *transfersCheckerMem) AddTransfer(transfer dataprovider.ActiveTransfer) {
t.Lock()
defer t.Unlock()
t.transfers = append(t.transfers, transfer)
}
func (t *transfersCheckerMem) RemoveTransfer(ID int64, connectionID string) {
t.Lock()
defer t.Unlock()
for idx, transfer := range t.transfers {
if transfer.ID == ID && transfer.ConnID == connectionID {
lastIdx := len(t.transfers) - 1
t.transfers[idx] = t.transfers[lastIdx]
t.transfers = t.transfers[:lastIdx]
return
}
}
}
func (t *transfersCheckerMem) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) {
t.Lock()
defer t.Unlock()
for idx := range t.transfers {
if t.transfers[idx].ID == ID && t.transfers[idx].ConnID == connectionID {
t.transfers[idx].CurrentDLSize = dlSize
t.transfers[idx].CurrentULSize = ulSize
t.transfers[idx].UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
return
}
}
}
func (t *transfersCheckerMem) GetOverquotaTransfers() []overquotaTransfer {
t.RLock()
usersToFetch, uploadAggregations := t.aggregateUploadTransfers()
usersToFetch, userAggregations := t.aggregateTransfersByUser(usersToFetch)
t.RUnlock()
return t.getOverquotaTransfers(usersToFetch, uploadAggregations, userAggregations)
}
type transfersCheckerDB struct {
baseTransferChecker
lastCleanup time.Time
}
func (t *transfersCheckerDB) AddTransfer(transfer dataprovider.ActiveTransfer) {
dataprovider.AddActiveTransfer(transfer)
}
func (t *transfersCheckerDB) RemoveTransfer(ID int64, connectionID string) {
dataprovider.RemoveActiveTransfer(ID, connectionID)
}
func (t *transfersCheckerDB) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) {
dataprovider.UpdateActiveTransferSizes(ulSize, dlSize, ID, connectionID)
}
func (t *transfersCheckerDB) GetOverquotaTransfers() []overquotaTransfer {
if t.lastCleanup.IsZero() || t.lastCleanup.Add(periodicTimeoutCheckInterval*15).Before(time.Now()) {
before := time.Now().Add(-periodicTimeoutCheckInterval * 5)
err := dataprovider.CleanupActiveTransfers(before)
logger.Debug(logSender, "", "cleanup active transfers completed, err: %v", err)
if err == nil {
t.lastCleanup = time.Now()
}
}
var err error
from := time.Now().Add(-periodicTimeoutCheckInterval * 2)
t.transfers, err = dataprovider.GetActiveTransfers(from)
if err != nil {
logger.Error(logSender, "", "unable to check overquota transfers, error getting active transfers: %v", err)
return nil
}
usersToFetch, uploadAggregations := t.aggregateUploadTransfers()
usersToFetch, userAggregations := t.aggregateTransfersByUser(usersToFetch)
return t.getOverquotaTransfers(usersToFetch, uploadAggregations, userAggregations)
}