309 lines
7.4 KiB
Go
309 lines
7.4 KiB
Go
package dataprovider
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"sort"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/drakkan/sftpgo/logger"
|
|
"github.com/drakkan/sftpgo/utils"
|
|
)
|
|
|
|
var (
|
|
errMemoryProviderClosed = errors.New("memory provider is closed")
|
|
)
|
|
|
|
type memoryProviderHandle struct {
|
|
isClosed bool
|
|
// slice with ordered usernames
|
|
usernames []string
|
|
// mapping between ID and username
|
|
usersIdx map[int64]string
|
|
// map for users, username is the key
|
|
users map[string]User
|
|
lock *sync.Mutex
|
|
}
|
|
|
|
// MemoryProvider auth provider for a memory store
|
|
type MemoryProvider struct {
|
|
dbHandle *memoryProviderHandle
|
|
}
|
|
|
|
func initializeMemoryProvider() error {
|
|
provider = MemoryProvider{
|
|
dbHandle: &memoryProviderHandle{
|
|
isClosed: false,
|
|
usernames: []string{},
|
|
usersIdx: make(map[int64]string),
|
|
users: make(map[string]User),
|
|
lock: new(sync.Mutex),
|
|
},
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p MemoryProvider) checkAvailability() error {
|
|
p.dbHandle.lock.Lock()
|
|
defer p.dbHandle.lock.Unlock()
|
|
if p.dbHandle.isClosed {
|
|
return errMemoryProviderClosed
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p MemoryProvider) close() error {
|
|
p.dbHandle.lock.Lock()
|
|
defer p.dbHandle.lock.Unlock()
|
|
if p.dbHandle.isClosed {
|
|
return errMemoryProviderClosed
|
|
}
|
|
p.dbHandle.isClosed = true
|
|
return nil
|
|
}
|
|
|
|
func (p MemoryProvider) validateUserAndPass(username string, password string) (User, error) {
|
|
var user User
|
|
if len(password) == 0 {
|
|
return user, errors.New("Credentials cannot be null or empty")
|
|
}
|
|
user, err := p.userExists(username)
|
|
if err != nil {
|
|
providerLog(logger.LevelWarn, "error authenticating user: %v, error: %v", username, err)
|
|
return user, err
|
|
}
|
|
return checkUserAndPass(user, password)
|
|
}
|
|
|
|
func (p MemoryProvider) validateUserAndPubKey(username string, pubKey string) (User, string, error) {
|
|
var user User
|
|
if len(pubKey) == 0 {
|
|
return user, "", errors.New("Credentials cannot be null or empty")
|
|
}
|
|
user, err := p.userExists(username)
|
|
if err != nil {
|
|
providerLog(logger.LevelWarn, "error authenticating user: %v, error: %v", username, err)
|
|
return user, "", err
|
|
}
|
|
return checkUserAndPubKey(user, pubKey)
|
|
}
|
|
|
|
func (p MemoryProvider) getUserByID(ID int64) (User, error) {
|
|
p.dbHandle.lock.Lock()
|
|
defer p.dbHandle.lock.Unlock()
|
|
if p.dbHandle.isClosed {
|
|
return User{}, errMemoryProviderClosed
|
|
}
|
|
if val, ok := p.dbHandle.usersIdx[ID]; ok {
|
|
return p.userExistsInternal(val)
|
|
}
|
|
return User{}, &RecordNotFoundError{err: fmt.Sprintf("user with ID %v does not exist", ID)}
|
|
}
|
|
|
|
func (p MemoryProvider) updateLastLogin(username string) error {
|
|
p.dbHandle.lock.Lock()
|
|
defer p.dbHandle.lock.Unlock()
|
|
if p.dbHandle.isClosed {
|
|
return errMemoryProviderClosed
|
|
}
|
|
user, err := p.userExistsInternal(username)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user.LastLogin = utils.GetTimeAsMsSinceEpoch(time.Now())
|
|
p.dbHandle.users[user.Username] = user
|
|
return nil
|
|
}
|
|
|
|
func (p MemoryProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
|
|
p.dbHandle.lock.Lock()
|
|
defer p.dbHandle.lock.Unlock()
|
|
if p.dbHandle.isClosed {
|
|
return errMemoryProviderClosed
|
|
}
|
|
user, err := p.userExistsInternal(username)
|
|
if err != nil {
|
|
providerLog(logger.LevelWarn, "unable to update quota for user %v error: %v", username, err)
|
|
return err
|
|
}
|
|
if reset {
|
|
user.UsedQuotaSize = sizeAdd
|
|
user.UsedQuotaFiles = filesAdd
|
|
} else {
|
|
user.UsedQuotaSize += sizeAdd
|
|
user.UsedQuotaFiles += filesAdd
|
|
}
|
|
user.LastQuotaUpdate = utils.GetTimeAsMsSinceEpoch(time.Now())
|
|
p.dbHandle.users[user.Username] = user
|
|
return nil
|
|
}
|
|
|
|
func (p MemoryProvider) getUsedQuota(username string) (int, int64, error) {
|
|
p.dbHandle.lock.Lock()
|
|
defer p.dbHandle.lock.Unlock()
|
|
if p.dbHandle.isClosed {
|
|
return 0, 0, errMemoryProviderClosed
|
|
}
|
|
user, err := p.userExistsInternal(username)
|
|
if err != nil {
|
|
providerLog(logger.LevelWarn, "unable to get quota for user %v error: %v", username, err)
|
|
return 0, 0, err
|
|
}
|
|
return user.UsedQuotaFiles, user.UsedQuotaSize, err
|
|
}
|
|
|
|
func (p MemoryProvider) addUser(user User) error {
|
|
p.dbHandle.lock.Lock()
|
|
defer p.dbHandle.lock.Unlock()
|
|
if p.dbHandle.isClosed {
|
|
return errMemoryProviderClosed
|
|
}
|
|
err := validateUser(&user)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = p.userExistsInternal(user.Username)
|
|
if err == nil {
|
|
return fmt.Errorf("username %v already exists", user.Username)
|
|
}
|
|
user.ID = p.getNextID()
|
|
p.dbHandle.users[user.Username] = user
|
|
p.dbHandle.usersIdx[user.ID] = user.Username
|
|
p.dbHandle.usernames = append(p.dbHandle.usernames, user.Username)
|
|
sort.Strings(p.dbHandle.usernames)
|
|
return nil
|
|
}
|
|
|
|
func (p MemoryProvider) updateUser(user User) error {
|
|
p.dbHandle.lock.Lock()
|
|
defer p.dbHandle.lock.Unlock()
|
|
if p.dbHandle.isClosed {
|
|
return errMemoryProviderClosed
|
|
}
|
|
err := validateUser(&user)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = p.userExistsInternal(user.Username)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
p.dbHandle.users[user.Username] = user
|
|
return nil
|
|
}
|
|
|
|
func (p MemoryProvider) deleteUser(user User) error {
|
|
p.dbHandle.lock.Lock()
|
|
defer p.dbHandle.lock.Unlock()
|
|
if p.dbHandle.isClosed {
|
|
return errMemoryProviderClosed
|
|
}
|
|
_, err := p.userExistsInternal(user.Username)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
delete(p.dbHandle.users, user.Username)
|
|
delete(p.dbHandle.usersIdx, user.ID)
|
|
// this could be more efficient
|
|
p.dbHandle.usernames = []string{}
|
|
for username := range p.dbHandle.users {
|
|
p.dbHandle.usernames = append(p.dbHandle.usernames, username)
|
|
}
|
|
sort.Strings(p.dbHandle.usernames)
|
|
return nil
|
|
}
|
|
|
|
func (p MemoryProvider) dumpUsers() ([]User, error) {
|
|
users := []User{}
|
|
var err error
|
|
p.dbHandle.lock.Lock()
|
|
defer p.dbHandle.lock.Unlock()
|
|
if p.dbHandle.isClosed {
|
|
return users, errMemoryProviderClosed
|
|
}
|
|
for _, username := range p.dbHandle.usernames {
|
|
user := p.dbHandle.users[username]
|
|
users = append(users, user)
|
|
}
|
|
return users, err
|
|
}
|
|
|
|
func (p MemoryProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
|
|
users := []User{}
|
|
var err error
|
|
p.dbHandle.lock.Lock()
|
|
defer p.dbHandle.lock.Unlock()
|
|
if p.dbHandle.isClosed {
|
|
return users, errMemoryProviderClosed
|
|
}
|
|
if limit <= 0 {
|
|
return users, err
|
|
}
|
|
if len(username) > 0 {
|
|
if offset == 0 {
|
|
user, err := p.userExistsInternal(username)
|
|
if err == nil {
|
|
user.Password = ""
|
|
users = append(users, user)
|
|
}
|
|
}
|
|
return users, err
|
|
}
|
|
itNum := 0
|
|
if order == "ASC" {
|
|
for _, username := range p.dbHandle.usernames {
|
|
itNum++
|
|
if itNum <= offset {
|
|
continue
|
|
}
|
|
user := p.dbHandle.users[username]
|
|
user.Password = ""
|
|
users = append(users, user)
|
|
if len(users) >= limit {
|
|
break
|
|
}
|
|
}
|
|
} else {
|
|
for i := len(p.dbHandle.usernames) - 1; i >= 0; i-- {
|
|
itNum++
|
|
if itNum <= offset {
|
|
continue
|
|
}
|
|
username := p.dbHandle.usernames[i]
|
|
user := p.dbHandle.users[username]
|
|
user.Password = ""
|
|
users = append(users, user)
|
|
if len(users) >= limit {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return users, err
|
|
}
|
|
|
|
func (p MemoryProvider) userExists(username string) (User, error) {
|
|
p.dbHandle.lock.Lock()
|
|
defer p.dbHandle.lock.Unlock()
|
|
if p.dbHandle.isClosed {
|
|
return User{}, errMemoryProviderClosed
|
|
}
|
|
return p.userExistsInternal(username)
|
|
}
|
|
|
|
func (p MemoryProvider) userExistsInternal(username string) (User, error) {
|
|
if val, ok := p.dbHandle.users[username]; ok {
|
|
return val.getACopy(), nil
|
|
}
|
|
return User{}, &RecordNotFoundError{err: fmt.Sprintf("username %v does not exist", username)}
|
|
}
|
|
|
|
func (p MemoryProvider) getNextID() int64 {
|
|
nextID := int64(1)
|
|
for id := range p.dbHandle.usersIdx {
|
|
if id >= nextID {
|
|
nextID = id + 1
|
|
}
|
|
}
|
|
return nextID
|
|
}
|