mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-22 07:30:25 +00:00
dataprovider move db handle to provider struct
This is needed to support non SQL providers
This commit is contained in:
parent
51aacae3c5
commit
cb87fe811a
7 changed files with 74 additions and 58 deletions
13
README.md
13
README.md
|
@ -40,7 +40,14 @@ $ go get -u github.com/drakkan/sftpgo
|
|||
|
||||
Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`.
|
||||
|
||||
Version info can be embedded populating the following variables at build time:
|
||||
SFTPGo depends on [go-sqlite3](https://github.com/mattn/go-sqlite3) that is a CGO package and so it requires a `C` compiler at build time.
|
||||
On Linux and macOS a compiler is easy to install or already installed, on Windows you need to download [MinGW-w64](https://sourceforge.net/projects/mingw-w64/files/) and build SFTPGo from it's command prompt.
|
||||
|
||||
The compiler is a build time only dependency, it is not not required at runtime.
|
||||
|
||||
If you don't need SQLite, you can also get/build SFTPGo setting the environment variable `GCO_ENABLED` to 0, this way SQLite support will be disabled but PostgreSQL and MySQL will work and you don't need a `C` compiler for building.
|
||||
|
||||
Version info, such as git commit and build date, can be embedded setting the following string variables at build time:
|
||||
|
||||
- `github.com/drakkan/sftpgo/utils.commit`
|
||||
- `github.com/drakkan/sftpgo/utils.date`
|
||||
|
@ -54,11 +61,11 @@ go build -i -ldflags "-s -w -X github.com/drakkan/sftpgo/utils.commit=`git descr
|
|||
and you will get a version that includes git commit and build date like this one:
|
||||
|
||||
```bash
|
||||
./sftpgo -v
|
||||
sftpgo -v
|
||||
SFTPGo version: 0.9.0-dev-90607d4-dirty-2019-08-08T19:28:36Z
|
||||
```
|
||||
|
||||
A systemd sample [service](https://github.com/drakkan/sftpgo/tree/master/init/sftpgo.service "systemd service") can be found inside the source tree.
|
||||
For Linux, a systemd sample [service](https://github.com/drakkan/sftpgo/tree/master/init/sftpgo.service "systemd service") can be found inside the source tree.
|
||||
|
||||
Alternately you can use distro packages:
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
package dataprovider
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
@ -33,7 +32,6 @@ const (
|
|||
var (
|
||||
// SupportedProviders data provider configured in the sftpgo.conf file must match of these strings
|
||||
SupportedProviders = []string{SQLiteDataProviderName, PGSSQLDataProviderName, MySQLDataProviderName}
|
||||
dbHandle *sql.DB
|
||||
config Config
|
||||
provider Provider
|
||||
sqlPlaceholders []string
|
||||
|
@ -124,13 +122,10 @@ func Initialize(cnf Config, basePath string) error {
|
|||
config = cnf
|
||||
sqlPlaceholders = getSQLPlaceholders()
|
||||
if config.Driver == SQLiteDataProviderName {
|
||||
provider = SQLiteProvider{}
|
||||
return initializeSQLiteProvider(basePath)
|
||||
} else if config.Driver == PGSSQLDataProviderName {
|
||||
provider = PGSQLProvider{}
|
||||
return initializePGSQLProvider()
|
||||
} else if config.Driver == MySQLDataProviderName {
|
||||
provider = MySQLProvider{}
|
||||
return initializeMySQLProvider()
|
||||
}
|
||||
return fmt.Errorf("Unsupported data provider: %v", config.Driver)
|
||||
|
@ -226,7 +221,8 @@ func validateUser(user *User) error {
|
|||
return &ValidationError{err: fmt.Sprintf("Invalid permission: %v", p)}
|
||||
}
|
||||
}
|
||||
if len(user.Password) > 0 && !strings.HasPrefix(user.Password, argonPwdPrefix) {
|
||||
if len(user.Password) > 0 && !strings.HasPrefix(user.Password, argonPwdPrefix) &&
|
||||
!strings.HasPrefix(user.Password, bcryptPwdPrefix) {
|
||||
pwd, err := argon2id.CreateHash(user.Password, argon2id.DefaultParams)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
// MySQLProvider auth provider for MySQL/MariaDB database
|
||||
type MySQLProvider struct {
|
||||
dbHandle *sql.DB
|
||||
}
|
||||
|
||||
func initializeMySQLProvider() error {
|
||||
|
@ -22,13 +23,14 @@ func initializeMySQLProvider() error {
|
|||
} else {
|
||||
connectionString = config.ConnectionString
|
||||
}
|
||||
dbHandle, err = sql.Open("mysql", connectionString)
|
||||
dbHandle, err := sql.Open("mysql", connectionString)
|
||||
if err == nil {
|
||||
numCPU := runtime.NumCPU()
|
||||
logger.Debug(logSender, "mysql database handle created, connection string: '%v', pool size: %v", connectionString, numCPU)
|
||||
dbHandle.SetMaxIdleConns(numCPU)
|
||||
dbHandle.SetMaxOpenConns(numCPU)
|
||||
dbHandle.SetConnMaxLifetime(1800 * time.Second)
|
||||
provider = MySQLProvider{dbHandle: dbHandle}
|
||||
} else {
|
||||
logger.Warn(logSender, "error creating mysql database handler, connection string: '%v', error: %v", connectionString, err)
|
||||
}
|
||||
|
@ -36,24 +38,24 @@ func initializeMySQLProvider() error {
|
|||
}
|
||||
|
||||
func (p MySQLProvider) validateUserAndPass(username string, password string) (User, error) {
|
||||
return sqlCommonValidateUserAndPass(username, password)
|
||||
return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p MySQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
|
||||
return sqlCommonValidateUserAndPubKey(username, publicKey)
|
||||
return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p MySQLProvider) getUserByID(ID int64) (User, error) {
|
||||
return sqlCommonGetUserByID(ID)
|
||||
return sqlCommonGetUserByID(ID, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
|
||||
tx, err := dbHandle.Begin()
|
||||
tx, err := p.dbHandle.Begin()
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err)
|
||||
return err
|
||||
}
|
||||
err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p)
|
||||
err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
|
||||
if err == nil {
|
||||
err = tx.Commit()
|
||||
} else {
|
||||
|
@ -66,25 +68,25 @@ func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64,
|
|||
}
|
||||
|
||||
func (p MySQLProvider) getUsedQuota(username string) (int, int64, error) {
|
||||
return sqlCommonGetUsedQuota(username)
|
||||
return sqlCommonGetUsedQuota(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p MySQLProvider) userExists(username string) (User, error) {
|
||||
return sqlCommonCheckUserExists(username)
|
||||
return sqlCommonCheckUserExists(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p MySQLProvider) addUser(user User) error {
|
||||
return sqlCommonAddUser(user)
|
||||
return sqlCommonAddUser(user, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p MySQLProvider) updateUser(user User) error {
|
||||
return sqlCommonUpdateUser(user)
|
||||
return sqlCommonUpdateUser(user, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p MySQLProvider) deleteUser(user User) error {
|
||||
return sqlCommonDeleteUser(user)
|
||||
return sqlCommonDeleteUser(user, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p MySQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
|
||||
return sqlCommonGetUsers(limit, offset, order, username)
|
||||
return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle)
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
// PGSQLProvider auth provider for PostgreSQL database
|
||||
type PGSQLProvider struct {
|
||||
dbHandle *sql.DB
|
||||
}
|
||||
|
||||
func initializePGSQLProvider() error {
|
||||
|
@ -21,12 +22,13 @@ func initializePGSQLProvider() error {
|
|||
} else {
|
||||
connectionString = config.ConnectionString
|
||||
}
|
||||
dbHandle, err = sql.Open("postgres", connectionString)
|
||||
dbHandle, err := sql.Open("postgres", connectionString)
|
||||
if err == nil {
|
||||
numCPU := runtime.NumCPU()
|
||||
logger.Debug(logSender, "postgres database handle created, connection string: '%v', pool size: %v", connectionString, numCPU)
|
||||
dbHandle.SetMaxIdleConns(numCPU)
|
||||
dbHandle.SetMaxOpenConns(numCPU)
|
||||
provider = PGSQLProvider{dbHandle: dbHandle}
|
||||
} else {
|
||||
logger.Warn(logSender, "error creating postgres database handler, connection string: '%v', error: %v", connectionString, err)
|
||||
}
|
||||
|
@ -34,24 +36,24 @@ func initializePGSQLProvider() error {
|
|||
}
|
||||
|
||||
func (p PGSQLProvider) validateUserAndPass(username string, password string) (User, error) {
|
||||
return sqlCommonValidateUserAndPass(username, password)
|
||||
return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
|
||||
return sqlCommonValidateUserAndPubKey(username, publicKey)
|
||||
return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p PGSQLProvider) getUserByID(ID int64) (User, error) {
|
||||
return sqlCommonGetUserByID(ID)
|
||||
return sqlCommonGetUserByID(ID, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
|
||||
tx, err := dbHandle.Begin()
|
||||
tx, err := p.dbHandle.Begin()
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err)
|
||||
return err
|
||||
}
|
||||
err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p)
|
||||
err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
|
||||
if err == nil {
|
||||
err = tx.Commit()
|
||||
} else {
|
||||
|
@ -64,25 +66,25 @@ func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64,
|
|||
}
|
||||
|
||||
func (p PGSQLProvider) getUsedQuota(username string) (int, int64, error) {
|
||||
return sqlCommonGetUsedQuota(username)
|
||||
return sqlCommonGetUsedQuota(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p PGSQLProvider) userExists(username string) (User, error) {
|
||||
return sqlCommonCheckUserExists(username)
|
||||
return sqlCommonCheckUserExists(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p PGSQLProvider) addUser(user User) error {
|
||||
return sqlCommonAddUser(user)
|
||||
return sqlCommonAddUser(user, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p PGSQLProvider) updateUser(user User) error {
|
||||
return sqlCommonUpdateUser(user)
|
||||
return sqlCommonUpdateUser(user, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p PGSQLProvider) deleteUser(user User) error {
|
||||
return sqlCommonDeleteUser(user)
|
||||
return sqlCommonDeleteUser(user, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p PGSQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
|
||||
return sqlCommonGetUsers(limit, offset, order, username)
|
||||
return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle)
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ import (
|
|||
"github.com/drakkan/sftpgo/utils"
|
||||
)
|
||||
|
||||
func getUserByUsername(username string) (User, error) {
|
||||
func getUserByUsername(username string, dbHandle *sql.DB) (User, error) {
|
||||
var user User
|
||||
q := getUserByUsernameQuery()
|
||||
stmt, err := dbHandle.Prepare(q)
|
||||
|
@ -30,12 +30,12 @@ func getUserByUsername(username string) (User, error) {
|
|||
return getUserFromDbRow(row, nil)
|
||||
}
|
||||
|
||||
func sqlCommonValidateUserAndPass(username string, password string) (User, error) {
|
||||
func sqlCommonValidateUserAndPass(username string, password string, dbHandle *sql.DB) (User, error) {
|
||||
var user User
|
||||
if len(password) == 0 {
|
||||
return user, errors.New("Credentials cannot be null or empty")
|
||||
}
|
||||
user, err := getUserByUsername(username)
|
||||
user, err := getUserByUsername(username, dbHandle)
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err)
|
||||
} else {
|
||||
|
@ -68,12 +68,12 @@ func sqlCommonValidateUserAndPass(username string, password string) (User, error
|
|||
return user, err
|
||||
}
|
||||
|
||||
func sqlCommonValidateUserAndPubKey(username string, pubKey string) (User, error) {
|
||||
func sqlCommonValidateUserAndPubKey(username string, pubKey string, dbHandle *sql.DB) (User, error) {
|
||||
var user User
|
||||
if len(pubKey) == 0 {
|
||||
return user, errors.New("Credentials cannot be null or empty")
|
||||
}
|
||||
user, err := getUserByUsername(username)
|
||||
user, err := getUserByUsername(username, dbHandle)
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err)
|
||||
return user, err
|
||||
|
@ -95,7 +95,7 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey string) (User, error
|
|||
return user, errors.New("Invalid credentials")
|
||||
}
|
||||
|
||||
func sqlCommonGetUserByID(ID int64) (User, error) {
|
||||
func sqlCommonGetUserByID(ID int64, dbHandle *sql.DB) (User, error) {
|
||||
var user User
|
||||
q := getUserByIDQuery()
|
||||
stmt, err := dbHandle.Prepare(q)
|
||||
|
@ -109,7 +109,7 @@ func sqlCommonGetUserByID(ID int64) (User, error) {
|
|||
return getUserFromDbRow(row, nil)
|
||||
}
|
||||
|
||||
func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, p Provider) error {
|
||||
func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
|
||||
q := getUpdateQuotaQuery(reset)
|
||||
stmt, err := dbHandle.Prepare(q)
|
||||
if err != nil {
|
||||
|
@ -127,7 +127,7 @@ func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bo
|
|||
return err
|
||||
}
|
||||
|
||||
func sqlCommonGetUsedQuota(username string) (int, int64, error) {
|
||||
func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, error) {
|
||||
q := getQuotaQuery()
|
||||
stmt, err := dbHandle.Prepare(q)
|
||||
if err != nil {
|
||||
|
@ -146,7 +146,7 @@ func sqlCommonGetUsedQuota(username string) (int, int64, error) {
|
|||
return usedFiles, usedSize, err
|
||||
}
|
||||
|
||||
func sqlCommonCheckUserExists(username string) (User, error) {
|
||||
func sqlCommonCheckUserExists(username string, dbHandle *sql.DB) (User, error) {
|
||||
var user User
|
||||
q := getUserByUsernameQuery()
|
||||
stmt, err := dbHandle.Prepare(q)
|
||||
|
@ -159,7 +159,7 @@ func sqlCommonCheckUserExists(username string) (User, error) {
|
|||
return getUserFromDbRow(row, nil)
|
||||
}
|
||||
|
||||
func sqlCommonAddUser(user User) error {
|
||||
func sqlCommonAddUser(user User, dbHandle *sql.DB) error {
|
||||
err := validateUser(&user)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -184,7 +184,7 @@ func sqlCommonAddUser(user User) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func sqlCommonUpdateUser(user User) error {
|
||||
func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error {
|
||||
err := validateUser(&user)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -209,7 +209,7 @@ func sqlCommonUpdateUser(user User) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func sqlCommonDeleteUser(user User) error {
|
||||
func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error {
|
||||
q := getDeleteUserQuery()
|
||||
stmt, err := dbHandle.Prepare(q)
|
||||
if err != nil {
|
||||
|
@ -221,7 +221,7 @@ func sqlCommonDeleteUser(user User) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func sqlCommonGetUsers(limit int, offset int, order string, username string) ([]User, error) {
|
||||
func sqlCommonGetUsers(limit int, offset int, order string, username string, dbHandle *sql.DB) ([]User, error) {
|
||||
users := []User{}
|
||||
q := getUsersQuery(order, username)
|
||||
stmt, err := dbHandle.Prepare(q)
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
|
||||
// SQLiteProvider auth provider for SQLite database
|
||||
type SQLiteProvider struct {
|
||||
dbHandle *sql.DB
|
||||
}
|
||||
|
||||
func initializeSQLiteProvider(basePath string) error {
|
||||
|
@ -36,10 +37,11 @@ func initializeSQLiteProvider(basePath string) error {
|
|||
} else {
|
||||
connectionString = config.ConnectionString
|
||||
}
|
||||
dbHandle, err = sql.Open("sqlite3", connectionString)
|
||||
dbHandle, err := sql.Open("sqlite3", connectionString)
|
||||
if err == nil {
|
||||
logger.Debug(logSender, "sqlite database handle created, connection string: '%v'", connectionString)
|
||||
dbHandle.SetMaxOpenConns(1)
|
||||
provider = SQLiteProvider{dbHandle: dbHandle}
|
||||
} else {
|
||||
logger.Warn(logSender, "error creating sqlite database handler, connection string: '%v', error: %v", connectionString, err)
|
||||
}
|
||||
|
@ -47,43 +49,43 @@ func initializeSQLiteProvider(basePath string) error {
|
|||
}
|
||||
|
||||
func (p SQLiteProvider) validateUserAndPass(username string, password string) (User, error) {
|
||||
return sqlCommonValidateUserAndPass(username, password)
|
||||
return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
|
||||
return sqlCommonValidateUserAndPubKey(username, publicKey)
|
||||
return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p SQLiteProvider) getUserByID(ID int64) (User, error) {
|
||||
return sqlCommonGetUserByID(ID)
|
||||
return sqlCommonGetUserByID(ID, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p SQLiteProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
|
||||
// we keep only 1 open connection (SetMaxOpenConns(1)) so a transaction is not needed and it could block
|
||||
// the database access since it will try to open a new connection
|
||||
return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p)
|
||||
return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p SQLiteProvider) getUsedQuota(username string) (int, int64, error) {
|
||||
return sqlCommonGetUsedQuota(username)
|
||||
return sqlCommonGetUsedQuota(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p SQLiteProvider) userExists(username string) (User, error) {
|
||||
return sqlCommonCheckUserExists(username)
|
||||
return sqlCommonCheckUserExists(username, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p SQLiteProvider) addUser(user User) error {
|
||||
return sqlCommonAddUser(user)
|
||||
return sqlCommonAddUser(user, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p SQLiteProvider) updateUser(user User) error {
|
||||
return sqlCommonUpdateUser(user)
|
||||
return sqlCommonUpdateUser(user, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p SQLiteProvider) deleteUser(user User) error {
|
||||
return sqlCommonDeleteUser(user)
|
||||
return sqlCommonDeleteUser(user, p.dbHandle)
|
||||
}
|
||||
|
||||
func (p SQLiteProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
|
||||
return sqlCommonGetUsers(limit, offset, order, username)
|
||||
return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle)
|
||||
}
|
||||
|
|
|
@ -99,11 +99,18 @@ func TestUploadFiles(t *testing.T) {
|
|||
uploadMode = oldUploadMode
|
||||
}
|
||||
|
||||
func TestLoginWithInvalidHome(t *testing.T) {
|
||||
func TestWithInvalidHome(t *testing.T) {
|
||||
u := dataprovider.User{}
|
||||
u.HomeDir = "home_rel_path"
|
||||
_, err := loginUser(u)
|
||||
if err == nil {
|
||||
t.Errorf("login a user with an invalid home_dir must fail")
|
||||
}
|
||||
c := Connection{
|
||||
User: u,
|
||||
}
|
||||
err = c.isSubDir("dir_rel_path")
|
||||
if err == nil {
|
||||
t.Errorf("tested path is not a home subdir")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue