From cb87fe811a87a0f4a96df49c11180a235da4d2eb Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sun, 11 Aug 2019 14:53:37 +0200 Subject: [PATCH] dataprovider move db handle to provider struct This is needed to support non SQL providers --- README.md | 13 ++++++++++--- dataprovider/dataprovider.go | 8 ++------ dataprovider/mysql.go | 26 ++++++++++++++------------ dataprovider/pgsql.go | 26 ++++++++++++++------------ dataprovider/sqlcommon.go | 26 +++++++++++++------------- dataprovider/sqlite.go | 24 +++++++++++++----------- sftpd/internal_test.go | 9 ++++++++- 7 files changed, 74 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 547c8645..afa90a77 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,14 @@ $ go get -u github.com/drakkan/sftpgo Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`. -Version info can be embedded populating the following variables at build time: +SFTPGo depends on [go-sqlite3](https://github.com/mattn/go-sqlite3) that is a CGO package and so it requires a `C` compiler at build time. +On Linux and macOS a compiler is easy to install or already installed, on Windows you need to download [MinGW-w64](https://sourceforge.net/projects/mingw-w64/files/) and build SFTPGo from it's command prompt. + +The compiler is a build time only dependency, it is not not required at runtime. + +If you don't need SQLite, you can also get/build SFTPGo setting the environment variable `GCO_ENABLED` to 0, this way SQLite support will be disabled but PostgreSQL and MySQL will work and you don't need a `C` compiler for building. + +Version info, such as git commit and build date, can be embedded setting the following string variables at build time: - `github.com/drakkan/sftpgo/utils.commit` - `github.com/drakkan/sftpgo/utils.date` @@ -54,11 +61,11 @@ go build -i -ldflags "-s -w -X github.com/drakkan/sftpgo/utils.commit=`git descr and you will get a version that includes git commit and build date like this one: ```bash -./sftpgo -v +sftpgo -v SFTPGo version: 0.9.0-dev-90607d4-dirty-2019-08-08T19:28:36Z ``` -A systemd sample [service](https://github.com/drakkan/sftpgo/tree/master/init/sftpgo.service "systemd service") can be found inside the source tree. +For Linux, a systemd sample [service](https://github.com/drakkan/sftpgo/tree/master/init/sftpgo.service "systemd service") can be found inside the source tree. Alternately you can use distro packages: diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index 0a92d95e..bc0b2e65 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -4,7 +4,6 @@ package dataprovider import ( - "database/sql" "fmt" "path/filepath" "strings" @@ -33,7 +32,6 @@ const ( var ( // SupportedProviders data provider configured in the sftpgo.conf file must match of these strings SupportedProviders = []string{SQLiteDataProviderName, PGSSQLDataProviderName, MySQLDataProviderName} - dbHandle *sql.DB config Config provider Provider sqlPlaceholders []string @@ -124,13 +122,10 @@ func Initialize(cnf Config, basePath string) error { config = cnf sqlPlaceholders = getSQLPlaceholders() if config.Driver == SQLiteDataProviderName { - provider = SQLiteProvider{} return initializeSQLiteProvider(basePath) } else if config.Driver == PGSSQLDataProviderName { - provider = PGSQLProvider{} return initializePGSQLProvider() } else if config.Driver == MySQLDataProviderName { - provider = MySQLProvider{} return initializeMySQLProvider() } return fmt.Errorf("Unsupported data provider: %v", config.Driver) @@ -226,7 +221,8 @@ func validateUser(user *User) error { return &ValidationError{err: fmt.Sprintf("Invalid permission: %v", p)} } } - if len(user.Password) > 0 && !strings.HasPrefix(user.Password, argonPwdPrefix) { + if len(user.Password) > 0 && !strings.HasPrefix(user.Password, argonPwdPrefix) && + !strings.HasPrefix(user.Password, bcryptPwdPrefix) { pwd, err := argon2id.CreateHash(user.Password, argon2id.DefaultParams) if err != nil { return err diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go index 7fc3fbfa..92ba0992 100644 --- a/dataprovider/mysql.go +++ b/dataprovider/mysql.go @@ -11,6 +11,7 @@ import ( // MySQLProvider auth provider for MySQL/MariaDB database type MySQLProvider struct { + dbHandle *sql.DB } func initializeMySQLProvider() error { @@ -22,13 +23,14 @@ func initializeMySQLProvider() error { } else { connectionString = config.ConnectionString } - dbHandle, err = sql.Open("mysql", connectionString) + dbHandle, err := sql.Open("mysql", connectionString) if err == nil { numCPU := runtime.NumCPU() logger.Debug(logSender, "mysql database handle created, connection string: '%v', pool size: %v", connectionString, numCPU) dbHandle.SetMaxIdleConns(numCPU) dbHandle.SetMaxOpenConns(numCPU) dbHandle.SetConnMaxLifetime(1800 * time.Second) + provider = MySQLProvider{dbHandle: dbHandle} } else { logger.Warn(logSender, "error creating mysql database handler, connection string: '%v', error: %v", connectionString, err) } @@ -36,24 +38,24 @@ func initializeMySQLProvider() error { } func (p MySQLProvider) validateUserAndPass(username string, password string) (User, error) { - return sqlCommonValidateUserAndPass(username, password) + return sqlCommonValidateUserAndPass(username, password, p.dbHandle) } func (p MySQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) { - return sqlCommonValidateUserAndPubKey(username, publicKey) + return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle) } func (p MySQLProvider) getUserByID(ID int64) (User, error) { - return sqlCommonGetUserByID(ID) + return sqlCommonGetUserByID(ID, p.dbHandle) } func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { - tx, err := dbHandle.Begin() + tx, err := p.dbHandle.Begin() if err != nil { logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err) return err } - err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p) + err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) if err == nil { err = tx.Commit() } else { @@ -66,25 +68,25 @@ func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, } func (p MySQLProvider) getUsedQuota(username string) (int, int64, error) { - return sqlCommonGetUsedQuota(username) + return sqlCommonGetUsedQuota(username, p.dbHandle) } func (p MySQLProvider) userExists(username string) (User, error) { - return sqlCommonCheckUserExists(username) + return sqlCommonCheckUserExists(username, p.dbHandle) } func (p MySQLProvider) addUser(user User) error { - return sqlCommonAddUser(user) + return sqlCommonAddUser(user, p.dbHandle) } func (p MySQLProvider) updateUser(user User) error { - return sqlCommonUpdateUser(user) + return sqlCommonUpdateUser(user, p.dbHandle) } func (p MySQLProvider) deleteUser(user User) error { - return sqlCommonDeleteUser(user) + return sqlCommonDeleteUser(user, p.dbHandle) } func (p MySQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { - return sqlCommonGetUsers(limit, offset, order, username) + return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle) } diff --git a/dataprovider/pgsql.go b/dataprovider/pgsql.go index fa839fa4..20a01125 100644 --- a/dataprovider/pgsql.go +++ b/dataprovider/pgsql.go @@ -10,6 +10,7 @@ import ( // PGSQLProvider auth provider for PostgreSQL database type PGSQLProvider struct { + dbHandle *sql.DB } func initializePGSQLProvider() error { @@ -21,12 +22,13 @@ func initializePGSQLProvider() error { } else { connectionString = config.ConnectionString } - dbHandle, err = sql.Open("postgres", connectionString) + dbHandle, err := sql.Open("postgres", connectionString) if err == nil { numCPU := runtime.NumCPU() logger.Debug(logSender, "postgres database handle created, connection string: '%v', pool size: %v", connectionString, numCPU) dbHandle.SetMaxIdleConns(numCPU) dbHandle.SetMaxOpenConns(numCPU) + provider = PGSQLProvider{dbHandle: dbHandle} } else { logger.Warn(logSender, "error creating postgres database handler, connection string: '%v', error: %v", connectionString, err) } @@ -34,24 +36,24 @@ func initializePGSQLProvider() error { } func (p PGSQLProvider) validateUserAndPass(username string, password string) (User, error) { - return sqlCommonValidateUserAndPass(username, password) + return sqlCommonValidateUserAndPass(username, password, p.dbHandle) } func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) { - return sqlCommonValidateUserAndPubKey(username, publicKey) + return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle) } func (p PGSQLProvider) getUserByID(ID int64) (User, error) { - return sqlCommonGetUserByID(ID) + return sqlCommonGetUserByID(ID, p.dbHandle) } func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { - tx, err := dbHandle.Begin() + tx, err := p.dbHandle.Begin() if err != nil { logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err) return err } - err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p) + err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) if err == nil { err = tx.Commit() } else { @@ -64,25 +66,25 @@ func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, } func (p PGSQLProvider) getUsedQuota(username string) (int, int64, error) { - return sqlCommonGetUsedQuota(username) + return sqlCommonGetUsedQuota(username, p.dbHandle) } func (p PGSQLProvider) userExists(username string) (User, error) { - return sqlCommonCheckUserExists(username) + return sqlCommonCheckUserExists(username, p.dbHandle) } func (p PGSQLProvider) addUser(user User) error { - return sqlCommonAddUser(user) + return sqlCommonAddUser(user, p.dbHandle) } func (p PGSQLProvider) updateUser(user User) error { - return sqlCommonUpdateUser(user) + return sqlCommonUpdateUser(user, p.dbHandle) } func (p PGSQLProvider) deleteUser(user User) error { - return sqlCommonDeleteUser(user) + return sqlCommonDeleteUser(user, p.dbHandle) } func (p PGSQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { - return sqlCommonGetUsers(limit, offset, order, username) + return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle) } diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go index 31bf88f1..5ee10697 100644 --- a/dataprovider/sqlcommon.go +++ b/dataprovider/sqlcommon.go @@ -16,7 +16,7 @@ import ( "github.com/drakkan/sftpgo/utils" ) -func getUserByUsername(username string) (User, error) { +func getUserByUsername(username string, dbHandle *sql.DB) (User, error) { var user User q := getUserByUsernameQuery() stmt, err := dbHandle.Prepare(q) @@ -30,12 +30,12 @@ func getUserByUsername(username string) (User, error) { return getUserFromDbRow(row, nil) } -func sqlCommonValidateUserAndPass(username string, password string) (User, error) { +func sqlCommonValidateUserAndPass(username string, password string, dbHandle *sql.DB) (User, error) { var user User if len(password) == 0 { return user, errors.New("Credentials cannot be null or empty") } - user, err := getUserByUsername(username) + user, err := getUserByUsername(username, dbHandle) if err != nil { logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err) } else { @@ -68,12 +68,12 @@ func sqlCommonValidateUserAndPass(username string, password string) (User, error return user, err } -func sqlCommonValidateUserAndPubKey(username string, pubKey string) (User, error) { +func sqlCommonValidateUserAndPubKey(username string, pubKey string, dbHandle *sql.DB) (User, error) { var user User if len(pubKey) == 0 { return user, errors.New("Credentials cannot be null or empty") } - user, err := getUserByUsername(username) + user, err := getUserByUsername(username, dbHandle) if err != nil { logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err) return user, err @@ -95,7 +95,7 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey string) (User, error return user, errors.New("Invalid credentials") } -func sqlCommonGetUserByID(ID int64) (User, error) { +func sqlCommonGetUserByID(ID int64, dbHandle *sql.DB) (User, error) { var user User q := getUserByIDQuery() stmt, err := dbHandle.Prepare(q) @@ -109,7 +109,7 @@ func sqlCommonGetUserByID(ID int64) (User, error) { return getUserFromDbRow(row, nil) } -func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, p Provider) error { +func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error { q := getUpdateQuotaQuery(reset) stmt, err := dbHandle.Prepare(q) if err != nil { @@ -127,7 +127,7 @@ func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bo return err } -func sqlCommonGetUsedQuota(username string) (int, int64, error) { +func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, error) { q := getQuotaQuery() stmt, err := dbHandle.Prepare(q) if err != nil { @@ -146,7 +146,7 @@ func sqlCommonGetUsedQuota(username string) (int, int64, error) { return usedFiles, usedSize, err } -func sqlCommonCheckUserExists(username string) (User, error) { +func sqlCommonCheckUserExists(username string, dbHandle *sql.DB) (User, error) { var user User q := getUserByUsernameQuery() stmt, err := dbHandle.Prepare(q) @@ -159,7 +159,7 @@ func sqlCommonCheckUserExists(username string) (User, error) { return getUserFromDbRow(row, nil) } -func sqlCommonAddUser(user User) error { +func sqlCommonAddUser(user User, dbHandle *sql.DB) error { err := validateUser(&user) if err != nil { return err @@ -184,7 +184,7 @@ func sqlCommonAddUser(user User) error { return err } -func sqlCommonUpdateUser(user User) error { +func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error { err := validateUser(&user) if err != nil { return err @@ -209,7 +209,7 @@ func sqlCommonUpdateUser(user User) error { return err } -func sqlCommonDeleteUser(user User) error { +func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error { q := getDeleteUserQuery() stmt, err := dbHandle.Prepare(q) if err != nil { @@ -221,7 +221,7 @@ func sqlCommonDeleteUser(user User) error { return err } -func sqlCommonGetUsers(limit int, offset int, order string, username string) ([]User, error) { +func sqlCommonGetUsers(limit int, offset int, order string, username string, dbHandle *sql.DB) ([]User, error) { users := []User{} q := getUsersQuery(order, username) stmt, err := dbHandle.Prepare(q) diff --git a/dataprovider/sqlite.go b/dataprovider/sqlite.go index c6cddc66..b032b59d 100644 --- a/dataprovider/sqlite.go +++ b/dataprovider/sqlite.go @@ -12,6 +12,7 @@ import ( // SQLiteProvider auth provider for SQLite database type SQLiteProvider struct { + dbHandle *sql.DB } func initializeSQLiteProvider(basePath string) error { @@ -36,10 +37,11 @@ func initializeSQLiteProvider(basePath string) error { } else { connectionString = config.ConnectionString } - dbHandle, err = sql.Open("sqlite3", connectionString) + dbHandle, err := sql.Open("sqlite3", connectionString) if err == nil { logger.Debug(logSender, "sqlite database handle created, connection string: '%v'", connectionString) dbHandle.SetMaxOpenConns(1) + provider = SQLiteProvider{dbHandle: dbHandle} } else { logger.Warn(logSender, "error creating sqlite database handler, connection string: '%v', error: %v", connectionString, err) } @@ -47,43 +49,43 @@ func initializeSQLiteProvider(basePath string) error { } func (p SQLiteProvider) validateUserAndPass(username string, password string) (User, error) { - return sqlCommonValidateUserAndPass(username, password) + return sqlCommonValidateUserAndPass(username, password, p.dbHandle) } func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey string) (User, error) { - return sqlCommonValidateUserAndPubKey(username, publicKey) + return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle) } func (p SQLiteProvider) getUserByID(ID int64) (User, error) { - return sqlCommonGetUserByID(ID) + return sqlCommonGetUserByID(ID, p.dbHandle) } func (p SQLiteProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { // we keep only 1 open connection (SetMaxOpenConns(1)) so a transaction is not needed and it could block // the database access since it will try to open a new connection - return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p) + return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) } func (p SQLiteProvider) getUsedQuota(username string) (int, int64, error) { - return sqlCommonGetUsedQuota(username) + return sqlCommonGetUsedQuota(username, p.dbHandle) } func (p SQLiteProvider) userExists(username string) (User, error) { - return sqlCommonCheckUserExists(username) + return sqlCommonCheckUserExists(username, p.dbHandle) } func (p SQLiteProvider) addUser(user User) error { - return sqlCommonAddUser(user) + return sqlCommonAddUser(user, p.dbHandle) } func (p SQLiteProvider) updateUser(user User) error { - return sqlCommonUpdateUser(user) + return sqlCommonUpdateUser(user, p.dbHandle) } func (p SQLiteProvider) deleteUser(user User) error { - return sqlCommonDeleteUser(user) + return sqlCommonDeleteUser(user, p.dbHandle) } func (p SQLiteProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { - return sqlCommonGetUsers(limit, offset, order, username) + return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle) } diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 0620edb0..47f2b019 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -99,11 +99,18 @@ func TestUploadFiles(t *testing.T) { uploadMode = oldUploadMode } -func TestLoginWithInvalidHome(t *testing.T) { +func TestWithInvalidHome(t *testing.T) { u := dataprovider.User{} u.HomeDir = "home_rel_path" _, err := loginUser(u) if err == nil { t.Errorf("login a user with an invalid home_dir must fail") } + c := Connection{ + User: u, + } + err = c.isSubDir("dir_rel_path") + if err == nil { + t.Errorf("tested path is not a home subdir") + } }