initprovider: check if the provider is already initialized

exit with code 0 if no initialization is required
This commit is contained in:
Nicola Murino 2020-08-30 13:50:43 +02:00
parent 2746c0b0f1
commit 600a107699
9 changed files with 22 additions and 15 deletions

View file

@ -46,6 +46,8 @@ Please take a look at the usage below to customize the options.`,
err = dataprovider.InitializeDatabase(providerConf, configDir) err = dataprovider.InitializeDatabase(providerConf, configDir)
if err == nil { if err == nil {
logger.DebugToConsole("Data provider successfully initialized") logger.DebugToConsole("Data provider successfully initialized")
} else if err == dataprovider.ErrNoInitRequired {
logger.DebugToConsole("%v", err.Error())
} else { } else {
logger.WarnToConsole("Unable to initialize data provider: %v", err) logger.WarnToConsole("Unable to initialize data provider: %v", err)
os.Exit(1) os.Exit(1)

View file

@ -703,7 +703,7 @@ func (p BoltProvider) reloadConfig() error {
// initializeDatabase does nothing, no initilization is needed for bolt provider // initializeDatabase does nothing, no initilization is needed for bolt provider
func (p BoltProvider) initializeDatabase() error { func (p BoltProvider) initializeDatabase() error {
return errNoInitRequired return ErrNoInitRequired
} }
func (p BoltProvider) migrateDatabase() error { func (p BoltProvider) migrateDatabase() error {

View file

@ -95,11 +95,13 @@ var (
// ErrNoAuthTryed defines the error for connection closed before authentication // ErrNoAuthTryed defines the error for connection closed before authentication
ErrNoAuthTryed = errors.New("no auth tryed") ErrNoAuthTryed = errors.New("no auth tryed")
// ValidProtocols defines all the valid protcols // ValidProtocols defines all the valid protcols
ValidProtocols = []string{"SSH", "FTP", "DAV"} ValidProtocols = []string{"SSH", "FTP", "DAV"}
config Config // ErrNoInitRequired defines the error returned by InitProvider if no inizialization is required
provider Provider ErrNoInitRequired = errors.New("Data provider initialization is not required")
sqlPlaceholders []string config Config
hashPwdPrefixes = []string{argonPwdPrefix, bcryptPwdPrefix, pbkdf2SHA1Prefix, pbkdf2SHA256Prefix, provider Provider
sqlPlaceholders []string
hashPwdPrefixes = []string{argonPwdPrefix, bcryptPwdPrefix, pbkdf2SHA1Prefix, pbkdf2SHA256Prefix,
pbkdf2SHA512Prefix, pbkdf2SHA256B64SaltPrefix, md5cryptPwdPrefix, md5cryptApr1PwdPrefix, sha512cryptPwdPrefix} pbkdf2SHA512Prefix, pbkdf2SHA256B64SaltPrefix, md5cryptPwdPrefix, md5cryptApr1PwdPrefix, sha512cryptPwdPrefix}
pbkdfPwdPrefixes = []string{pbkdf2SHA1Prefix, pbkdf2SHA256Prefix, pbkdf2SHA512Prefix, pbkdf2SHA256B64SaltPrefix} pbkdfPwdPrefixes = []string{pbkdf2SHA1Prefix, pbkdf2SHA256Prefix, pbkdf2SHA512Prefix, pbkdf2SHA256B64SaltPrefix}
pbkdfPwdB64SaltPrefixes = []string{pbkdf2SHA256B64SaltPrefix} pbkdfPwdB64SaltPrefixes = []string{pbkdf2SHA256B64SaltPrefix}
@ -108,7 +110,6 @@ var (
availabilityTicker *time.Ticker availabilityTicker *time.Ticker
availabilityTickerDone chan bool availabilityTickerDone chan bool
errWrongPassword = errors.New("password does not match") errWrongPassword = errors.New("password does not match")
errNoInitRequired = errors.New("initialization is not required for this data provider")
credentialsDirPath string credentialsDirPath string
sqlTableUsers = "users" sqlTableUsers = "users"
sqlTableFolders = "folders" sqlTableFolders = "folders"
@ -422,7 +423,7 @@ func InitializeDatabase(cnf Config, basePath string) error {
config = cnf config = cnf
if config.Driver == BoltDataProviderName || config.Driver == MemoryDataProviderName { if config.Driver == BoltDataProviderName || config.Driver == MemoryDataProviderName {
return errNoInitRequired return ErrNoInitRequired
} }
err := createProvider(basePath) err := createProvider(basePath)
if err != nil { if err != nil {

View file

@ -667,7 +667,7 @@ func (p MemoryProvider) reloadConfig() error {
// initializeDatabase does nothing, no initilization is needed for memory provider // initializeDatabase does nothing, no initilization is needed for memory provider
func (p MemoryProvider) initializeDatabase() error { func (p MemoryProvider) initializeDatabase() error {
return errNoInitRequired return ErrNoInitRequired
} }
func (p MemoryProvider) migrateDatabase() error { func (p MemoryProvider) migrateDatabase() error {

View file

@ -196,7 +196,7 @@ func (p MySQLProvider) initializeDatabase() error {
} }
func (p MySQLProvider) migrateDatabase() error { func (p MySQLProvider) migrateDatabase() error {
dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle) dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true)
if err != nil { if err != nil {
return err return err
} }

View file

@ -195,7 +195,7 @@ func (p PGSQLProvider) initializeDatabase() error {
} }
func (p PGSQLProvider) migrateDatabase() error { func (p PGSQLProvider) migrateDatabase() error {
dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle) dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true)
if err != nil { if err != nil {
return err return err
} }

View file

@ -766,7 +766,7 @@ func sqlCommonRollbackTransaction(tx *sql.Tx) {
} }
} }
func sqlCommonGetDatabaseVersion(dbHandle *sql.DB) (schemaVersion, error) { func sqlCommonGetDatabaseVersion(dbHandle *sql.DB, showInitWarn bool) (schemaVersion, error) {
var result schemaVersion var result schemaVersion
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel() defer cancel()
@ -774,7 +774,7 @@ func sqlCommonGetDatabaseVersion(dbHandle *sql.DB) (schemaVersion, error) {
stmt, err := dbHandle.PrepareContext(ctx, q) stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil { if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err) providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
if strings.Contains(err.Error(), sqlTableSchemaVersion) { if showInitWarn && strings.Contains(err.Error(), sqlTableSchemaVersion) {
logger.WarnToConsole("database query error, did you forgot to run the \"initprovider\" command?") logger.WarnToConsole("database query error, did you forgot to run the \"initprovider\" command?")
} }
return result, err return result, err

View file

@ -194,6 +194,10 @@ func (p SQLiteProvider) reloadConfig() error {
// initializeDatabase creates the initial database structure // initializeDatabase creates the initial database structure
func (p SQLiteProvider) initializeDatabase() error { func (p SQLiteProvider) initializeDatabase() error {
dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, false)
if err == nil && dbVersion.Version > 0 {
return ErrNoInitRequired
}
sqlUsers := strings.Replace(sqliteUsersTableSQL, "{{users}}", sqlTableUsers, 1) sqlUsers := strings.Replace(sqliteUsersTableSQL, "{{users}}", sqlTableUsers, 1)
tx, err := p.dbHandle.Begin() tx, err := p.dbHandle.Begin()
if err != nil { if err != nil {
@ -218,7 +222,7 @@ func (p SQLiteProvider) initializeDatabase() error {
} }
func (p SQLiteProvider) migrateDatabase() error { func (p SQLiteProvider) migrateDatabase() error {
dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle) dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true)
if err != nil { if err != nil {
return err return err
} }

View file

@ -103,4 +103,4 @@
"ca_certificates": [], "ca_certificates": [],
"skip_tls_verify": false "skip_tls_verify": false
} }
} }