add memory data provider and use it for portable mode

This commit is contained in:
Nicola Murino 2019-10-25 18:37:12 +02:00
parent a4cddf4f7f
commit 8cd0aec417
7 changed files with 317 additions and 11 deletions

View file

@ -161,7 +161,7 @@ func (p BoltProvider) userExists(username string) (User, error) {
}
u := bucket.Get([]byte(username))
if u == nil {
return &RecordNotFoundError{err: fmt.Sprintf("username %v does not exist", user.Username)}
return &RecordNotFoundError{err: fmt.Sprintf("username %v does not exist", username)}
}
return json.Unmarshal(u, &user)
})
@ -242,6 +242,9 @@ func (p BoltProvider) deleteUser(user User) error {
func (p BoltProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
users := []User{}
var err error
if limit <= 0 {
return users, err
}
if len(username) > 0 {
if offset == 0 {
user, err := p.userExists(username)
@ -252,9 +255,6 @@ func (p BoltProvider) getUsers(limit int, offset int, order string, username str
return users, err
}
err = p.dbHandle.View(func(tx *bolt.Tx) error {
if limit <= 0 {
return nil
}
bucket, _, err := getBuckets(tx)
if err != nil {
return err

View file

@ -37,6 +37,8 @@ const (
MySQLDataProviderName = "mysql"
// BoltDataProviderName name for bbolt key/value store provider
BoltDataProviderName = "bolt"
// MemoryDataProviderName name for memory provider
MemoryDataProviderName = "memory"
argonPwdPrefix = "$argon2id$"
bcryptPwdPrefix = "$2a$"
@ -50,7 +52,8 @@ const (
var (
// SupportedProviders data provider configured in the sftpgo.conf file must match of these strings
SupportedProviders = []string{SQLiteDataProviderName, PGSQLDataProviderName, MySQLDataProviderName, BoltDataProviderName}
SupportedProviders = []string{SQLiteDataProviderName, PGSQLDataProviderName, MySQLDataProviderName,
BoltDataProviderName, MemoryDataProviderName}
// ValidPerms list that contains all the valid permissions for an user
ValidPerms = []string{PermAny, PermListItems, PermDownload, PermUpload, PermOverwrite, PermRename, PermDelete,
PermCreateDirs, PermCreateSymlinks}
@ -179,6 +182,8 @@ func Initialize(cnf Config, basePath string) error {
err = initializeMySQLProvider()
} else if config.Driver == BoltDataProviderName {
err = initializeBoltProvider(basePath)
} else if config.Driver == MemoryDataProviderName {
err = initializeMemoryProvider()
} else {
err = fmt.Errorf("unsupported data provider: %v", config.Driver)
}

279
dataprovider/memory.go Normal file
View file

@ -0,0 +1,279 @@
package dataprovider
import (
"errors"
"fmt"
"sort"
"sync"
"time"
"github.com/drakkan/sftpgo/logger"
"github.com/drakkan/sftpgo/utils"
)
var (
errMemoryProviderClosed = errors.New("memory provider is closed")
)
type memoryProviderHandle struct {
isClosed bool
// slice with ordered usernames
usernames []string
// mapping between ID and username
usersIdx map[int64]string
// map for users, username is the key
users map[string]User
lock *sync.Mutex
}
// MemoryProvider auth provider for a memory store
type MemoryProvider struct {
dbHandle *memoryProviderHandle
}
func initializeMemoryProvider() error {
provider = MemoryProvider{
dbHandle: &memoryProviderHandle{
isClosed: false,
usernames: []string{},
usersIdx: make(map[int64]string),
users: make(map[string]User),
lock: new(sync.Mutex),
},
}
return nil
}
func (p MemoryProvider) checkAvailability() error {
p.dbHandle.lock.Lock()
defer p.dbHandle.lock.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
return nil
}
func (p MemoryProvider) close() error {
p.dbHandle.lock.Lock()
defer p.dbHandle.lock.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
p.dbHandle.isClosed = true
return nil
}
func (p MemoryProvider) validateUserAndPass(username string, password string) (User, error) {
var user User
if len(password) == 0 {
return user, errors.New("Credentials cannot be null or empty")
}
user, err := p.userExists(username)
if err != nil {
providerLog(logger.LevelWarn, "error authenticating user: %v, error: %v", username, err)
return user, err
}
return checkUserAndPass(user, password)
}
func (p MemoryProvider) validateUserAndPubKey(username string, pubKey string) (User, string, error) {
var user User
if len(pubKey) == 0 {
return user, "", errors.New("Credentials cannot be null or empty")
}
user, err := p.userExists(username)
if err != nil {
providerLog(logger.LevelWarn, "error authenticating user: %v, error: %v", username, err)
return user, "", err
}
return checkUserAndPubKey(user, pubKey)
}
func (p MemoryProvider) getUserByID(ID int64) (User, error) {
p.dbHandle.lock.Lock()
defer p.dbHandle.lock.Unlock()
if p.dbHandle.isClosed {
return User{}, errMemoryProviderClosed
}
if val, ok := p.dbHandle.usersIdx[ID]; ok {
return p.userExistsInternal(val)
}
return User{}, &RecordNotFoundError{err: fmt.Sprintf("user with ID %v does not exist", ID)}
}
func (p MemoryProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
p.dbHandle.lock.Lock()
defer p.dbHandle.lock.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
user, err := p.userExistsInternal(username)
if err != nil {
providerLog(logger.LevelWarn, "unable to update quota for user %v error: %v", username, err)
return err
}
if reset {
user.UsedQuotaSize = sizeAdd
user.UsedQuotaFiles = filesAdd
} else {
user.UsedQuotaSize += sizeAdd
user.UsedQuotaFiles += filesAdd
}
user.LastQuotaUpdate = utils.GetTimeAsMsSinceEpoch(time.Now())
p.dbHandle.users[user.Username] = user
return nil
}
func (p MemoryProvider) getUsedQuota(username string) (int, int64, error) {
p.dbHandle.lock.Lock()
defer p.dbHandle.lock.Unlock()
if p.dbHandle.isClosed {
return 0, 0, errMemoryProviderClosed
}
user, err := p.userExistsInternal(username)
if err != nil {
providerLog(logger.LevelWarn, "unable to get quota for user %v error: %v", username, err)
return 0, 0, err
}
return user.UsedQuotaFiles, user.UsedQuotaSize, err
}
func (p MemoryProvider) addUser(user User) error {
p.dbHandle.lock.Lock()
defer p.dbHandle.lock.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
err := validateUser(&user)
if err != nil {
return err
}
_, err = p.userExistsInternal(user.Username)
if err == nil {
return fmt.Errorf("username %v already exists", user.Username)
}
user.ID = p.getNextID()
p.dbHandle.users[user.Username] = user
p.dbHandle.usersIdx[user.ID] = user.Username
p.dbHandle.usernames = append(p.dbHandle.usernames, user.Username)
sort.Strings(p.dbHandle.usernames)
return nil
}
func (p MemoryProvider) updateUser(user User) error {
p.dbHandle.lock.Lock()
defer p.dbHandle.lock.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
err := validateUser(&user)
if err != nil {
return err
}
_, err = p.userExistsInternal(user.Username)
if err != nil {
return err
}
p.dbHandle.users[user.Username] = user
return nil
}
func (p MemoryProvider) deleteUser(user User) error {
p.dbHandle.lock.Lock()
defer p.dbHandle.lock.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
_, err := p.userExistsInternal(user.Username)
if err != nil {
return err
}
delete(p.dbHandle.users, user.Username)
delete(p.dbHandle.usersIdx, user.ID)
// this could be more efficient
p.dbHandle.usernames = []string{}
for username := range p.dbHandle.users {
p.dbHandle.usernames = append(p.dbHandle.usernames, username)
}
sort.Strings(p.dbHandle.usernames)
return nil
}
func (p MemoryProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
users := []User{}
var err error
p.dbHandle.lock.Lock()
defer p.dbHandle.lock.Unlock()
if p.dbHandle.isClosed {
return users, errMemoryProviderClosed
}
if limit <= 0 {
return users, err
}
if len(username) > 0 {
if offset == 0 {
user, err := p.userExistsInternal(username)
if err == nil {
user.Password = ""
users = append(users, user)
}
}
return users, err
}
itNum := 0
if order == "ASC" {
for _, username := range p.dbHandle.usernames {
itNum++
if itNum <= offset {
continue
}
user := p.dbHandle.users[username]
user.Password = ""
users = append(users, user)
if len(users) >= limit {
break
}
}
} else {
for i := len(p.dbHandle.usernames) - 1; i >= 0; i-- {
itNum++
if itNum <= offset {
continue
}
username := p.dbHandle.usernames[i]
user := p.dbHandle.users[username]
user.Password = ""
users = append(users, user)
if len(users) >= limit {
break
}
}
}
return users, err
}
func (p MemoryProvider) userExists(username string) (User, error) {
p.dbHandle.lock.Lock()
defer p.dbHandle.lock.Unlock()
if p.dbHandle.isClosed {
return User{}, errMemoryProviderClosed
}
return p.userExistsInternal(username)
}
func (p MemoryProvider) userExistsInternal(username string) (User, error) {
if val, ok := p.dbHandle.users[username]; ok {
return val.getACopy(), nil
}
return User{}, &RecordNotFoundError{err: fmt.Sprintf("username %v does not exist", username)}
}
func (p MemoryProvider) getNextID() int64 {
nextID := int64(1)
for id := range p.dbHandle.usersIdx {
if id >= nextID {
nextID = id + 1
}
}
return nextID
}

View file

@ -189,3 +189,28 @@ func (u *User) GetInfoString() string {
}
return result
}
func (u *User) getACopy() User {
pubKeys := make([]string, len(u.PublicKeys))
copy(pubKeys, u.PublicKeys)
permissions := make([]string, len(u.Permissions))
copy(permissions, u.Permissions)
return User{
ID: u.ID,
Username: u.Username,
Password: u.Password,
PublicKeys: pubKeys,
HomeDir: u.HomeDir,
UID: u.UID,
GID: u.GID,
MaxSessions: u.MaxSessions,
QuotaSize: u.QuotaSize,
QuotaFiles: u.QuotaFiles,
Permissions: permissions,
UsedQuotaSize: u.UsedQuotaSize,
UsedQuotaFiles: u.UsedQuotaFiles,
LastQuotaUpdate: u.LastQuotaUpdate,
UploadBandwidth: u.UploadBandwidth,
DownloadBandwidth: u.DownloadBandwidth,
}
}

View file

@ -143,11 +143,9 @@ func (s *Service) StartPortableMode(sftpdPort int, enableSCP bool) error {
}
tempDir := os.TempDir()
instanceID := xid.New().String()
databasePath := filepath.Join(tempDir, instanceID+".db")
s.LogFilePath = filepath.Join(tempDir, instanceID+".log")
dataProviderConf := config.GetProviderConf()
dataProviderConf.Driver = dataprovider.BoltDataProviderName
dataProviderConf.Name = databasePath
dataProviderConf.Driver = dataprovider.MemoryDataProviderName
config.SetProviderConf(dataProviderConf)
httpdConf := config.GetHTTPDConfig()
httpdConf.BindPort = 0

View file

@ -277,8 +277,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
if err := ssh.Unmarshal(req.Payload, &msg); err == nil {
name, scpArgs, err := parseCommandPayload(msg.Command)
connection.Log(logger.LevelDebug, logSender, "new exec command: %#v args: %v user: %v, error: %v",
name, scpArgs,
connection.User.Username, err)
name, scpArgs, connection.User.Username, err)
if err == nil && name == "scp" && len(scpArgs) >= 2 {
ok = true
connection.protocol = protocolSCP

View file

@ -200,7 +200,7 @@ func TestBasicSFTPHandling(t *testing.T) {
testFileSize := int64(65535)
expectedQuotaSize := user.UsedQuotaSize + testFileSize
expectedQuotaFiles := user.UsedQuotaFiles + 1
err = createTestFile(testFilePath, testFileSize)
createTestFile(testFilePath, testFileSize)
err = sftpUploadFile(testFilePath, path.Join("/missing_dir", testFileName), testFileSize, client)
if err == nil {
t.Errorf("upload a file to a missing dir must fail")