From e29a3efd39a1cc87c9d2f012d7a62410da8be47a Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Mon, 15 Nov 2021 18:40:31 +0100 Subject: [PATCH] add resetprovider sub-command Fixes #608 --- README.md | 8 +++- cmd/resetprovider.go | 75 ++++++++++++++++++++++++++++++++++++ common/common.go | 4 +- dataprovider/bolt.go | 12 ++++++ dataprovider/dataprovider.go | 17 ++++++++ dataprovider/memory.go | 4 ++ dataprovider/mysql.go | 21 ++++++++++ dataprovider/pgsql.go | 22 +++++++++++ dataprovider/sqlcommon.go | 23 +++++++---- dataprovider/sqlite.go | 22 +++++++++++ 10 files changed, 199 insertions(+), 9 deletions(-) create mode 100644 cmd/resetprovider.go diff --git a/README.md b/README.md index 97650141..7bfb45ff 100644 --- a/README.md +++ b/README.md @@ -135,13 +135,19 @@ sftpgo initprovider --help You can disable automatic data provider checks/updates at startup by setting the `update_mode` configuration key to `1`. +You can also reset your provider by using the `resetprovider` sub-command. Take a look at the CLI usage for more details: + +```bash +sftpgo resetprovider --help +``` + ## Create the first admin To start using SFTPGo you need to create an admin user, you can do it in several ways: - by using the web admin interface. The default URL is [http://127.0.0.1:8080/web/admin](http://127.0.0.1:8080/web/admin) - by loading initial data -- by enabling `create_default_admin` in your configuration file. In this case the credentials are `admin`/`password` +- by enabling `create_default_admin` in your configuration file and setting the environment variables `SFTPGO_DEFAULT_ADMIN_USERNAME` and `SFTPGO_DEFAULT_ADMIN_PASSWORD` ## Upgrading diff --git a/cmd/resetprovider.go b/cmd/resetprovider.go new file mode 100644 index 00000000..437db6e7 --- /dev/null +++ b/cmd/resetprovider.go @@ -0,0 +1,75 @@ +package cmd + +import ( + "bufio" + "os" + "strings" + + "github.com/rs/zerolog" + "github.com/spf13/cobra" + "github.com/spf13/viper" + + "github.com/drakkan/sftpgo/v2/config" + "github.com/drakkan/sftpgo/v2/dataprovider" + "github.com/drakkan/sftpgo/v2/logger" + "github.com/drakkan/sftpgo/v2/util" +) + +var ( + resetProviderForce bool + resetProviderCmd = &cobra.Command{ + Use: "resetprovider", + Short: "Reset the configured provider, any data will be lost", + Long: `This command reads the data provider connection details from the specified +configuration file and resets the provider by deleting all data and schemas. +This command is not supported for the memory provider. + +Please take a look at the usage below to customize the options.`, + Run: func(cmd *cobra.Command, args []string) { + logger.DisableLogger() + logger.EnableConsoleLogger(zerolog.DebugLevel) + configDir = util.CleanDirInput(configDir) + err := config.LoadConfig(configDir, configFile) + if err != nil { + logger.WarnToConsole("Unable to initialize data provider, config load error: %v", err) + os.Exit(1) + } + kmsConfig := config.GetKMSConfig() + err = kmsConfig.Initialize() + if err != nil { + logger.ErrorToConsole("unable to initialize KMS: %v", err) + os.Exit(1) + } + providerConf := config.GetProviderConf() + if !resetProviderForce { + logger.WarnToConsole("You are about to delete all the SFTPGo data for provider %#v, config file: %#v", + providerConf.Driver, viper.ConfigFileUsed()) + logger.WarnToConsole("Are you sure? (Y/n)") + reader := bufio.NewReader(os.Stdin) + answer, err := reader.ReadString('\n') + if err != nil { + logger.ErrorToConsole("unable to read your answer: %v", err) + os.Exit(1) + } + if strings.ToUpper(strings.TrimSpace(answer)) != "Y" { + logger.InfoToConsole("command aborted") + os.Exit(1) + } + } + logger.InfoToConsole("Resetting provider: %#v, config file: %#v", providerConf.Driver, viper.ConfigFileUsed()) + err = dataprovider.ResetDatabase(providerConf, configDir) + if err != nil { + logger.WarnToConsole("Error resetting provider: %v", err) + os.Exit(1) + } + logger.InfoToConsole("Tha data provider was successfully reset") + }, + } +) + +func init() { + addConfigFlags(resetProviderCmd) + resetProviderCmd.Flags().BoolVar(&resetProviderForce, "force", false, `reset the provider without asking for confirmation`) + + rootCmd.AddCommand(resetProviderCmd) +} diff --git a/common/common.go b/common/common.go index c1bb1cab..0e929675 100644 --- a/common/common.go +++ b/common/common.go @@ -683,8 +683,10 @@ func (conns *ActiveConnections) Swap(c ActiveConnection) error { for idx, conn := range conns.connections { if conn.GetID() == c.GetID() { - conn = nil + err := conn.CloseFS() conns.connections[idx] = c + logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err) + conn = nil return nil } } diff --git a/dataprovider/bolt.go b/dataprovider/bolt.go index 31f1cd17..34673190 100644 --- a/dataprovider/bolt.go +++ b/dataprovider/bolt.go @@ -1418,6 +1418,18 @@ func (p *BoltProvider) revertDatabase(targetVersion int) error { } } +func (p *BoltProvider) resetDatabase() error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + for _, bucketName := range boltBuckets { + err := tx.DeleteBucket(bucketName) + if err != nil && !errors.Is(err, bolt.ErrBucketNotFound) { + return fmt.Errorf("unable to remove bucket %v: %w", bucketName, err) + } + } + return nil + }) +} + func joinUserAndFolders(u []byte, foldersBucket *bolt.Bucket) (User, error) { var user User err := json.Unmarshal(u, &user) diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index dac58aaa..f1f5f496 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -457,6 +457,7 @@ type Provider interface { initializeDatabase() error migrateDatabase() error revertDatabase(targetVersion int) error + resetDatabase() error } // SetTempPath sets the path for temporary files @@ -653,6 +654,22 @@ func RevertDatabase(cnf Config, basePath string, targetVersion int) error { return provider.revertDatabase(targetVersion) } +// ResetDatabase restores schema and/or data to a previous version +func ResetDatabase(cnf Config, basePath string) error { + config = cnf + + if filepath.IsAbs(config.CredentialsPath) { + credentialsDirPath = config.CredentialsPath + } else { + credentialsDirPath = filepath.Join(basePath, config.CredentialsPath) + } + + if err := createProvider(basePath); err != nil { + return err + } + return provider.resetDatabase() +} + // CheckAdminAndPass validates the given admin and password connecting from ip func CheckAdminAndPass(username, password, ip string) (Admin, error) { return provider.validateAdminAndPass(username, password, ip) diff --git a/dataprovider/memory.go b/dataprovider/memory.go index 7b0e38e1..bcf31128 100644 --- a/dataprovider/memory.go +++ b/dataprovider/memory.go @@ -1468,3 +1468,7 @@ func (p *MemoryProvider) migrateDatabase() error { func (p *MemoryProvider) revertDatabase(targetVersion int) error { return errors.New("memory provider does not store data, revert not possible") } + +func (p *MemoryProvider) resetDatabase() error { + return errors.New("memory provider does not store data, reset not possible") +} diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go index d944403c..04e57fff 100644 --- a/dataprovider/mysql.go +++ b/dataprovider/mysql.go @@ -21,6 +21,13 @@ import ( ) const ( + mysqlResetSQL = "DROP TABLE IF EXISTS `{{api_keys}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{folders_mapping}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{admins}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{folders}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{shares}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{users}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{schema_version}}` CASCADE;" mysqlInitialSQL = "CREATE TABLE `{{schema_version}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" + "CREATE TABLE `{{admins}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " + "`description` varchar(512) NULL, `password` varchar(255) NOT NULL, `email` varchar(255) NULL, `status` integer NOT NULL, " + @@ -318,6 +325,9 @@ func (p *MySQLProvider) initializeDatabase() error { if err == nil && dbVersion.Version > 0 { return ErrNoInitRequired } + if errors.Is(err, sql.ErrNoRows) { + return errSchemaVersionEmpty + } initialSQL := strings.ReplaceAll(mysqlInitialSQL, "{{schema_version}}", sqlTableSchemaVersion) initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins) initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders) @@ -387,6 +397,17 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error { } } +func (p *MySQLProvider) resetDatabase() error { + sql := strings.ReplaceAll(mysqlResetSQL, "{{schema_version}}", sqlTableSchemaVersion) + sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins) + sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders) + sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers) + sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping) + sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys) + sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares) + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(sql, ";"), 0) +} + func updateMySQLDatabaseFromV10(dbHandle *sql.DB) error { if err := updateMySQLDatabaseFrom10To11(dbHandle); err != nil { return err diff --git a/dataprovider/pgsql.go b/dataprovider/pgsql.go index ea762648..6ae679fb 100644 --- a/dataprovider/pgsql.go +++ b/dataprovider/pgsql.go @@ -21,6 +21,14 @@ import ( ) const ( + pgsqlResetSQL = `DROP TABLE IF EXISTS "{{api_keys}}" CASCADE; +DROP TABLE IF EXISTS "{{folders_mapping}}" CASCADE; +DROP TABLE IF EXISTS "{{admins}}" CASCADE; +DROP TABLE IF EXISTS "{{folders}}" CASCADE; +DROP TABLE IF EXISTS "{{shares}}" CASCADE; +DROP TABLE IF EXISTS "{{users}}" CASCADE; +DROP TABLE IF EXISTS "{{schema_version}}" CASCADE; +` pgsqlInitial = `CREATE TABLE "{{schema_version}}" ("id" serial NOT NULL PRIMARY KEY, "version" integer NOT NULL); CREATE TABLE "{{admins}}" ("id" serial NOT NULL PRIMARY KEY, "username" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL, @@ -332,6 +340,9 @@ func (p *PGSQLProvider) initializeDatabase() error { if err == nil && dbVersion.Version > 0 { return ErrNoInitRequired } + if errors.Is(err, sql.ErrNoRows) { + return errSchemaVersionEmpty + } initialSQL := strings.ReplaceAll(pgsqlInitial, "{{schema_version}}", sqlTableSchemaVersion) initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins) initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders) @@ -407,6 +418,17 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error { } } +func (p *PGSQLProvider) resetDatabase() error { + sql := strings.ReplaceAll(pgsqlResetSQL, "{{schema_version}}", sqlTableSchemaVersion) + sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins) + sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders) + sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers) + sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping) + sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys) + sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares) + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0) +} + func updatePGSQLDatabaseFromV10(dbHandle *sql.DB) error { if err := updatePGSQLDatabaseFrom10To11(dbHandle); err != nil { return err diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go index 84bef721..918e6311 100644 --- a/dataprovider/sqlcommon.go +++ b/dataprovider/sqlcommon.go @@ -24,7 +24,10 @@ const ( longSQLQueryTimeout = 60 * time.Second ) -var errSQLFoldersAssosaction = errors.New("unable to associate virtual folders to user") +var ( + errSQLFoldersAssosaction = errors.New("unable to associate virtual folders to user") + errSchemaVersionEmpty = errors.New("we can't determine schema version because the schema_migration table is empty. The SFTPGo database might be corrupted. Consider using the \"resetprovider\" sub-command") +) type sqlQuerier interface { PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) @@ -946,7 +949,7 @@ func getShareFromDbRow(row sqlScanner) (Share, error) { &share.LastUseAt, &share.ExpiresAt, &password, &share.MaxTokens, &share.UsedTokens, &allowFrom) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return share, util.NewRecordNotFoundError(err.Error()) } return share, err @@ -986,7 +989,7 @@ func getAPIKeyFromDbRow(row sqlScanner) (APIKey, error) { &apiKey.LastUseAt, &apiKey.ExpiresAt, &description, &userID, &adminID) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return apiKey, util.NewRecordNotFoundError(err.Error()) } return apiKey, err @@ -1013,7 +1016,7 @@ func getAdminFromDbRow(row sqlScanner) (Admin, error) { &filters, &additionalInfo, &description, &admin.CreatedAt, &admin.UpdatedAt, &admin.LastLogin) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return admin, util.NewRecordNotFoundError(err.Error()) } return admin, err @@ -1063,7 +1066,7 @@ func getUserFromDbRow(row sqlScanner) (User, error) { &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig, &additionalInfo, &description, &email, &user.CreatedAt, &user.UpdatedAt) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return user, util.NewRecordNotFoundError(err.Error()) } return user, err @@ -1143,8 +1146,11 @@ func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) ( var mappedPath, description, fsConfig sql.NullString err = row.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig) - if err == sql.ErrNoRows { - return folder, util.NewRecordNotFoundError(err.Error()) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return folder, util.NewRecordNotFoundError(err.Error()) + } + return folder, err } if mappedPath.Valid { folder.MappedPath = mappedPath.String @@ -1688,6 +1694,9 @@ func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, n return err } } + if newVersion == 0 { + return nil + } return sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion) }) } diff --git a/dataprovider/sqlite.go b/dataprovider/sqlite.go index f287dbf3..29b82838 100644 --- a/dataprovider/sqlite.go +++ b/dataprovider/sqlite.go @@ -22,6 +22,14 @@ import ( ) const ( + sqliteResetSQL = `DROP TABLE IF EXISTS "{{api_keys}}"; +DROP TABLE IF EXISTS "{{folders_mapping}}"; +DROP TABLE IF EXISTS "{{admins}}"; +DROP TABLE IF EXISTS "{{folders}}"; +DROP TABLE IF EXISTS "{{shares}}"; +DROP TABLE IF EXISTS "{{users}}"; +DROP TABLE IF EXISTS "{{schema_version}}"; +` sqliteInitialSQL = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "version" integer NOT NULL); CREATE TABLE "{{admins}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "username" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL, @@ -314,6 +322,9 @@ func (p *SQLiteProvider) initializeDatabase() error { if err == nil && dbVersion.Version > 0 { return ErrNoInitRequired } + if errors.Is(err, sql.ErrNoRows) { + return errSchemaVersionEmpty + } initialSQL := strings.ReplaceAll(sqliteInitialSQL, "{{schema_version}}", sqlTableSchemaVersion) initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins) initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders) @@ -383,6 +394,17 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error { } } +func (p *SQLiteProvider) resetDatabase() error { + sql := strings.ReplaceAll(sqliteResetSQL, "{{schema_version}}", sqlTableSchemaVersion) + sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins) + sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders) + sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers) + sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping) + sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys) + sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares) + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0) +} + func updateSQLiteDatabaseFromV10(dbHandle *sql.DB) error { if err := updateSQLiteDatabaseFrom10To11(dbHandle); err != nil { return err