sftpgo-mirror/dataprovider/pgsql.go
Nicola Murino dccc583b5d
add a dedicated struct to store encrypted credentials
also gcs credentials are now encrypted, both on disk and inside the
provider.

Data provider is automatically migrated and load data will accept
old format too but you should upgrade to the new format to avoid future
issues
2020-11-22 21:53:04 +01:00

272 lines
9.8 KiB
Go

// +build !nopgsql
package dataprovider
import (
"context"
"database/sql"
"fmt"
"strings"
// we import lib/pq here to be able to disable PostgreSQL support using a build tag
_ "github.com/lib/pq"
"github.com/drakkan/sftpgo/logger"
"github.com/drakkan/sftpgo/version"
"github.com/drakkan/sftpgo/vfs"
)
const (
pgsqlUsersTableSQL = `CREATE TABLE "{{users}}" ("id" serial NOT NULL PRIMARY KEY, "username" varchar(255) NOT NULL UNIQUE,
"password" varchar(255) NULL, "public_keys" text NULL, "home_dir" varchar(255) NOT NULL, "uid" integer NOT NULL,
"gid" integer NOT NULL, "max_sessions" integer NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL,
"permissions" text NOT NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL,
"last_quota_update" bigint NOT NULL, "upload_bandwidth" integer NOT NULL, "download_bandwidth" integer NOT NULL,
"expiration_date" bigint NOT NULL, "last_login" bigint NOT NULL, "status" integer NOT NULL, "filters" text NULL,
"filesystem" text NULL);`
pgsqlSchemaTableSQL = `CREATE TABLE "{{schema_version}}" ("id" serial NOT NULL PRIMARY KEY, "version" integer NOT NULL);`
pgsqlV2SQL = `ALTER TABLE "{{users}}" ADD COLUMN "virtual_folders" text NULL;`
pgsqlV3SQL = `ALTER TABLE "{{users}}" ALTER COLUMN "password" TYPE text USING "password"::text;`
pgsqlV4SQL = `CREATE TABLE "{{folders}}" ("id" serial NOT NULL PRIMARY KEY, "path" varchar(512) NOT NULL UNIQUE, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL);
ALTER TABLE "{{users}}" ALTER COLUMN "home_dir" TYPE varchar(512) USING "home_dir"::varchar(512);
ALTER TABLE "{{users}}" DROP COLUMN "virtual_folders" CASCADE;
CREATE TABLE "{{folders_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "virtual_path" varchar(512) NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "folder_id" integer NOT NULL, "user_id" integer NOT NULL);
ALTER TABLE "{{folders_mapping}}" ADD CONSTRAINT "unique_mapping" UNIQUE ("user_id", "folder_id");
ALTER TABLE "{{folders_mapping}}" ADD CONSTRAINT "folders_mapping_folder_id_fk_folders_id" FOREIGN KEY ("folder_id") REFERENCES "{{folders}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED;
ALTER TABLE "{{folders_mapping}}" ADD CONSTRAINT "folders_mapping_user_id_fk_users_id" FOREIGN KEY ("user_id") REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED;
CREATE INDEX "folders_mapping_folder_id_idx" ON "{{folders_mapping}}" ("folder_id");
CREATE INDEX "folders_mapping_user_id_idx" ON "{{folders_mapping}}" ("user_id");
`
)
// PGSQLProvider auth provider for PostgreSQL database
type PGSQLProvider struct {
dbHandle *sql.DB
}
func init() {
version.AddFeature("+pgsql")
}
func initializePGSQLProvider() error {
var err error
logSender = fmt.Sprintf("dataprovider_%v", PGSQLDataProviderName)
dbHandle, err := sql.Open("postgres", getPGSQLConnectionString(false))
if err == nil {
providerLog(logger.LevelDebug, "postgres database handle created, connection string: %#v, pool size: %v",
getPGSQLConnectionString(true), config.PoolSize)
dbHandle.SetMaxOpenConns(config.PoolSize)
provider = PGSQLProvider{dbHandle: dbHandle}
} else {
providerLog(logger.LevelWarn, "error creating postgres database handler, connection string: %#v, error: %v",
getPGSQLConnectionString(true), err)
}
return err
}
func getPGSQLConnectionString(redactedPwd bool) string {
var connectionString string
if len(config.ConnectionString) == 0 {
password := config.Password
if redactedPwd {
password = "[redacted]"
}
connectionString = fmt.Sprintf("host='%v' port=%v dbname='%v' user='%v' password='%v' sslmode=%v connect_timeout=10",
config.Host, config.Port, config.Name, config.Username, password, getSSLMode())
} else {
connectionString = config.ConnectionString
}
return connectionString
}
func (p PGSQLProvider) checkAvailability() error {
return sqlCommonCheckAvailability(p.dbHandle)
}
func (p PGSQLProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) {
return sqlCommonValidateUserAndPass(username, password, ip, protocol, p.dbHandle)
}
func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey []byte) (User, string, error) {
return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
}
func (p PGSQLProvider) getUserByID(ID int64) (User, error) {
return sqlCommonGetUserByID(ID, p.dbHandle)
}
func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
}
func (p PGSQLProvider) getUsedQuota(username string) (int, int64, error) {
return sqlCommonGetUsedQuota(username, p.dbHandle)
}
func (p PGSQLProvider) updateLastLogin(username string) error {
return sqlCommonUpdateLastLogin(username, p.dbHandle)
}
func (p PGSQLProvider) userExists(username string) (User, error) {
return sqlCommonCheckUserExists(username, p.dbHandle)
}
func (p PGSQLProvider) addUser(user User) error {
return sqlCommonAddUser(user, p.dbHandle)
}
func (p PGSQLProvider) updateUser(user User) error {
return sqlCommonUpdateUser(user, p.dbHandle)
}
func (p PGSQLProvider) deleteUser(user User) error {
return sqlCommonDeleteUser(user, p.dbHandle)
}
func (p PGSQLProvider) dumpUsers() ([]User, error) {
return sqlCommonDumpUsers(p.dbHandle)
}
func (p PGSQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle)
}
func (p PGSQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) {
return sqlCommonDumpFolders(p.dbHandle)
}
func (p PGSQLProvider) getFolders(limit, offset int, order, folderPath string) ([]vfs.BaseVirtualFolder, error) {
return sqlCommonGetFolders(limit, offset, order, folderPath, p.dbHandle)
}
func (p PGSQLProvider) getFolderByPath(mappedPath string) (vfs.BaseVirtualFolder, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
return sqlCommonCheckFolderExists(ctx, mappedPath, p.dbHandle)
}
func (p PGSQLProvider) addFolder(folder vfs.BaseVirtualFolder) error {
return sqlCommonAddFolder(folder, p.dbHandle)
}
func (p PGSQLProvider) deleteFolder(folder vfs.BaseVirtualFolder) error {
return sqlCommonDeleteFolder(folder, p.dbHandle)
}
func (p PGSQLProvider) updateFolderQuota(mappedPath string, filesAdd int, sizeAdd int64, reset bool) error {
return sqlCommonUpdateFolderQuota(mappedPath, filesAdd, sizeAdd, reset, p.dbHandle)
}
func (p PGSQLProvider) getUsedFolderQuota(mappedPath string) (int, int64, error) {
return sqlCommonGetFolderUsedQuota(mappedPath, p.dbHandle)
}
func (p PGSQLProvider) close() error {
return p.dbHandle.Close()
}
func (p PGSQLProvider) reloadConfig() error {
return nil
}
// initializeDatabase creates the initial database structure
func (p PGSQLProvider) initializeDatabase() error {
dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, false)
if err == nil && dbVersion.Version > 0 {
return ErrNoInitRequired
}
sqlUsers := strings.Replace(pgsqlUsersTableSQL, "{{users}}", sqlTableUsers, 1)
tx, err := p.dbHandle.Begin()
if err != nil {
return err
}
_, err = tx.Exec(sqlUsers)
if err != nil {
sqlCommonRollbackTransaction(tx)
return err
}
_, err = tx.Exec(strings.Replace(pgsqlSchemaTableSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
if err != nil {
sqlCommonRollbackTransaction(tx)
return err
}
_, err = tx.Exec(strings.Replace(initialDBVersionSQL, "{{schema_version}}", sqlTableSchemaVersion, 1))
if err != nil {
sqlCommonRollbackTransaction(tx)
return err
}
return tx.Commit()
}
func (p PGSQLProvider) migrateDatabase() error {
dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true)
if err != nil {
return err
}
if dbVersion.Version == sqlDatabaseVersion {
providerLog(logger.LevelDebug, "sql database is up to date, current version: %v", dbVersion.Version)
return ErrNoInitRequired
}
switch dbVersion.Version {
case 1:
return updatePGSQLDatabaseFromV1(p.dbHandle)
case 2:
return updatePGSQLDatabaseFromV2(p.dbHandle)
case 3:
return updatePGSQLDatabaseFromV3(p.dbHandle)
case 4:
return updatePGSQLDatabaseFromV4(p.dbHandle)
default:
return fmt.Errorf("Database version not handled: %v", dbVersion.Version)
}
}
func updatePGSQLDatabaseFromV1(dbHandle *sql.DB) error {
err := updatePGSQLDatabaseFrom1To2(dbHandle)
if err != nil {
return err
}
return updatePGSQLDatabaseFromV2(dbHandle)
}
func updatePGSQLDatabaseFromV2(dbHandle *sql.DB) error {
err := updatePGSQLDatabaseFrom2To3(dbHandle)
if err != nil {
return err
}
return updatePGSQLDatabaseFromV3(dbHandle)
}
func updatePGSQLDatabaseFromV3(dbHandle *sql.DB) error {
err := updatePGSQLDatabaseFrom3To4(dbHandle)
if err != nil {
return err
}
return updatePGSQLDatabaseFromV4(dbHandle)
}
func updatePGSQLDatabaseFromV4(dbHandle *sql.DB) error {
return updatePGSQLDatabaseFrom4To5(dbHandle)
}
func updatePGSQLDatabaseFrom1To2(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 1 -> 2")
providerLog(logger.LevelInfo, "updating database version: 1 -> 2")
sql := strings.Replace(pgsqlV2SQL, "{{users}}", sqlTableUsers, 1)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 2)
}
func updatePGSQLDatabaseFrom2To3(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 2 -> 3")
providerLog(logger.LevelInfo, "updating database version: 2 -> 3")
sql := strings.Replace(pgsqlV3SQL, "{{users}}", sqlTableUsers, 1)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 3)
}
func updatePGSQLDatabaseFrom3To4(dbHandle *sql.DB) error {
return sqlCommonUpdateDatabaseFrom3To4(pgsqlV4SQL, dbHandle)
}
func updatePGSQLDatabaseFrom4To5(dbHandle *sql.DB) error {
return sqlCommonUpdateDatabaseFrom4To5(dbHandle)
}