dataprovider move db handle to provider struct

This is needed to support non SQL providers
This commit is contained in:
Nicola Murino 2019-08-11 14:53:37 +02:00
parent 51aacae3c5
commit cb87fe811a
7 changed files with 74 additions and 58 deletions

View file

@ -40,7 +40,14 @@ $ go get -u github.com/drakkan/sftpgo
Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`.
Version info can be embedded populating the following variables at build time:
SFTPGo depends on [go-sqlite3](https://github.com/mattn/go-sqlite3) that is a CGO package and so it requires a `C` compiler at build time.
On Linux and macOS a compiler is easy to install or already installed, on Windows you need to download [MinGW-w64](https://sourceforge.net/projects/mingw-w64/files/) and build SFTPGo from it's command prompt.
The compiler is a build time only dependency, it is not not required at runtime.
If you don't need SQLite, you can also get/build SFTPGo setting the environment variable `GCO_ENABLED` to 0, this way SQLite support will be disabled but PostgreSQL and MySQL will work and you don't need a `C` compiler for building.
Version info, such as git commit and build date, can be embedded setting the following string variables at build time:
- `github.com/drakkan/sftpgo/utils.commit`
- `github.com/drakkan/sftpgo/utils.date`
@ -54,11 +61,11 @@ go build -i -ldflags "-s -w -X github.com/drakkan/sftpgo/utils.commit=`git descr
and you will get a version that includes git commit and build date like this one:
```bash
./sftpgo -v
sftpgo -v
SFTPGo version: 0.9.0-dev-90607d4-dirty-2019-08-08T19:28:36Z
```
A systemd sample [service](https://github.com/drakkan/sftpgo/tree/master/init/sftpgo.service "systemd service") can be found inside the source tree.
For Linux, a systemd sample [service](https://github.com/drakkan/sftpgo/tree/master/init/sftpgo.service "systemd service") can be found inside the source tree.
Alternately you can use distro packages:

View file

@ -4,7 +4,6 @@
package dataprovider
import (
"database/sql"
"fmt"
"path/filepath"
"strings"
@ -33,7 +32,6 @@ const (
var (
// SupportedProviders data provider configured in the sftpgo.conf file must match of these strings
SupportedProviders = []string{SQLiteDataProviderName, PGSSQLDataProviderName, MySQLDataProviderName}
dbHandle *sql.DB
config Config
provider Provider
sqlPlaceholders []string
@ -124,13 +122,10 @@ func Initialize(cnf Config, basePath string) error {
config = cnf
sqlPlaceholders = getSQLPlaceholders()
if config.Driver == SQLiteDataProviderName {
provider = SQLiteProvider{}
return initializeSQLiteProvider(basePath)
} else if config.Driver == PGSSQLDataProviderName {
provider = PGSQLProvider{}
return initializePGSQLProvider()
} else if config.Driver == MySQLDataProviderName {
provider = MySQLProvider{}
return initializeMySQLProvider()
}
return fmt.Errorf("Unsupported data provider: %v", config.Driver)
@ -226,7 +221,8 @@ func validateUser(user *User) error {
return &ValidationError{err: fmt.Sprintf("Invalid permission: %v", p)}
}
}
if len(user.Password) > 0 && !strings.HasPrefix(user.Password, argonPwdPrefix) {
if len(user.Password) > 0 && !strings.HasPrefix(user.Password, argonPwdPrefix) &&
!strings.HasPrefix(user.Password, bcryptPwdPrefix) {
pwd, err := argon2id.CreateHash(user.Password, argon2id.DefaultParams)
if err != nil {
return err

View file

@ -11,6 +11,7 @@ import (
// MySQLProvider auth provider for MySQL/MariaDB database
type MySQLProvider struct {
dbHandle *sql.DB
}
func initializeMySQLProvider() error {
@ -22,13 +23,14 @@ func initializeMySQLProvider() error {
} else {
connectionString = config.ConnectionString
}
dbHandle, err = sql.Open("mysql", connectionString)
dbHandle, err := sql.Open("mysql", connectionString)
if err == nil {
numCPU := runtime.NumCPU()
logger.Debug(logSender, "mysql database handle created, connection string: '%v', pool size: %v", connectionString, numCPU)
dbHandle.SetMaxIdleConns(numCPU)
dbHandle.SetMaxOpenConns(numCPU)
dbHandle.SetConnMaxLifetime(1800 * time.Second)
provider = MySQLProvider{dbHandle: dbHandle}
} else {
logger.Warn(logSender, "error creating mysql database handler, connection string: '%v', error: %v", connectionString, err)
}
@ -36,24 +38,24 @@ func initializeMySQLProvider() error {
}
func (p MySQLProvider) validateUserAndPass(username string, password string) (User, error) {
return sqlCommonValidateUserAndPass(username, password)
return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
}
func (p MySQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
return sqlCommonValidateUserAndPubKey(username, publicKey)
return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
}
func (p MySQLProvider) getUserByID(ID int64) (User, error) {
return sqlCommonGetUserByID(ID)
return sqlCommonGetUserByID(ID, p.dbHandle)
}
func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
tx, err := dbHandle.Begin()
tx, err := p.dbHandle.Begin()
if err != nil {
logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err)
return err
}
err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p)
err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
if err == nil {
err = tx.Commit()
} else {
@ -66,25 +68,25 @@ func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64,
}
func (p MySQLProvider) getUsedQuota(username string) (int, int64, error) {
return sqlCommonGetUsedQuota(username)
return sqlCommonGetUsedQuota(username, p.dbHandle)
}
func (p MySQLProvider) userExists(username string) (User, error) {
return sqlCommonCheckUserExists(username)
return sqlCommonCheckUserExists(username, p.dbHandle)
}
func (p MySQLProvider) addUser(user User) error {
return sqlCommonAddUser(user)
return sqlCommonAddUser(user, p.dbHandle)
}
func (p MySQLProvider) updateUser(user User) error {
return sqlCommonUpdateUser(user)
return sqlCommonUpdateUser(user, p.dbHandle)
}
func (p MySQLProvider) deleteUser(user User) error {
return sqlCommonDeleteUser(user)
return sqlCommonDeleteUser(user, p.dbHandle)
}
func (p MySQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
return sqlCommonGetUsers(limit, offset, order, username)
return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle)
}

View file

@ -10,6 +10,7 @@ import (
// PGSQLProvider auth provider for PostgreSQL database
type PGSQLProvider struct {
dbHandle *sql.DB
}
func initializePGSQLProvider() error {
@ -21,12 +22,13 @@ func initializePGSQLProvider() error {
} else {
connectionString = config.ConnectionString
}
dbHandle, err = sql.Open("postgres", connectionString)
dbHandle, err := sql.Open("postgres", connectionString)
if err == nil {
numCPU := runtime.NumCPU()
logger.Debug(logSender, "postgres database handle created, connection string: '%v', pool size: %v", connectionString, numCPU)
dbHandle.SetMaxIdleConns(numCPU)
dbHandle.SetMaxOpenConns(numCPU)
provider = PGSQLProvider{dbHandle: dbHandle}
} else {
logger.Warn(logSender, "error creating postgres database handler, connection string: '%v', error: %v", connectionString, err)
}
@ -34,24 +36,24 @@ func initializePGSQLProvider() error {
}
func (p PGSQLProvider) validateUserAndPass(username string, password string) (User, error) {
return sqlCommonValidateUserAndPass(username, password)
return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
}
func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
return sqlCommonValidateUserAndPubKey(username, publicKey)
return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
}
func (p PGSQLProvider) getUserByID(ID int64) (User, error) {
return sqlCommonGetUserByID(ID)
return sqlCommonGetUserByID(ID, p.dbHandle)
}
func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
tx, err := dbHandle.Begin()
tx, err := p.dbHandle.Begin()
if err != nil {
logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err)
return err
}
err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p)
err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
if err == nil {
err = tx.Commit()
} else {
@ -64,25 +66,25 @@ func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64,
}
func (p PGSQLProvider) getUsedQuota(username string) (int, int64, error) {
return sqlCommonGetUsedQuota(username)
return sqlCommonGetUsedQuota(username, p.dbHandle)
}
func (p PGSQLProvider) userExists(username string) (User, error) {
return sqlCommonCheckUserExists(username)
return sqlCommonCheckUserExists(username, p.dbHandle)
}
func (p PGSQLProvider) addUser(user User) error {
return sqlCommonAddUser(user)
return sqlCommonAddUser(user, p.dbHandle)
}
func (p PGSQLProvider) updateUser(user User) error {
return sqlCommonUpdateUser(user)
return sqlCommonUpdateUser(user, p.dbHandle)
}
func (p PGSQLProvider) deleteUser(user User) error {
return sqlCommonDeleteUser(user)
return sqlCommonDeleteUser(user, p.dbHandle)
}
func (p PGSQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
return sqlCommonGetUsers(limit, offset, order, username)
return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle)
}

View file

@ -16,7 +16,7 @@ import (
"github.com/drakkan/sftpgo/utils"
)
func getUserByUsername(username string) (User, error) {
func getUserByUsername(username string, dbHandle *sql.DB) (User, error) {
var user User
q := getUserByUsernameQuery()
stmt, err := dbHandle.Prepare(q)
@ -30,12 +30,12 @@ func getUserByUsername(username string) (User, error) {
return getUserFromDbRow(row, nil)
}
func sqlCommonValidateUserAndPass(username string, password string) (User, error) {
func sqlCommonValidateUserAndPass(username string, password string, dbHandle *sql.DB) (User, error) {
var user User
if len(password) == 0 {
return user, errors.New("Credentials cannot be null or empty")
}
user, err := getUserByUsername(username)
user, err := getUserByUsername(username, dbHandle)
if err != nil {
logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err)
} else {
@ -68,12 +68,12 @@ func sqlCommonValidateUserAndPass(username string, password string) (User, error
return user, err
}
func sqlCommonValidateUserAndPubKey(username string, pubKey string) (User, error) {
func sqlCommonValidateUserAndPubKey(username string, pubKey string, dbHandle *sql.DB) (User, error) {
var user User
if len(pubKey) == 0 {
return user, errors.New("Credentials cannot be null or empty")
}
user, err := getUserByUsername(username)
user, err := getUserByUsername(username, dbHandle)
if err != nil {
logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err)
return user, err
@ -95,7 +95,7 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey string) (User, error
return user, errors.New("Invalid credentials")
}
func sqlCommonGetUserByID(ID int64) (User, error) {
func sqlCommonGetUserByID(ID int64, dbHandle *sql.DB) (User, error) {
var user User
q := getUserByIDQuery()
stmt, err := dbHandle.Prepare(q)
@ -109,7 +109,7 @@ func sqlCommonGetUserByID(ID int64) (User, error) {
return getUserFromDbRow(row, nil)
}
func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, p Provider) error {
func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
q := getUpdateQuotaQuery(reset)
stmt, err := dbHandle.Prepare(q)
if err != nil {
@ -127,7 +127,7 @@ func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bo
return err
}
func sqlCommonGetUsedQuota(username string) (int, int64, error) {
func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, error) {
q := getQuotaQuery()
stmt, err := dbHandle.Prepare(q)
if err != nil {
@ -146,7 +146,7 @@ func sqlCommonGetUsedQuota(username string) (int, int64, error) {
return usedFiles, usedSize, err
}
func sqlCommonCheckUserExists(username string) (User, error) {
func sqlCommonCheckUserExists(username string, dbHandle *sql.DB) (User, error) {
var user User
q := getUserByUsernameQuery()
stmt, err := dbHandle.Prepare(q)
@ -159,7 +159,7 @@ func sqlCommonCheckUserExists(username string) (User, error) {
return getUserFromDbRow(row, nil)
}
func sqlCommonAddUser(user User) error {
func sqlCommonAddUser(user User, dbHandle *sql.DB) error {
err := validateUser(&user)
if err != nil {
return err
@ -184,7 +184,7 @@ func sqlCommonAddUser(user User) error {
return err
}
func sqlCommonUpdateUser(user User) error {
func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error {
err := validateUser(&user)
if err != nil {
return err
@ -209,7 +209,7 @@ func sqlCommonUpdateUser(user User) error {
return err
}
func sqlCommonDeleteUser(user User) error {
func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error {
q := getDeleteUserQuery()
stmt, err := dbHandle.Prepare(q)
if err != nil {
@ -221,7 +221,7 @@ func sqlCommonDeleteUser(user User) error {
return err
}
func sqlCommonGetUsers(limit int, offset int, order string, username string) ([]User, error) {
func sqlCommonGetUsers(limit int, offset int, order string, username string, dbHandle *sql.DB) ([]User, error) {
users := []User{}
q := getUsersQuery(order, username)
stmt, err := dbHandle.Prepare(q)

View file

@ -12,6 +12,7 @@ import (
// SQLiteProvider auth provider for SQLite database
type SQLiteProvider struct {
dbHandle *sql.DB
}
func initializeSQLiteProvider(basePath string) error {
@ -36,10 +37,11 @@ func initializeSQLiteProvider(basePath string) error {
} else {
connectionString = config.ConnectionString
}
dbHandle, err = sql.Open("sqlite3", connectionString)
dbHandle, err := sql.Open("sqlite3", connectionString)
if err == nil {
logger.Debug(logSender, "sqlite database handle created, connection string: '%v'", connectionString)
dbHandle.SetMaxOpenConns(1)
provider = SQLiteProvider{dbHandle: dbHandle}
} else {
logger.Warn(logSender, "error creating sqlite database handler, connection string: '%v', error: %v", connectionString, err)
}
@ -47,43 +49,43 @@ func initializeSQLiteProvider(basePath string) error {
}
func (p SQLiteProvider) validateUserAndPass(username string, password string) (User, error) {
return sqlCommonValidateUserAndPass(username, password)
return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
}
func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
return sqlCommonValidateUserAndPubKey(username, publicKey)
return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
}
func (p SQLiteProvider) getUserByID(ID int64) (User, error) {
return sqlCommonGetUserByID(ID)
return sqlCommonGetUserByID(ID, p.dbHandle)
}
func (p SQLiteProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
// we keep only 1 open connection (SetMaxOpenConns(1)) so a transaction is not needed and it could block
// the database access since it will try to open a new connection
return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p)
return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
}
func (p SQLiteProvider) getUsedQuota(username string) (int, int64, error) {
return sqlCommonGetUsedQuota(username)
return sqlCommonGetUsedQuota(username, p.dbHandle)
}
func (p SQLiteProvider) userExists(username string) (User, error) {
return sqlCommonCheckUserExists(username)
return sqlCommonCheckUserExists(username, p.dbHandle)
}
func (p SQLiteProvider) addUser(user User) error {
return sqlCommonAddUser(user)
return sqlCommonAddUser(user, p.dbHandle)
}
func (p SQLiteProvider) updateUser(user User) error {
return sqlCommonUpdateUser(user)
return sqlCommonUpdateUser(user, p.dbHandle)
}
func (p SQLiteProvider) deleteUser(user User) error {
return sqlCommonDeleteUser(user)
return sqlCommonDeleteUser(user, p.dbHandle)
}
func (p SQLiteProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
return sqlCommonGetUsers(limit, offset, order, username)
return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle)
}

View file

@ -99,11 +99,18 @@ func TestUploadFiles(t *testing.T) {
uploadMode = oldUploadMode
}
func TestLoginWithInvalidHome(t *testing.T) {
func TestWithInvalidHome(t *testing.T) {
u := dataprovider.User{}
u.HomeDir = "home_rel_path"
_, err := loginUser(u)
if err == nil {
t.Errorf("login a user with an invalid home_dir must fail")
}
c := Connection{
User: u,
}
err = c.isSubDir("dir_rel_path")
if err == nil {
t.Errorf("tested path is not a home subdir")
}
}