diff --git a/dataprovider/bolt.go b/dataprovider/bolt.go index a2cb967c..f4028487 100644 --- a/dataprovider/bolt.go +++ b/dataprovider/bolt.go @@ -419,17 +419,18 @@ func (p BoltProvider) migrateDatabase() error { providerLog(logger.LevelDebug, "bolt database is updated, current version: %v", dbVersion.Version) return nil } - if dbVersion.Version == 1 { + switch dbVersion.Version { + case 1: err = updateDatabaseFrom1To2(p.dbHandle) if err != nil { return err } return updateDatabaseFrom2To3(p.dbHandle) - } else if dbVersion.Version == 2 { + case 2: return updateDatabaseFrom2To3(p.dbHandle) + default: + return fmt.Errorf("Database version not handled: %v", dbVersion.Version) } - - return nil } // itob returns an 8-byte big endian representation of v. diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go index 7a8a9442..23e27084 100644 --- a/dataprovider/mysql.go +++ b/dataprovider/mysql.go @@ -20,6 +20,7 @@ const ( "`filesystem` longtext DEFAULT NULL);" mysqlSchemaTableSQL = "CREATE TABLE `schema_version` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" mysqlUsersV2SQL = "ALTER TABLE `{{users}}` ADD COLUMN `virtual_folders` longtext NULL;" + mysqlUsersV3SQL = "ALTER TABLE `{{users}}` MODIFY `password` longtext NULL;" ) // MySQLProvider auth provider for MySQL/MariaDB database @@ -152,15 +153,33 @@ func (p MySQLProvider) migrateDatabase() error { providerLog(logger.LevelDebug, "sql database is updated, current version: %v", dbVersion.Version) return nil } - if dbVersion.Version == 1 { - return updateMySQLDatabaseFrom1To2(p.dbHandle) + switch dbVersion.Version { + case 1: + err = updateMySQLDatabaseFrom1To2(p.dbHandle) + if err != nil { + return err + } + return updateMySQLDatabaseFrom2To3(p.dbHandle) + case 2: + return updateMySQLDatabaseFrom2To3(p.dbHandle) + default: + return fmt.Errorf("Database version not handled: %v", dbVersion.Version) } - return nil } func updateMySQLDatabaseFrom1To2(dbHandle *sql.DB) error { providerLog(logger.LevelInfo, "updating database version: 1 -> 2") sql := strings.Replace(mysqlUsersV2SQL, "{{users}}", config.UsersTable, 1) + return updateMySQLDatabase(dbHandle, sql, 2) +} + +func updateMySQLDatabaseFrom2To3(dbHandle *sql.DB) error { + providerLog(logger.LevelInfo, "updating database version: 2 -> 3") + sql := strings.Replace(mysqlUsersV3SQL, "{{users}}", config.UsersTable, 1) + return updateMySQLDatabase(dbHandle, sql, 3) +} + +func updateMySQLDatabase(dbHandle *sql.DB, sql string, newVersion int) error { tx, err := dbHandle.Begin() if err != nil { return err @@ -170,7 +189,7 @@ func updateMySQLDatabaseFrom1To2(dbHandle *sql.DB) error { tx.Rollback() return err } - err = sqlCommonUpdateDatabaseVersionWithTX(tx, 2) + err = sqlCommonUpdateDatabaseVersionWithTX(tx, newVersion) if err != nil { tx.Rollback() return err diff --git a/dataprovider/pgsql.go b/dataprovider/pgsql.go index 32381442..de8b5c77 100644 --- a/dataprovider/pgsql.go +++ b/dataprovider/pgsql.go @@ -18,6 +18,7 @@ const ( "filesystem" text NULL);` pgsqlSchemaTableSQL = `CREATE TABLE "schema_version" ("id" serial NOT NULL PRIMARY KEY, "version" integer NOT NULL);` pgsqlUsersV2SQL = `ALTER TABLE "{{users}}" ADD COLUMN "virtual_folders" text NULL;` + pgsqlUsersV3SQL = `ALTER TABLE "{{users}}" ALTER COLUMN "password" TYPE text USING "password"::text;` ) // PGSQLProvider auth provider for PostgreSQL database @@ -150,15 +151,33 @@ func (p PGSQLProvider) migrateDatabase() error { providerLog(logger.LevelDebug, "sql database is updated, current version: %v", dbVersion.Version) return nil } - if dbVersion.Version == 1 { - return updatePGSQLDatabaseFrom1To2(p.dbHandle) + switch dbVersion.Version { + case 1: + err = updatePGSQLDatabaseFrom1To2(p.dbHandle) + if err != nil { + return err + } + return updatePGSQLDatabaseFrom2To3(p.dbHandle) + case 2: + return updatePGSQLDatabaseFrom2To3(p.dbHandle) + default: + return fmt.Errorf("Database version not handled: %v", dbVersion.Version) } - return nil } func updatePGSQLDatabaseFrom1To2(dbHandle *sql.DB) error { providerLog(logger.LevelInfo, "updating database version: 1 -> 2") sql := strings.Replace(pgsqlUsersV2SQL, "{{users}}", config.UsersTable, 1) + return updatePGSQLDatabase(dbHandle, sql, 2) +} + +func updatePGSQLDatabaseFrom2To3(dbHandle *sql.DB) error { + providerLog(logger.LevelInfo, "updating database version: 2 -> 3") + sql := strings.Replace(pgsqlUsersV3SQL, "{{users}}", config.UsersTable, 1) + return updatePGSQLDatabase(dbHandle, sql, 3) +} + +func updatePGSQLDatabase(dbHandle *sql.DB, sql string, newVersion int) error { tx, err := dbHandle.Begin() if err != nil { return err @@ -168,7 +187,7 @@ func updatePGSQLDatabaseFrom1To2(dbHandle *sql.DB) error { tx.Rollback() return err } - err = sqlCommonUpdateDatabaseVersionWithTX(tx, 2) + err = sqlCommonUpdateDatabaseVersionWithTX(tx, newVersion) if err != nil { tx.Rollback() return err diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go index 686ee17a..509cf908 100644 --- a/dataprovider/sqlcommon.go +++ b/dataprovider/sqlcommon.go @@ -13,7 +13,7 @@ import ( ) const ( - sqlDatabaseVersion = 2 + sqlDatabaseVersion = 3 initialDBVersionSQL = "INSERT INTO schema_version (version) VALUES (1);" ) diff --git a/dataprovider/sqlite.go b/dataprovider/sqlite.go index 920a699f..16f86904 100644 --- a/dataprovider/sqlite.go +++ b/dataprovider/sqlite.go @@ -20,6 +20,20 @@ NOT NULL UNIQUE, "password" varchar(255) NULL, "public_keys" text NULL, "home_di "filesystem" text NULL);` sqliteSchemaTableSQL = `CREATE TABLE "schema_version" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "version" integer NOT NULL);` sqliteUsersV2SQL = `ALTER TABLE "{{users}}" ADD COLUMN "virtual_folders" text NULL;` + sqliteUsersV3SQL = `CREATE TABLE "new__users" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "username" varchar(255) NOT NULL UNIQUE, + "password" text NULL, "public_keys" text NULL, "home_dir" varchar(255) NOT NULL, "uid" integer NOT NULL, +"gid" integer NOT NULL, "max_sessions" integer NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, +"permissions" text NOT NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL, +"upload_bandwidth" integer NOT NULL, "download_bandwidth" integer NOT NULL, "expiration_date" bigint NOT NULL, "last_login" bigint NOT NULL, +"status" integer NOT NULL, "filters" text NULL, "filesystem" text NULL, "virtual_folders" text NULL); +INSERT INTO "new__users" ("id", "username", "public_keys", "home_dir", "uid", "gid", "max_sessions", "quota_size", "quota_files", +"permissions", "used_quota_size", "used_quota_files", "last_quota_update", "upload_bandwidth", "download_bandwidth", "expiration_date", +"last_login", "status", "filters", "filesystem", "virtual_folders", "password") SELECT "id", "username", "public_keys", "home_dir", +"uid", "gid", "max_sessions", "quota_size", "quota_files", "permissions", "used_quota_size", "used_quota_files", "last_quota_update", +"upload_bandwidth", "download_bandwidth", "expiration_date", "last_login", "status", "filters", "filesystem", "virtual_folders", +"password" FROM "{{users}}"; +DROP TABLE "{{users}}"; +ALTER TABLE "new__users" RENAME TO "{{users}}";` ) // SQLiteProvider auth provider for SQLite database @@ -132,10 +146,18 @@ func (p SQLiteProvider) migrateDatabase() error { providerLog(logger.LevelDebug, "sql database is updated, current version: %v", dbVersion.Version) return nil } - if dbVersion.Version == 1 { - return updateSQLiteDatabaseFrom1To2(p.dbHandle) + switch dbVersion.Version { + case 1: + err = updateSQLiteDatabaseFrom1To2(p.dbHandle) + if err != nil { + return err + } + return updateSQLiteDatabaseFrom2To3(p.dbHandle) + case 2: + return updateSQLiteDatabaseFrom2To3(p.dbHandle) + default: + return fmt.Errorf("Database version not handled: %v", dbVersion.Version) } - return nil } func updateSQLiteDatabaseFrom1To2(dbHandle *sql.DB) error { @@ -147,3 +169,13 @@ func updateSQLiteDatabaseFrom1To2(dbHandle *sql.DB) error { } return sqlCommonUpdateDatabaseVersion(dbHandle, 2) } + +func updateSQLiteDatabaseFrom2To3(dbHandle *sql.DB) error { + providerLog(logger.LevelInfo, "updating database version: 2 -> 3") + sql := strings.ReplaceAll(sqliteUsersV3SQL, "{{users}}", config.UsersTable) + _, err := dbHandle.Exec(sql) + if err != nil { + return err + } + return sqlCommonUpdateDatabaseVersion(dbHandle, 3) +}