Преглед изворни кода

initprovider: check if the provider is already initialized

exit with code 0 if no initialization is required
Nicola Murino пре 4 година
родитељ
комит
600a107699

+ 2 - 0
cmd/initprovider.go

@@ -46,6 +46,8 @@ Please take a look at the usage below to customize the options.`,
 			err = dataprovider.InitializeDatabase(providerConf, configDir)
 			if err == nil {
 				logger.DebugToConsole("Data provider successfully initialized")
+			} else if err == dataprovider.ErrNoInitRequired {
+				logger.DebugToConsole("%v", err.Error())
 			} else {
 				logger.WarnToConsole("Unable to initialize data provider: %v", err)
 				os.Exit(1)

+ 1 - 1
dataprovider/bolt.go

@@ -703,7 +703,7 @@ func (p BoltProvider) reloadConfig() error {
 
 // initializeDatabase does nothing, no initilization is needed for bolt provider
 func (p BoltProvider) initializeDatabase() error {
-	return errNoInitRequired
+	return ErrNoInitRequired
 }
 
 func (p BoltProvider) migrateDatabase() error {

+ 8 - 7
dataprovider/dataprovider.go

@@ -95,11 +95,13 @@ var (
 	// ErrNoAuthTryed defines the error for connection closed before authentication
 	ErrNoAuthTryed = errors.New("no auth tryed")
 	// ValidProtocols defines all the valid protcols
-	ValidProtocols  = []string{"SSH", "FTP", "DAV"}
-	config          Config
-	provider        Provider
-	sqlPlaceholders []string
-	hashPwdPrefixes = []string{argonPwdPrefix, bcryptPwdPrefix, pbkdf2SHA1Prefix, pbkdf2SHA256Prefix,
+	ValidProtocols = []string{"SSH", "FTP", "DAV"}
+	// ErrNoInitRequired defines the error returned by InitProvider if no inizialization is required
+	ErrNoInitRequired = errors.New("Data provider initialization is not required")
+	config            Config
+	provider          Provider
+	sqlPlaceholders   []string
+	hashPwdPrefixes   = []string{argonPwdPrefix, bcryptPwdPrefix, pbkdf2SHA1Prefix, pbkdf2SHA256Prefix,
 		pbkdf2SHA512Prefix, pbkdf2SHA256B64SaltPrefix, md5cryptPwdPrefix, md5cryptApr1PwdPrefix, sha512cryptPwdPrefix}
 	pbkdfPwdPrefixes        = []string{pbkdf2SHA1Prefix, pbkdf2SHA256Prefix, pbkdf2SHA512Prefix, pbkdf2SHA256B64SaltPrefix}
 	pbkdfPwdB64SaltPrefixes = []string{pbkdf2SHA256B64SaltPrefix}
@@ -108,7 +110,6 @@ var (
 	availabilityTicker      *time.Ticker
 	availabilityTickerDone  chan bool
 	errWrongPassword        = errors.New("password does not match")
-	errNoInitRequired       = errors.New("initialization is not required for this data provider")
 	credentialsDirPath      string
 	sqlTableUsers           = "users"
 	sqlTableFolders         = "folders"
@@ -422,7 +423,7 @@ func InitializeDatabase(cnf Config, basePath string) error {
 	config = cnf
 
 	if config.Driver == BoltDataProviderName || config.Driver == MemoryDataProviderName {
-		return errNoInitRequired
+		return ErrNoInitRequired
 	}
 	err := createProvider(basePath)
 	if err != nil {

+ 1 - 1
dataprovider/memory.go

@@ -667,7 +667,7 @@ func (p MemoryProvider) reloadConfig() error {
 
 // initializeDatabase does nothing, no initilization is needed for memory provider
 func (p MemoryProvider) initializeDatabase() error {
-	return errNoInitRequired
+	return ErrNoInitRequired
 }
 
 func (p MemoryProvider) migrateDatabase() error {

+ 1 - 1
dataprovider/mysql.go

@@ -196,7 +196,7 @@ func (p MySQLProvider) initializeDatabase() error {
 }
 
 func (p MySQLProvider) migrateDatabase() error {
-	dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle)
+	dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true)
 	if err != nil {
 		return err
 	}

+ 1 - 1
dataprovider/pgsql.go

@@ -195,7 +195,7 @@ func (p PGSQLProvider) initializeDatabase() error {
 }
 
 func (p PGSQLProvider) migrateDatabase() error {
-	dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle)
+	dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true)
 	if err != nil {
 		return err
 	}

+ 2 - 2
dataprovider/sqlcommon.go

@@ -766,7 +766,7 @@ func sqlCommonRollbackTransaction(tx *sql.Tx) {
 	}
 }
 
-func sqlCommonGetDatabaseVersion(dbHandle *sql.DB) (schemaVersion, error) {
+func sqlCommonGetDatabaseVersion(dbHandle *sql.DB, showInitWarn bool) (schemaVersion, error) {
 	var result schemaVersion
 	ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
 	defer cancel()
@@ -774,7 +774,7 @@ func sqlCommonGetDatabaseVersion(dbHandle *sql.DB) (schemaVersion, error) {
 	stmt, err := dbHandle.PrepareContext(ctx, q)
 	if err != nil {
 		providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
-		if strings.Contains(err.Error(), sqlTableSchemaVersion) {
+		if showInitWarn && strings.Contains(err.Error(), sqlTableSchemaVersion) {
 			logger.WarnToConsole("database query error, did you forgot to run the \"initprovider\" command?")
 		}
 		return result, err

+ 5 - 1
dataprovider/sqlite.go

@@ -194,6 +194,10 @@ func (p SQLiteProvider) reloadConfig() error {
 
 // initializeDatabase creates the initial database structure
 func (p SQLiteProvider) initializeDatabase() error {
+	dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, false)
+	if err == nil && dbVersion.Version > 0 {
+		return ErrNoInitRequired
+	}
 	sqlUsers := strings.Replace(sqliteUsersTableSQL, "{{users}}", sqlTableUsers, 1)
 	tx, err := p.dbHandle.Begin()
 	if err != nil {
@@ -218,7 +222,7 @@ func (p SQLiteProvider) initializeDatabase() error {
 }
 
 func (p SQLiteProvider) migrateDatabase() error {
-	dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle)
+	dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true)
 	if err != nil {
 		return err
 	}

+ 1 - 1
sftpgo.json

@@ -103,4 +103,4 @@
     "ca_certificates": [],
     "skip_tls_verify": false
   }
-}
+}