add resetprovider sub-command

Fixes #608
This commit is contained in:
Nicola Murino 2021-11-15 18:40:31 +01:00
parent ca730e77a5
commit e29a3efd39
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
10 changed files with 199 additions and 9 deletions

View file

@ -135,13 +135,19 @@ sftpgo initprovider --help
You can disable automatic data provider checks/updates at startup by setting the `update_mode` configuration key to `1`. You can disable automatic data provider checks/updates at startup by setting the `update_mode` configuration key to `1`.
You can also reset your provider by using the `resetprovider` sub-command. Take a look at the CLI usage for more details:
```bash
sftpgo resetprovider --help
```
## Create the first admin ## Create the first admin
To start using SFTPGo you need to create an admin user, you can do it in several ways: To start using SFTPGo you need to create an admin user, you can do it in several ways:
- by using the web admin interface. The default URL is [http://127.0.0.1:8080/web/admin](http://127.0.0.1:8080/web/admin) - by using the web admin interface. The default URL is [http://127.0.0.1:8080/web/admin](http://127.0.0.1:8080/web/admin)
- by loading initial data - by loading initial data
- by enabling `create_default_admin` in your configuration file. In this case the credentials are `admin`/`password` - by enabling `create_default_admin` in your configuration file and setting the environment variables `SFTPGO_DEFAULT_ADMIN_USERNAME` and `SFTPGO_DEFAULT_ADMIN_PASSWORD`
## Upgrading ## Upgrading

75
cmd/resetprovider.go Normal file
View file

@ -0,0 +1,75 @@
package cmd
import (
"bufio"
"os"
"strings"
"github.com/rs/zerolog"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
var (
resetProviderForce bool
resetProviderCmd = &cobra.Command{
Use: "resetprovider",
Short: "Reset the configured provider, any data will be lost",
Long: `This command reads the data provider connection details from the specified
configuration file and resets the provider by deleting all data and schemas.
This command is not supported for the memory provider.
Please take a look at the usage below to customize the options.`,
Run: func(cmd *cobra.Command, args []string) {
logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel)
configDir = util.CleanDirInput(configDir)
err := config.LoadConfig(configDir, configFile)
if err != nil {
logger.WarnToConsole("Unable to initialize data provider, config load error: %v", err)
os.Exit(1)
}
kmsConfig := config.GetKMSConfig()
err = kmsConfig.Initialize()
if err != nil {
logger.ErrorToConsole("unable to initialize KMS: %v", err)
os.Exit(1)
}
providerConf := config.GetProviderConf()
if !resetProviderForce {
logger.WarnToConsole("You are about to delete all the SFTPGo data for provider %#v, config file: %#v",
providerConf.Driver, viper.ConfigFileUsed())
logger.WarnToConsole("Are you sure? (Y/n)")
reader := bufio.NewReader(os.Stdin)
answer, err := reader.ReadString('\n')
if err != nil {
logger.ErrorToConsole("unable to read your answer: %v", err)
os.Exit(1)
}
if strings.ToUpper(strings.TrimSpace(answer)) != "Y" {
logger.InfoToConsole("command aborted")
os.Exit(1)
}
}
logger.InfoToConsole("Resetting provider: %#v, config file: %#v", providerConf.Driver, viper.ConfigFileUsed())
err = dataprovider.ResetDatabase(providerConf, configDir)
if err != nil {
logger.WarnToConsole("Error resetting provider: %v", err)
os.Exit(1)
}
logger.InfoToConsole("Tha data provider was successfully reset")
},
}
)
func init() {
addConfigFlags(resetProviderCmd)
resetProviderCmd.Flags().BoolVar(&resetProviderForce, "force", false, `reset the provider without asking for confirmation`)
rootCmd.AddCommand(resetProviderCmd)
}

View file

@ -683,8 +683,10 @@ func (conns *ActiveConnections) Swap(c ActiveConnection) error {
for idx, conn := range conns.connections { for idx, conn := range conns.connections {
if conn.GetID() == c.GetID() { if conn.GetID() == c.GetID() {
conn = nil err := conn.CloseFS()
conns.connections[idx] = c conns.connections[idx] = c
logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err)
conn = nil
return nil return nil
} }
} }

View file

@ -1418,6 +1418,18 @@ func (p *BoltProvider) revertDatabase(targetVersion int) error {
} }
} }
func (p *BoltProvider) resetDatabase() error {
return p.dbHandle.Update(func(tx *bolt.Tx) error {
for _, bucketName := range boltBuckets {
err := tx.DeleteBucket(bucketName)
if err != nil && !errors.Is(err, bolt.ErrBucketNotFound) {
return fmt.Errorf("unable to remove bucket %v: %w", bucketName, err)
}
}
return nil
})
}
func joinUserAndFolders(u []byte, foldersBucket *bolt.Bucket) (User, error) { func joinUserAndFolders(u []byte, foldersBucket *bolt.Bucket) (User, error) {
var user User var user User
err := json.Unmarshal(u, &user) err := json.Unmarshal(u, &user)

View file

@ -457,6 +457,7 @@ type Provider interface {
initializeDatabase() error initializeDatabase() error
migrateDatabase() error migrateDatabase() error
revertDatabase(targetVersion int) error revertDatabase(targetVersion int) error
resetDatabase() error
} }
// SetTempPath sets the path for temporary files // SetTempPath sets the path for temporary files
@ -653,6 +654,22 @@ func RevertDatabase(cnf Config, basePath string, targetVersion int) error {
return provider.revertDatabase(targetVersion) return provider.revertDatabase(targetVersion)
} }
// ResetDatabase restores schema and/or data to a previous version
func ResetDatabase(cnf Config, basePath string) error {
config = cnf
if filepath.IsAbs(config.CredentialsPath) {
credentialsDirPath = config.CredentialsPath
} else {
credentialsDirPath = filepath.Join(basePath, config.CredentialsPath)
}
if err := createProvider(basePath); err != nil {
return err
}
return provider.resetDatabase()
}
// CheckAdminAndPass validates the given admin and password connecting from ip // CheckAdminAndPass validates the given admin and password connecting from ip
func CheckAdminAndPass(username, password, ip string) (Admin, error) { func CheckAdminAndPass(username, password, ip string) (Admin, error) {
return provider.validateAdminAndPass(username, password, ip) return provider.validateAdminAndPass(username, password, ip)

View file

@ -1468,3 +1468,7 @@ func (p *MemoryProvider) migrateDatabase() error {
func (p *MemoryProvider) revertDatabase(targetVersion int) error { func (p *MemoryProvider) revertDatabase(targetVersion int) error {
return errors.New("memory provider does not store data, revert not possible") return errors.New("memory provider does not store data, revert not possible")
} }
func (p *MemoryProvider) resetDatabase() error {
return errors.New("memory provider does not store data, reset not possible")
}

View file

@ -21,6 +21,13 @@ import (
) )
const ( const (
mysqlResetSQL = "DROP TABLE IF EXISTS `{{api_keys}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{folders_mapping}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{admins}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{folders}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{shares}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{users}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{schema_version}}` CASCADE;"
mysqlInitialSQL = "CREATE TABLE `{{schema_version}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" + mysqlInitialSQL = "CREATE TABLE `{{schema_version}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" +
"CREATE TABLE `{{admins}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " + "CREATE TABLE `{{admins}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " +
"`description` varchar(512) NULL, `password` varchar(255) NOT NULL, `email` varchar(255) NULL, `status` integer NOT NULL, " + "`description` varchar(512) NULL, `password` varchar(255) NOT NULL, `email` varchar(255) NULL, `status` integer NOT NULL, " +
@ -318,6 +325,9 @@ func (p *MySQLProvider) initializeDatabase() error {
if err == nil && dbVersion.Version > 0 { if err == nil && dbVersion.Version > 0 {
return ErrNoInitRequired return ErrNoInitRequired
} }
if errors.Is(err, sql.ErrNoRows) {
return errSchemaVersionEmpty
}
initialSQL := strings.ReplaceAll(mysqlInitialSQL, "{{schema_version}}", sqlTableSchemaVersion) initialSQL := strings.ReplaceAll(mysqlInitialSQL, "{{schema_version}}", sqlTableSchemaVersion)
initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins) initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders) initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders)
@ -387,6 +397,17 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error {
} }
} }
func (p *MySQLProvider) resetDatabase() error {
sql := strings.ReplaceAll(mysqlResetSQL, "{{schema_version}}", sqlTableSchemaVersion)
sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins)
sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys)
sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(sql, ";"), 0)
}
func updateMySQLDatabaseFromV10(dbHandle *sql.DB) error { func updateMySQLDatabaseFromV10(dbHandle *sql.DB) error {
if err := updateMySQLDatabaseFrom10To11(dbHandle); err != nil { if err := updateMySQLDatabaseFrom10To11(dbHandle); err != nil {
return err return err

View file

@ -21,6 +21,14 @@ import (
) )
const ( const (
pgsqlResetSQL = `DROP TABLE IF EXISTS "{{api_keys}}" CASCADE;
DROP TABLE IF EXISTS "{{folders_mapping}}" CASCADE;
DROP TABLE IF EXISTS "{{admins}}" CASCADE;
DROP TABLE IF EXISTS "{{folders}}" CASCADE;
DROP TABLE IF EXISTS "{{shares}}" CASCADE;
DROP TABLE IF EXISTS "{{users}}" CASCADE;
DROP TABLE IF EXISTS "{{schema_version}}" CASCADE;
`
pgsqlInitial = `CREATE TABLE "{{schema_version}}" ("id" serial NOT NULL PRIMARY KEY, "version" integer NOT NULL); pgsqlInitial = `CREATE TABLE "{{schema_version}}" ("id" serial NOT NULL PRIMARY KEY, "version" integer NOT NULL);
CREATE TABLE "{{admins}}" ("id" serial NOT NULL PRIMARY KEY, "username" varchar(255) NOT NULL UNIQUE, CREATE TABLE "{{admins}}" ("id" serial NOT NULL PRIMARY KEY, "username" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL, "description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL,
@ -332,6 +340,9 @@ func (p *PGSQLProvider) initializeDatabase() error {
if err == nil && dbVersion.Version > 0 { if err == nil && dbVersion.Version > 0 {
return ErrNoInitRequired return ErrNoInitRequired
} }
if errors.Is(err, sql.ErrNoRows) {
return errSchemaVersionEmpty
}
initialSQL := strings.ReplaceAll(pgsqlInitial, "{{schema_version}}", sqlTableSchemaVersion) initialSQL := strings.ReplaceAll(pgsqlInitial, "{{schema_version}}", sqlTableSchemaVersion)
initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins) initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders) initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders)
@ -407,6 +418,17 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error {
} }
} }
func (p *PGSQLProvider) resetDatabase() error {
sql := strings.ReplaceAll(pgsqlResetSQL, "{{schema_version}}", sqlTableSchemaVersion)
sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins)
sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys)
sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0)
}
func updatePGSQLDatabaseFromV10(dbHandle *sql.DB) error { func updatePGSQLDatabaseFromV10(dbHandle *sql.DB) error {
if err := updatePGSQLDatabaseFrom10To11(dbHandle); err != nil { if err := updatePGSQLDatabaseFrom10To11(dbHandle); err != nil {
return err return err

View file

@ -24,7 +24,10 @@ const (
longSQLQueryTimeout = 60 * time.Second longSQLQueryTimeout = 60 * time.Second
) )
var errSQLFoldersAssosaction = errors.New("unable to associate virtual folders to user") var (
errSQLFoldersAssosaction = errors.New("unable to associate virtual folders to user")
errSchemaVersionEmpty = errors.New("we can't determine schema version because the schema_migration table is empty. The SFTPGo database might be corrupted. Consider using the \"resetprovider\" sub-command")
)
type sqlQuerier interface { type sqlQuerier interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
@ -946,7 +949,7 @@ func getShareFromDbRow(row sqlScanner) (Share, error) {
&share.LastUseAt, &share.ExpiresAt, &password, &share.MaxTokens, &share.LastUseAt, &share.ExpiresAt, &password, &share.MaxTokens,
&share.UsedTokens, &allowFrom) &share.UsedTokens, &allowFrom)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
return share, util.NewRecordNotFoundError(err.Error()) return share, util.NewRecordNotFoundError(err.Error())
} }
return share, err return share, err
@ -986,7 +989,7 @@ func getAPIKeyFromDbRow(row sqlScanner) (APIKey, error) {
&apiKey.LastUseAt, &apiKey.ExpiresAt, &description, &userID, &adminID) &apiKey.LastUseAt, &apiKey.ExpiresAt, &description, &userID, &adminID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
return apiKey, util.NewRecordNotFoundError(err.Error()) return apiKey, util.NewRecordNotFoundError(err.Error())
} }
return apiKey, err return apiKey, err
@ -1013,7 +1016,7 @@ func getAdminFromDbRow(row sqlScanner) (Admin, error) {
&filters, &additionalInfo, &description, &admin.CreatedAt, &admin.UpdatedAt, &admin.LastLogin) &filters, &additionalInfo, &description, &admin.CreatedAt, &admin.UpdatedAt, &admin.LastLogin)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
return admin, util.NewRecordNotFoundError(err.Error()) return admin, util.NewRecordNotFoundError(err.Error())
} }
return admin, err return admin, err
@ -1063,7 +1066,7 @@ func getUserFromDbRow(row sqlScanner) (User, error) {
&user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig, &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig,
&additionalInfo, &description, &email, &user.CreatedAt, &user.UpdatedAt) &additionalInfo, &description, &email, &user.CreatedAt, &user.UpdatedAt)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
return user, util.NewRecordNotFoundError(err.Error()) return user, util.NewRecordNotFoundError(err.Error())
} }
return user, err return user, err
@ -1143,8 +1146,11 @@ func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (
var mappedPath, description, fsConfig sql.NullString var mappedPath, description, fsConfig sql.NullString
err = row.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate, err = row.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate,
&folder.Name, &description, &fsConfig) &folder.Name, &description, &fsConfig)
if err == sql.ErrNoRows { if err != nil {
return folder, util.NewRecordNotFoundError(err.Error()) if errors.Is(err, sql.ErrNoRows) {
return folder, util.NewRecordNotFoundError(err.Error())
}
return folder, err
} }
if mappedPath.Valid { if mappedPath.Valid {
folder.MappedPath = mappedPath.String folder.MappedPath = mappedPath.String
@ -1688,6 +1694,9 @@ func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, n
return err return err
} }
} }
if newVersion == 0 {
return nil
}
return sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion) return sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion)
}) })
} }

View file

@ -22,6 +22,14 @@ import (
) )
const ( const (
sqliteResetSQL = `DROP TABLE IF EXISTS "{{api_keys}}";
DROP TABLE IF EXISTS "{{folders_mapping}}";
DROP TABLE IF EXISTS "{{admins}}";
DROP TABLE IF EXISTS "{{folders}}";
DROP TABLE IF EXISTS "{{shares}}";
DROP TABLE IF EXISTS "{{users}}";
DROP TABLE IF EXISTS "{{schema_version}}";
`
sqliteInitialSQL = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "version" integer NOT NULL); sqliteInitialSQL = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "version" integer NOT NULL);
CREATE TABLE "{{admins}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "username" varchar(255) NOT NULL UNIQUE, CREATE TABLE "{{admins}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "username" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL, "description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL,
@ -314,6 +322,9 @@ func (p *SQLiteProvider) initializeDatabase() error {
if err == nil && dbVersion.Version > 0 { if err == nil && dbVersion.Version > 0 {
return ErrNoInitRequired return ErrNoInitRequired
} }
if errors.Is(err, sql.ErrNoRows) {
return errSchemaVersionEmpty
}
initialSQL := strings.ReplaceAll(sqliteInitialSQL, "{{schema_version}}", sqlTableSchemaVersion) initialSQL := strings.ReplaceAll(sqliteInitialSQL, "{{schema_version}}", sqlTableSchemaVersion)
initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins) initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders) initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders)
@ -383,6 +394,17 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error {
} }
} }
func (p *SQLiteProvider) resetDatabase() error {
sql := strings.ReplaceAll(sqliteResetSQL, "{{schema_version}}", sqlTableSchemaVersion)
sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins)
sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys)
sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0)
}
func updateSQLiteDatabaseFromV10(dbHandle *sql.DB) error { func updateSQLiteDatabaseFromV10(dbHandle *sql.DB) error {
if err := updateSQLiteDatabaseFrom10To11(dbHandle); err != nil { if err := updateSQLiteDatabaseFrom10To11(dbHandle); err != nil {
return err return err