Browse Source

dataprovider move db handle to provider struct

This is needed to support non SQL providers
Nicola Murino 6 years ago
parent
commit
cb87fe811a
7 changed files with 74 additions and 58 deletions
  1. 10 3
      README.md
  2. 2 6
      dataprovider/dataprovider.go
  3. 14 12
      dataprovider/mysql.go
  4. 14 12
      dataprovider/pgsql.go
  5. 13 13
      dataprovider/sqlcommon.go
  6. 13 11
      dataprovider/sqlite.go
  7. 8 1
      sftpd/internal_test.go

+ 10 - 3
README.md

@@ -40,7 +40,14 @@ $ go get -u github.com/drakkan/sftpgo
 
 Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`.
 
-Version info can be embedded populating the following variables at build time:
+SFTPGo depends on [go-sqlite3](https://github.com/mattn/go-sqlite3) that is a CGO package and so it requires a `C` compiler at build time.
+On Linux and macOS a compiler is easy to install or already installed, on Windows you need to download [MinGW-w64](https://sourceforge.net/projects/mingw-w64/files/) and build SFTPGo from it's command prompt.
+
+The compiler is a build time only dependency, it is not not required at runtime.
+
+If you don't need SQLite, you can also get/build SFTPGo setting the environment variable `GCO_ENABLED` to 0, this way SQLite support will be disabled but PostgreSQL and MySQL will work and you don't need a `C` compiler for building.
+
+Version info, such as git commit and build date, can be embedded setting the following string variables at build time:
 
 - `github.com/drakkan/sftpgo/utils.commit`
 - `github.com/drakkan/sftpgo/utils.date`
@@ -54,11 +61,11 @@ go build -i -ldflags "-s -w -X github.com/drakkan/sftpgo/utils.commit=`git descr
 and you will get a version that includes git commit and build date like this one:
 
 ```bash
-./sftpgo -v
+sftpgo -v
 SFTPGo version: 0.9.0-dev-90607d4-dirty-2019-08-08T19:28:36Z
 ```
 
-A systemd sample [service](https://github.com/drakkan/sftpgo/tree/master/init/sftpgo.service "systemd service") can be found inside the source tree.
+For Linux, a systemd sample [service](https://github.com/drakkan/sftpgo/tree/master/init/sftpgo.service "systemd service") can be found inside the source tree.
 
 Alternately you can use distro packages:
 

+ 2 - 6
dataprovider/dataprovider.go

@@ -4,7 +4,6 @@
 package dataprovider
 
 import (
-	"database/sql"
 	"fmt"
 	"path/filepath"
 	"strings"
@@ -33,7 +32,6 @@ const (
 var (
 	// SupportedProviders data provider configured in the sftpgo.conf file must match of these strings
 	SupportedProviders = []string{SQLiteDataProviderName, PGSSQLDataProviderName, MySQLDataProviderName}
-	dbHandle           *sql.DB
 	config             Config
 	provider           Provider
 	sqlPlaceholders    []string
@@ -124,13 +122,10 @@ func Initialize(cnf Config, basePath string) error {
 	config = cnf
 	sqlPlaceholders = getSQLPlaceholders()
 	if config.Driver == SQLiteDataProviderName {
-		provider = SQLiteProvider{}
 		return initializeSQLiteProvider(basePath)
 	} else if config.Driver == PGSSQLDataProviderName {
-		provider = PGSQLProvider{}
 		return initializePGSQLProvider()
 	} else if config.Driver == MySQLDataProviderName {
-		provider = MySQLProvider{}
 		return initializeMySQLProvider()
 	}
 	return fmt.Errorf("Unsupported data provider: %v", config.Driver)
@@ -226,7 +221,8 @@ func validateUser(user *User) error {
 			return &ValidationError{err: fmt.Sprintf("Invalid permission: %v", p)}
 		}
 	}
-	if len(user.Password) > 0 && !strings.HasPrefix(user.Password, argonPwdPrefix) {
+	if len(user.Password) > 0 && !strings.HasPrefix(user.Password, argonPwdPrefix) &&
+		!strings.HasPrefix(user.Password, bcryptPwdPrefix) {
 		pwd, err := argon2id.CreateHash(user.Password, argon2id.DefaultParams)
 		if err != nil {
 			return err

+ 14 - 12
dataprovider/mysql.go

@@ -11,6 +11,7 @@ import (
 
 // MySQLProvider auth provider for MySQL/MariaDB database
 type MySQLProvider struct {
+	dbHandle *sql.DB
 }
 
 func initializeMySQLProvider() error {
@@ -22,13 +23,14 @@ func initializeMySQLProvider() error {
 	} else {
 		connectionString = config.ConnectionString
 	}
-	dbHandle, err = sql.Open("mysql", connectionString)
+	dbHandle, err := sql.Open("mysql", connectionString)
 	if err == nil {
 		numCPU := runtime.NumCPU()
 		logger.Debug(logSender, "mysql database handle created, connection string: '%v', pool size: %v", connectionString, numCPU)
 		dbHandle.SetMaxIdleConns(numCPU)
 		dbHandle.SetMaxOpenConns(numCPU)
 		dbHandle.SetConnMaxLifetime(1800 * time.Second)
+		provider = MySQLProvider{dbHandle: dbHandle}
 	} else {
 		logger.Warn(logSender, "error creating mysql database handler, connection string: '%v', error: %v", connectionString, err)
 	}
@@ -36,24 +38,24 @@ func initializeMySQLProvider() error {
 }
 
 func (p MySQLProvider) validateUserAndPass(username string, password string) (User, error) {
-	return sqlCommonValidateUserAndPass(username, password)
+	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 }
 
 func (p MySQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
-	return sqlCommonValidateUserAndPubKey(username, publicKey)
+	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 }
 
 func (p MySQLProvider) getUserByID(ID int64) (User, error) {
-	return sqlCommonGetUserByID(ID)
+	return sqlCommonGetUserByID(ID, p.dbHandle)
 }
 
 func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
-	tx, err := dbHandle.Begin()
+	tx, err := p.dbHandle.Begin()
 	if err != nil {
 		logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err)
 		return err
 	}
-	err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p)
+	err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
 	if err == nil {
 		err = tx.Commit()
 	} else {
@@ -66,25 +68,25 @@ func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64,
 }
 
 func (p MySQLProvider) getUsedQuota(username string) (int, int64, error) {
-	return sqlCommonGetUsedQuota(username)
+	return sqlCommonGetUsedQuota(username, p.dbHandle)
 }
 
 func (p MySQLProvider) userExists(username string) (User, error) {
-	return sqlCommonCheckUserExists(username)
+	return sqlCommonCheckUserExists(username, p.dbHandle)
 }
 
 func (p MySQLProvider) addUser(user User) error {
-	return sqlCommonAddUser(user)
+	return sqlCommonAddUser(user, p.dbHandle)
 }
 
 func (p MySQLProvider) updateUser(user User) error {
-	return sqlCommonUpdateUser(user)
+	return sqlCommonUpdateUser(user, p.dbHandle)
 }
 
 func (p MySQLProvider) deleteUser(user User) error {
-	return sqlCommonDeleteUser(user)
+	return sqlCommonDeleteUser(user, p.dbHandle)
 }
 
 func (p MySQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
-	return sqlCommonGetUsers(limit, offset, order, username)
+	return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle)
 }

+ 14 - 12
dataprovider/pgsql.go

@@ -10,6 +10,7 @@ import (
 
 // PGSQLProvider auth provider for PostgreSQL database
 type PGSQLProvider struct {
+	dbHandle *sql.DB
 }
 
 func initializePGSQLProvider() error {
@@ -21,12 +22,13 @@ func initializePGSQLProvider() error {
 	} else {
 		connectionString = config.ConnectionString
 	}
-	dbHandle, err = sql.Open("postgres", connectionString)
+	dbHandle, err := sql.Open("postgres", connectionString)
 	if err == nil {
 		numCPU := runtime.NumCPU()
 		logger.Debug(logSender, "postgres database handle created, connection string: '%v', pool size: %v", connectionString, numCPU)
 		dbHandle.SetMaxIdleConns(numCPU)
 		dbHandle.SetMaxOpenConns(numCPU)
+		provider = PGSQLProvider{dbHandle: dbHandle}
 	} else {
 		logger.Warn(logSender, "error creating postgres database handler, connection string: '%v', error: %v", connectionString, err)
 	}
@@ -34,24 +36,24 @@ func initializePGSQLProvider() error {
 }
 
 func (p PGSQLProvider) validateUserAndPass(username string, password string) (User, error) {
-	return sqlCommonValidateUserAndPass(username, password)
+	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 }
 
 func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
-	return sqlCommonValidateUserAndPubKey(username, publicKey)
+	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 }
 
 func (p PGSQLProvider) getUserByID(ID int64) (User, error) {
-	return sqlCommonGetUserByID(ID)
+	return sqlCommonGetUserByID(ID, p.dbHandle)
 }
 
 func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
-	tx, err := dbHandle.Begin()
+	tx, err := p.dbHandle.Begin()
 	if err != nil {
 		logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err)
 		return err
 	}
-	err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p)
+	err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
 	if err == nil {
 		err = tx.Commit()
 	} else {
@@ -64,25 +66,25 @@ func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64,
 }
 
 func (p PGSQLProvider) getUsedQuota(username string) (int, int64, error) {
-	return sqlCommonGetUsedQuota(username)
+	return sqlCommonGetUsedQuota(username, p.dbHandle)
 }
 
 func (p PGSQLProvider) userExists(username string) (User, error) {
-	return sqlCommonCheckUserExists(username)
+	return sqlCommonCheckUserExists(username, p.dbHandle)
 }
 
 func (p PGSQLProvider) addUser(user User) error {
-	return sqlCommonAddUser(user)
+	return sqlCommonAddUser(user, p.dbHandle)
 }
 
 func (p PGSQLProvider) updateUser(user User) error {
-	return sqlCommonUpdateUser(user)
+	return sqlCommonUpdateUser(user, p.dbHandle)
 }
 
 func (p PGSQLProvider) deleteUser(user User) error {
-	return sqlCommonDeleteUser(user)
+	return sqlCommonDeleteUser(user, p.dbHandle)
 }
 
 func (p PGSQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
-	return sqlCommonGetUsers(limit, offset, order, username)
+	return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle)
 }

+ 13 - 13
dataprovider/sqlcommon.go

@@ -16,7 +16,7 @@ import (
 	"github.com/drakkan/sftpgo/utils"
 )
 
-func getUserByUsername(username string) (User, error) {
+func getUserByUsername(username string, dbHandle *sql.DB) (User, error) {
 	var user User
 	q := getUserByUsernameQuery()
 	stmt, err := dbHandle.Prepare(q)
@@ -30,12 +30,12 @@ func getUserByUsername(username string) (User, error) {
 	return getUserFromDbRow(row, nil)
 }
 
-func sqlCommonValidateUserAndPass(username string, password string) (User, error) {
+func sqlCommonValidateUserAndPass(username string, password string, dbHandle *sql.DB) (User, error) {
 	var user User
 	if len(password) == 0 {
 		return user, errors.New("Credentials cannot be null or empty")
 	}
-	user, err := getUserByUsername(username)
+	user, err := getUserByUsername(username, dbHandle)
 	if err != nil {
 		logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err)
 	} else {
@@ -68,12 +68,12 @@ func sqlCommonValidateUserAndPass(username string, password string) (User, error
 	return user, err
 }
 
-func sqlCommonValidateUserAndPubKey(username string, pubKey string) (User, error) {
+func sqlCommonValidateUserAndPubKey(username string, pubKey string, dbHandle *sql.DB) (User, error) {
 	var user User
 	if len(pubKey) == 0 {
 		return user, errors.New("Credentials cannot be null or empty")
 	}
-	user, err := getUserByUsername(username)
+	user, err := getUserByUsername(username, dbHandle)
 	if err != nil {
 		logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err)
 		return user, err
@@ -95,7 +95,7 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey string) (User, error
 	return user, errors.New("Invalid credentials")
 }
 
-func sqlCommonGetUserByID(ID int64) (User, error) {
+func sqlCommonGetUserByID(ID int64, dbHandle *sql.DB) (User, error) {
 	var user User
 	q := getUserByIDQuery()
 	stmt, err := dbHandle.Prepare(q)
@@ -109,7 +109,7 @@ func sqlCommonGetUserByID(ID int64) (User, error) {
 	return getUserFromDbRow(row, nil)
 }
 
-func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, p Provider) error {
+func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
 	q := getUpdateQuotaQuery(reset)
 	stmt, err := dbHandle.Prepare(q)
 	if err != nil {
@@ -127,7 +127,7 @@ func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bo
 	return err
 }
 
-func sqlCommonGetUsedQuota(username string) (int, int64, error) {
+func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, error) {
 	q := getQuotaQuery()
 	stmt, err := dbHandle.Prepare(q)
 	if err != nil {
@@ -146,7 +146,7 @@ func sqlCommonGetUsedQuota(username string) (int, int64, error) {
 	return usedFiles, usedSize, err
 }
 
-func sqlCommonCheckUserExists(username string) (User, error) {
+func sqlCommonCheckUserExists(username string, dbHandle *sql.DB) (User, error) {
 	var user User
 	q := getUserByUsernameQuery()
 	stmt, err := dbHandle.Prepare(q)
@@ -159,7 +159,7 @@ func sqlCommonCheckUserExists(username string) (User, error) {
 	return getUserFromDbRow(row, nil)
 }
 
-func sqlCommonAddUser(user User) error {
+func sqlCommonAddUser(user User, dbHandle *sql.DB) error {
 	err := validateUser(&user)
 	if err != nil {
 		return err
@@ -184,7 +184,7 @@ func sqlCommonAddUser(user User) error {
 	return err
 }
 
-func sqlCommonUpdateUser(user User) error {
+func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error {
 	err := validateUser(&user)
 	if err != nil {
 		return err
@@ -209,7 +209,7 @@ func sqlCommonUpdateUser(user User) error {
 	return err
 }
 
-func sqlCommonDeleteUser(user User) error {
+func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error {
 	q := getDeleteUserQuery()
 	stmt, err := dbHandle.Prepare(q)
 	if err != nil {
@@ -221,7 +221,7 @@ func sqlCommonDeleteUser(user User) error {
 	return err
 }
 
-func sqlCommonGetUsers(limit int, offset int, order string, username string) ([]User, error) {
+func sqlCommonGetUsers(limit int, offset int, order string, username string, dbHandle *sql.DB) ([]User, error) {
 	users := []User{}
 	q := getUsersQuery(order, username)
 	stmt, err := dbHandle.Prepare(q)

+ 13 - 11
dataprovider/sqlite.go

@@ -12,6 +12,7 @@ import (
 
 // SQLiteProvider auth provider for SQLite database
 type SQLiteProvider struct {
+	dbHandle *sql.DB
 }
 
 func initializeSQLiteProvider(basePath string) error {
@@ -36,10 +37,11 @@ func initializeSQLiteProvider(basePath string) error {
 	} else {
 		connectionString = config.ConnectionString
 	}
-	dbHandle, err = sql.Open("sqlite3", connectionString)
+	dbHandle, err := sql.Open("sqlite3", connectionString)
 	if err == nil {
 		logger.Debug(logSender, "sqlite database handle created, connection string: '%v'", connectionString)
 		dbHandle.SetMaxOpenConns(1)
+		provider = SQLiteProvider{dbHandle: dbHandle}
 	} else {
 		logger.Warn(logSender, "error creating sqlite database handler, connection string: '%v', error: %v", connectionString, err)
 	}
@@ -47,43 +49,43 @@ func initializeSQLiteProvider(basePath string) error {
 }
 
 func (p SQLiteProvider) validateUserAndPass(username string, password string) (User, error) {
-	return sqlCommonValidateUserAndPass(username, password)
+	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 }
 
 func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
-	return sqlCommonValidateUserAndPubKey(username, publicKey)
+	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 }
 
 func (p SQLiteProvider) getUserByID(ID int64) (User, error) {
-	return sqlCommonGetUserByID(ID)
+	return sqlCommonGetUserByID(ID, p.dbHandle)
 }
 
 func (p SQLiteProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
 	// we keep only 1 open connection (SetMaxOpenConns(1)) so a transaction is not needed and it could block
 	// the database access since it will try to open a new connection
-	return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p)
+	return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
 }
 
 func (p SQLiteProvider) getUsedQuota(username string) (int, int64, error) {
-	return sqlCommonGetUsedQuota(username)
+	return sqlCommonGetUsedQuota(username, p.dbHandle)
 }
 
 func (p SQLiteProvider) userExists(username string) (User, error) {
-	return sqlCommonCheckUserExists(username)
+	return sqlCommonCheckUserExists(username, p.dbHandle)
 }
 
 func (p SQLiteProvider) addUser(user User) error {
-	return sqlCommonAddUser(user)
+	return sqlCommonAddUser(user, p.dbHandle)
 }
 
 func (p SQLiteProvider) updateUser(user User) error {
-	return sqlCommonUpdateUser(user)
+	return sqlCommonUpdateUser(user, p.dbHandle)
 }
 
 func (p SQLiteProvider) deleteUser(user User) error {
-	return sqlCommonDeleteUser(user)
+	return sqlCommonDeleteUser(user, p.dbHandle)
 }
 
 func (p SQLiteProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
-	return sqlCommonGetUsers(limit, offset, order, username)
+	return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle)
 }

+ 8 - 1
sftpd/internal_test.go

@@ -99,11 +99,18 @@ func TestUploadFiles(t *testing.T) {
 	uploadMode = oldUploadMode
 }
 
-func TestLoginWithInvalidHome(t *testing.T) {
+func TestWithInvalidHome(t *testing.T) {
 	u := dataprovider.User{}
 	u.HomeDir = "home_rel_path"
 	_, err := loginUser(u)
 	if err == nil {
 		t.Errorf("login a user with an invalid home_dir must fail")
 	}
+	c := Connection{
+		User: u,
+	}
+	err = c.isSubDir("dir_rel_path")
+	if err == nil {
+		t.Errorf("tested path is not a home subdir")
+	}
 }