diff --git a/README.md b/README.md index 0b392b9f..310f372d 100644 --- a/README.md +++ b/README.md @@ -156,7 +156,7 @@ You can also reset your provider by using the `resetprovider` sub-command. Take sftpgo resetprovider --help ``` -:warning: Please note that some data providers (e.g. MySQL and CockroachDB) do not support schema changes within a transaction, this means that you may end up with an inconsistent schema if migrations are forcibly aborted or if they are run concurrently by multiple instances. +:warning: Please note that some data providers (e.g. MySQL and CockroachDB) do not support schema changes within a transaction, this means that you may end up with an inconsistent schema if migrations are forcibly aborted. CockroachDB doesn't support database-level locks, so make sure you don't execute migrations concurrently. ## Create the first admin diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go index c6c69744..f9abcadc 100644 --- a/dataprovider/mysql.go +++ b/dataprovider/mysql.go @@ -234,7 +234,7 @@ func getMySQLConnectionString(redactedPwd bool) (string, error) { return "", fmt.Errorf("unable to register tls config: %v", err) } } - connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8mb4&interpolateParams=true&timeout=10s&parseTime=true&tls=%v&writeTimeout=10s&readTimeout=10s", + connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8mb4&interpolateParams=true&timeout=10s&parseTime=true&tls=%v&writeTimeout=60s&readTimeout=60s", config.Username, password, config.Host, config.Port, config.Name, sslMode) } else { connectionString = config.ConnectionString diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go index 94b26863..9c2b270a 100644 --- a/dataprovider/sqlcommon.go +++ b/dataprovider/sqlcommon.go @@ -2911,10 +2911,11 @@ func sqlCommonGetAPIKeyRelatedIDs(apiKey *APIKey) (sql.NullInt64, sql.NullInt64, return userID, adminID, nil } -func sqlCommonGetDatabaseVersion(dbHandle *sql.DB, showInitWarn bool) (schemaVersion, error) { +func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schemaVersion, error) { var result schemaVersion ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() + q := getDatabaseVersionQuery() stmt, err := dbHandle.PrepareContext(ctx, q) if err != nil { @@ -2943,9 +2944,23 @@ func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, ve } func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, newVersion int) error { + if err := sqlAquireLock(dbHandle); err != nil { + return err + } + defer sqlReleaseLock(dbHandle) + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() + if newVersion > 0 { + currentVersion, err := sqlCommonGetDatabaseVersion(dbHandle, false) + if err == nil && currentVersion.Version >= newVersion { + providerLog(logger.LevelInfo, "current schema version: %v, requested: %v, did you execute simultaneous migrations?", + currentVersion.Version, newVersion) + return nil + } + } + return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { for _, q := range sqlQueries { if strings.TrimSpace(q) == "" { @@ -2963,6 +2978,63 @@ func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, n }) } +func sqlAquireLock(dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + switch config.Driver { + case PGSQLDataProviderName: + _, err := dbHandle.ExecContext(ctx, `SELECT pg_advisory_lock(101,1)`) + if err != nil { + return fmt.Errorf("unable to get advisory lock: %w", err) + } + providerLog(logger.LevelInfo, "acquired database lock") + case MySQLDataProviderName: + stmt, err := dbHandle.PrepareContext(ctx, `SELECT GET_LOCK('sftpgo.migration',30)`) + if err != nil { + return fmt.Errorf("unable to get lock: %w", err) + } + defer stmt.Close() + + var lockResult sql.NullInt64 + err = stmt.QueryRowContext(ctx).Scan(&lockResult) + if err != nil { + return fmt.Errorf("unable to get lock: %w", err) + } + if !lockResult.Valid { + return errors.New("unable to get lock: null value returned") + } + if lockResult.Int64 != 1 { + return fmt.Errorf("unable to get lock, result: %v", lockResult.Int64) + } + providerLog(logger.LevelInfo, "acquired database lock") + } + + return nil +} + +func sqlReleaseLock(dbHandle *sql.DB) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + switch config.Driver { + case PGSQLDataProviderName: + _, err := dbHandle.ExecContext(ctx, `SELECT pg_advisory_unlock(101,1)`) + if err != nil { + providerLog(logger.LevelWarn, "unable to release lock: %v", err) + } else { + providerLog(logger.LevelInfo, "released database lock") + } + case MySQLDataProviderName: + _, err := dbHandle.ExecContext(ctx, `SELECT RELEASE_LOCK('sftpgo.migration')`) + if err != nil { + providerLog(logger.LevelWarn, "unable to release lock: %v", err) + } else { + providerLog(logger.LevelInfo, "released database lock") + } + } +} + func sqlCommonExecuteTx(ctx context.Context, dbHandle *sql.DB, txFn func(*sql.Tx) error) error { if config.Driver == CockroachDataProviderName { return crdb.ExecuteTx(ctx, dbHandle, nil, txFn)