Ver Fonte

Add upgrade file

Kailash Nadh há 5 anos atrás
pai
commit
b54c5d8c5e
1 ficheiros alterados com 148 adições e 0 exclusões
  1. 148 0
      upgrade.go

+ 148 - 0
upgrade.go

@@ -0,0 +1,148 @@
+package main
+
+import (
+	"fmt"
+	"strings"
+
+	"github.com/jmoiron/sqlx"
+	"github.com/knadh/koanf"
+	"github.com/knadh/listmonk/internal/migrations"
+	"github.com/knadh/stuffbin"
+	"github.com/lib/pq"
+	"golang.org/x/mod/semver"
+)
+
+// migFunc represents a migration function for a particular version.
+// fn (generally) executes database migrations and additionally
+// takes the filesystem and config objects in case there are additional bits
+// of logic to be performed before executing upgrades. fn is idempotent.
+type migFunc struct {
+	version string
+	fn      func(*sqlx.DB, stuffbin.FileSystem, *koanf.Koanf) error
+}
+
+// migList is the list of available migList ordered by the semver.
+// Each migration is a Go file in internal/migrations named after the semver.
+// The functions are named as: v0.7.0 => migrations.V0_7_0() and are idempotent.
+var migList = []migFunc{
+	{"v0.4.0", migrations.V0_4_0},
+	{"v0.7.0", migrations.V0_7_0},
+}
+
+// upgrade upgrades the database to the current version by running SQL migration files
+// for all version from the last known version to the current one.
+func upgrade(db *sqlx.DB, fs stuffbin.FileSystem, prompt bool) {
+	if prompt {
+		var ok string
+		fmt.Printf("** IMPORTANT: Take a backup of the database before upgrading.\n")
+		fmt.Print("continue (y/n)?  ")
+		if _, err := fmt.Scanf("%s", &ok); err != nil {
+			lo.Fatalf("error reading value from terminal: %v", err)
+		}
+		if strings.ToLower(ok) != "y" {
+			fmt.Println("upgrade cancelled")
+			return
+		}
+	}
+
+	_, toRun, err := getPendingMigrations(db)
+	if err != nil {
+		lo.Fatalf("error checking migrations: %v", err)
+	}
+
+	// No migrations to run.
+	if len(toRun) == 0 {
+		lo.Printf("no upgrades to run. Database is up-to-date.")
+		return
+	}
+
+	// Execute migrations in succession.
+	for _, m := range toRun {
+		lo.Printf("running migration %s", m.version)
+		if err := m.fn(db, fs, ko); err != nil {
+			lo.Fatalf("error running migration %s: %v", m.version, err)
+		}
+
+		// Record the migration version in the settings table. There was no
+		// settings table until v0.7.0, so ignore the no-table errors.
+		if err := recordMigrationVersion(m.version, db); err != nil {
+			if isTableNotExistErr(err) {
+				continue
+			}
+			lo.Fatalf("error recording migration version %s: %v", m.version, err)
+		}
+	}
+
+	lo.Printf("upgrade complete")
+}
+
+// checkUpgrade checks if the current database schema matches the expected
+// binary version.
+func checkUpgrade(db *sqlx.DB) {
+	lastVer, toRun, err := getPendingMigrations(db)
+	if err != nil {
+		lo.Fatalf("error checking migrations: %v", err)
+	}
+
+	// No migrations to run.
+	if len(toRun) == 0 {
+		return
+	}
+
+	var vers []string
+	for _, m := range toRun {
+		vers = append(vers, m.version)
+	}
+
+	lo.Fatalf(`there are %d pending database upgrade(s): %v. The last upgrade was %s. Backup the database and run listmonk --upgrade`,
+		len(toRun), vers, lastVer)
+}
+
+// getPendingMigrations gets the pending migrations by comparing the last
+// recorded migration in the DB against all migrations listed in `migrations`.
+func getPendingMigrations(db *sqlx.DB) (string, []migFunc, error) {
+	lastVer, err := getLastMigrationVersion()
+	if err != nil {
+		return "", nil, err
+	}
+
+	// Iterate through the migration versions and get everything above the last
+	// last upgraded semver.
+	var toRun []migFunc
+	for i, m := range migList {
+		if semver.Compare(m.version, lastVer) > 0 {
+			toRun = migList[i:]
+			break
+		}
+	}
+
+	return lastVer, toRun, nil
+}
+
+// getLastMigrationVersion returns the last migration semver recorded in the DB.
+// If there isn't any, `v0.0.0` is returned.
+func getLastMigrationVersion() (string, error) {
+	var v string
+	if err := db.Get(&v, `
+		SELECT COALESCE(
+			(SELECT value->>-1 FROM settings WHERE key='migrations'),
+		'v0.0.0')`); err != nil {
+		if isTableNotExistErr(err) {
+			return "v0.0.0", nil
+		}
+		return v, err
+	}
+	return v, nil
+}
+
+// isPqNoTableErr checks if the given error represents a Postgres/pq
+// "table does not exist" error.
+func isTableNotExistErr(err error) bool {
+	if p, ok := err.(*pq.Error); ok {
+		// `settings` table does not exist. It was introduced in v0.7.0.
+		if p.Code == "42P01" {
+			return true
+		}
+	}
+	return false
+}