diff --git a/dataprovider/bolt.go b/dataprovider/bolt.go index 0821ffef..5d071956 100644 --- a/dataprovider/bolt.go +++ b/dataprovider/bolt.go @@ -161,7 +161,7 @@ func (p BoltProvider) userExists(username string) (User, error) { } u := bucket.Get([]byte(username)) if u == nil { - return &RecordNotFoundError{err: fmt.Sprintf("username %v does not exist", user.Username)} + return &RecordNotFoundError{err: fmt.Sprintf("username %v does not exist", username)} } return json.Unmarshal(u, &user) }) @@ -242,6 +242,9 @@ func (p BoltProvider) deleteUser(user User) error { func (p BoltProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { users := []User{} var err error + if limit <= 0 { + return users, err + } if len(username) > 0 { if offset == 0 { user, err := p.userExists(username) @@ -252,9 +255,6 @@ func (p BoltProvider) getUsers(limit int, offset int, order string, username str return users, err } err = p.dbHandle.View(func(tx *bolt.Tx) error { - if limit <= 0 { - return nil - } bucket, _, err := getBuckets(tx) if err != nil { return err diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index afc07457..b038a4b7 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -37,6 +37,8 @@ const ( MySQLDataProviderName = "mysql" // BoltDataProviderName name for bbolt key/value store provider BoltDataProviderName = "bolt" + // MemoryDataProviderName name for memory provider + MemoryDataProviderName = "memory" argonPwdPrefix = "$argon2id$" bcryptPwdPrefix = "$2a$" @@ -50,7 +52,8 @@ const ( var ( // SupportedProviders data provider configured in the sftpgo.conf file must match of these strings - SupportedProviders = []string{SQLiteDataProviderName, PGSQLDataProviderName, MySQLDataProviderName, BoltDataProviderName} + SupportedProviders = []string{SQLiteDataProviderName, PGSQLDataProviderName, MySQLDataProviderName, + BoltDataProviderName, MemoryDataProviderName} // ValidPerms list that contains all the valid permissions for an user ValidPerms = []string{PermAny, PermListItems, PermDownload, PermUpload, PermOverwrite, PermRename, PermDelete, PermCreateDirs, PermCreateSymlinks} @@ -179,6 +182,8 @@ func Initialize(cnf Config, basePath string) error { err = initializeMySQLProvider() } else if config.Driver == BoltDataProviderName { err = initializeBoltProvider(basePath) + } else if config.Driver == MemoryDataProviderName { + err = initializeMemoryProvider() } else { err = fmt.Errorf("unsupported data provider: %v", config.Driver) } diff --git a/dataprovider/memory.go b/dataprovider/memory.go new file mode 100644 index 00000000..4c4b7fba --- /dev/null +++ b/dataprovider/memory.go @@ -0,0 +1,279 @@ +package dataprovider + +import ( + "errors" + "fmt" + "sort" + "sync" + "time" + + "github.com/drakkan/sftpgo/logger" + "github.com/drakkan/sftpgo/utils" +) + +var ( + errMemoryProviderClosed = errors.New("memory provider is closed") +) + +type memoryProviderHandle struct { + isClosed bool + // slice with ordered usernames + usernames []string + // mapping between ID and username + usersIdx map[int64]string + // map for users, username is the key + users map[string]User + lock *sync.Mutex +} + +// MemoryProvider auth provider for a memory store +type MemoryProvider struct { + dbHandle *memoryProviderHandle +} + +func initializeMemoryProvider() error { + provider = MemoryProvider{ + dbHandle: &memoryProviderHandle{ + isClosed: false, + usernames: []string{}, + usersIdx: make(map[int64]string), + users: make(map[string]User), + lock: new(sync.Mutex), + }, + } + return nil +} + +func (p MemoryProvider) checkAvailability() error { + p.dbHandle.lock.Lock() + defer p.dbHandle.lock.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + return nil +} + +func (p MemoryProvider) close() error { + p.dbHandle.lock.Lock() + defer p.dbHandle.lock.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + p.dbHandle.isClosed = true + return nil +} + +func (p MemoryProvider) validateUserAndPass(username string, password string) (User, error) { + var user User + if len(password) == 0 { + return user, errors.New("Credentials cannot be null or empty") + } + user, err := p.userExists(username) + if err != nil { + providerLog(logger.LevelWarn, "error authenticating user: %v, error: %v", username, err) + return user, err + } + return checkUserAndPass(user, password) +} + +func (p MemoryProvider) validateUserAndPubKey(username string, pubKey string) (User, string, error) { + var user User + if len(pubKey) == 0 { + return user, "", errors.New("Credentials cannot be null or empty") + } + user, err := p.userExists(username) + if err != nil { + providerLog(logger.LevelWarn, "error authenticating user: %v, error: %v", username, err) + return user, "", err + } + return checkUserAndPubKey(user, pubKey) +} + +func (p MemoryProvider) getUserByID(ID int64) (User, error) { + p.dbHandle.lock.Lock() + defer p.dbHandle.lock.Unlock() + if p.dbHandle.isClosed { + return User{}, errMemoryProviderClosed + } + if val, ok := p.dbHandle.usersIdx[ID]; ok { + return p.userExistsInternal(val) + } + return User{}, &RecordNotFoundError{err: fmt.Sprintf("user with ID %v does not exist", ID)} +} + +func (p MemoryProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { + p.dbHandle.lock.Lock() + defer p.dbHandle.lock.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + user, err := p.userExistsInternal(username) + if err != nil { + providerLog(logger.LevelWarn, "unable to update quota for user %v error: %v", username, err) + return err + } + if reset { + user.UsedQuotaSize = sizeAdd + user.UsedQuotaFiles = filesAdd + } else { + user.UsedQuotaSize += sizeAdd + user.UsedQuotaFiles += filesAdd + } + user.LastQuotaUpdate = utils.GetTimeAsMsSinceEpoch(time.Now()) + p.dbHandle.users[user.Username] = user + return nil +} + +func (p MemoryProvider) getUsedQuota(username string) (int, int64, error) { + p.dbHandle.lock.Lock() + defer p.dbHandle.lock.Unlock() + if p.dbHandle.isClosed { + return 0, 0, errMemoryProviderClosed + } + user, err := p.userExistsInternal(username) + if err != nil { + providerLog(logger.LevelWarn, "unable to get quota for user %v error: %v", username, err) + return 0, 0, err + } + return user.UsedQuotaFiles, user.UsedQuotaSize, err +} + +func (p MemoryProvider) addUser(user User) error { + p.dbHandle.lock.Lock() + defer p.dbHandle.lock.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + err := validateUser(&user) + if err != nil { + return err + } + _, err = p.userExistsInternal(user.Username) + if err == nil { + return fmt.Errorf("username %v already exists", user.Username) + } + user.ID = p.getNextID() + p.dbHandle.users[user.Username] = user + p.dbHandle.usersIdx[user.ID] = user.Username + p.dbHandle.usernames = append(p.dbHandle.usernames, user.Username) + sort.Strings(p.dbHandle.usernames) + return nil +} + +func (p MemoryProvider) updateUser(user User) error { + p.dbHandle.lock.Lock() + defer p.dbHandle.lock.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + err := validateUser(&user) + if err != nil { + return err + } + _, err = p.userExistsInternal(user.Username) + if err != nil { + return err + } + p.dbHandle.users[user.Username] = user + return nil +} + +func (p MemoryProvider) deleteUser(user User) error { + p.dbHandle.lock.Lock() + defer p.dbHandle.lock.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + _, err := p.userExistsInternal(user.Username) + if err != nil { + return err + } + delete(p.dbHandle.users, user.Username) + delete(p.dbHandle.usersIdx, user.ID) + // this could be more efficient + p.dbHandle.usernames = []string{} + for username := range p.dbHandle.users { + p.dbHandle.usernames = append(p.dbHandle.usernames, username) + } + sort.Strings(p.dbHandle.usernames) + return nil +} + +func (p MemoryProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { + users := []User{} + var err error + p.dbHandle.lock.Lock() + defer p.dbHandle.lock.Unlock() + if p.dbHandle.isClosed { + return users, errMemoryProviderClosed + } + if limit <= 0 { + return users, err + } + if len(username) > 0 { + if offset == 0 { + user, err := p.userExistsInternal(username) + if err == nil { + user.Password = "" + users = append(users, user) + } + } + return users, err + } + itNum := 0 + if order == "ASC" { + for _, username := range p.dbHandle.usernames { + itNum++ + if itNum <= offset { + continue + } + user := p.dbHandle.users[username] + user.Password = "" + users = append(users, user) + if len(users) >= limit { + break + } + } + } else { + for i := len(p.dbHandle.usernames) - 1; i >= 0; i-- { + itNum++ + if itNum <= offset { + continue + } + username := p.dbHandle.usernames[i] + user := p.dbHandle.users[username] + user.Password = "" + users = append(users, user) + if len(users) >= limit { + break + } + } + } + return users, err +} + +func (p MemoryProvider) userExists(username string) (User, error) { + p.dbHandle.lock.Lock() + defer p.dbHandle.lock.Unlock() + if p.dbHandle.isClosed { + return User{}, errMemoryProviderClosed + } + return p.userExistsInternal(username) +} + +func (p MemoryProvider) userExistsInternal(username string) (User, error) { + if val, ok := p.dbHandle.users[username]; ok { + return val.getACopy(), nil + } + return User{}, &RecordNotFoundError{err: fmt.Sprintf("username %v does not exist", username)} +} + +func (p MemoryProvider) getNextID() int64 { + nextID := int64(1) + for id := range p.dbHandle.usersIdx { + if id >= nextID { + nextID = id + 1 + } + } + return nextID +} diff --git a/dataprovider/user.go b/dataprovider/user.go index 641d33d9..c86796d3 100644 --- a/dataprovider/user.go +++ b/dataprovider/user.go @@ -189,3 +189,28 @@ func (u *User) GetInfoString() string { } return result } + +func (u *User) getACopy() User { + pubKeys := make([]string, len(u.PublicKeys)) + copy(pubKeys, u.PublicKeys) + permissions := make([]string, len(u.Permissions)) + copy(permissions, u.Permissions) + return User{ + ID: u.ID, + Username: u.Username, + Password: u.Password, + PublicKeys: pubKeys, + HomeDir: u.HomeDir, + UID: u.UID, + GID: u.GID, + MaxSessions: u.MaxSessions, + QuotaSize: u.QuotaSize, + QuotaFiles: u.QuotaFiles, + Permissions: permissions, + UsedQuotaSize: u.UsedQuotaSize, + UsedQuotaFiles: u.UsedQuotaFiles, + LastQuotaUpdate: u.LastQuotaUpdate, + UploadBandwidth: u.UploadBandwidth, + DownloadBandwidth: u.DownloadBandwidth, + } +} diff --git a/service/service.go b/service/service.go index 4240d4ca..7a917e9c 100644 --- a/service/service.go +++ b/service/service.go @@ -143,11 +143,9 @@ func (s *Service) StartPortableMode(sftpdPort int, enableSCP bool) error { } tempDir := os.TempDir() instanceID := xid.New().String() - databasePath := filepath.Join(tempDir, instanceID+".db") s.LogFilePath = filepath.Join(tempDir, instanceID+".log") dataProviderConf := config.GetProviderConf() - dataProviderConf.Driver = dataprovider.BoltDataProviderName - dataProviderConf.Name = databasePath + dataProviderConf.Driver = dataprovider.MemoryDataProviderName config.SetProviderConf(dataProviderConf) httpdConf := config.GetHTTPDConfig() httpdConf.BindPort = 0 diff --git a/sftpd/server.go b/sftpd/server.go index 6c7bf30f..0c1c73e6 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -277,8 +277,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server if err := ssh.Unmarshal(req.Payload, &msg); err == nil { name, scpArgs, err := parseCommandPayload(msg.Command) connection.Log(logger.LevelDebug, logSender, "new exec command: %#v args: %v user: %v, error: %v", - name, scpArgs, - connection.User.Username, err) + name, scpArgs, connection.User.Username, err) if err == nil && name == "scp" && len(scpArgs) >= 2 { ok = true connection.protocol = protocolSCP diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index 11e679ff..6289ad08 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -200,7 +200,7 @@ func TestBasicSFTPHandling(t *testing.T) { testFileSize := int64(65535) expectedQuotaSize := user.UsedQuotaSize + testFileSize expectedQuotaFiles := user.UsedQuotaFiles + 1 - err = createTestFile(testFilePath, testFileSize) + createTestFile(testFilePath, testFileSize) err = sftpUploadFile(testFilePath, path.Join("/missing_dir", testFileName), testFileSize, client) if err == nil { t.Errorf("upload a file to a missing dir must fail")