212 lines
5.4 KiB
Go
212 lines
5.4 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"embed"
|
|
"fmt"
|
|
"os"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/influxdata/influxdb/v2/kit/migration"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type Migrator struct {
|
|
store *SqlStore
|
|
log *zap.Logger
|
|
|
|
backupPath string
|
|
}
|
|
|
|
func NewMigrator(store *SqlStore, log *zap.Logger) *Migrator {
|
|
return &Migrator{
|
|
store: store,
|
|
log: log,
|
|
}
|
|
}
|
|
|
|
// SetBackupPath records the filepath where pre-migration state should be written prior to running migrations.
|
|
func (m *Migrator) SetBackupPath(path string) {
|
|
m.backupPath = path
|
|
}
|
|
|
|
func (m *Migrator) Up(ctx context.Context, source embed.FS) error {
|
|
return m.UpUntil(ctx, -1, source)
|
|
}
|
|
|
|
// UpUntil migrates until a specific migration.
|
|
// -1 or 0 will run all migrations, any other number will run up until that.
|
|
// Returns no error untilMigration is less than the already run migrations.
|
|
func (m *Migrator) UpUntil(ctx context.Context, untilMigration int, source embed.FS) error {
|
|
knownMigrations, err := source.ReadDir(".")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// sort the list according to the version number to ensure the migrations are applied in the correct order
|
|
sort.Slice(knownMigrations, func(i, j int) bool {
|
|
return knownMigrations[i].Name() < knownMigrations[j].Name()
|
|
})
|
|
|
|
executedMigrations, err := m.store.allMigrationNames()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var lastMigration int
|
|
for idx := range executedMigrations {
|
|
if idx > len(knownMigrations)-1 || executedMigrations[idx] != dropExtension(knownMigrations[idx].Name()) {
|
|
return migration.ErrInvalidMigration(executedMigrations[idx])
|
|
}
|
|
|
|
lastMigration, err = scriptVersion(executedMigrations[idx])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
var migrationsToDo int
|
|
if untilMigration < 1 {
|
|
migrationsToDo = len(knownMigrations[lastMigration:])
|
|
untilMigration = len(knownMigrations)
|
|
} else if untilMigration >= lastMigration {
|
|
migrationsToDo = len(knownMigrations[lastMigration:untilMigration])
|
|
} else {
|
|
return nil
|
|
}
|
|
|
|
if migrationsToDo == 0 {
|
|
return nil
|
|
}
|
|
|
|
if m.backupPath != "" && lastMigration != 0 {
|
|
m.log.Info("Backing up pre-migration metadata", zap.String("backup_path", m.backupPath))
|
|
if err := func() error {
|
|
out, err := os.Create(m.backupPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer out.Close()
|
|
|
|
if err := m.store.BackupSqlStore(ctx, out); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}(); err != nil {
|
|
return fmt.Errorf("failed to back up pre-migration metadata: %w", err)
|
|
}
|
|
}
|
|
|
|
m.log.Info("Bringing up metadata migrations", zap.Int("migration_count", migrationsToDo))
|
|
|
|
for _, f := range knownMigrations[lastMigration:untilMigration] {
|
|
n := f.Name()
|
|
|
|
m.log.Debug("Executing metadata migration", zap.String("migration_name", n))
|
|
mBytes, err := source.ReadFile(n)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
recordStmt := fmt.Sprintf(`INSERT INTO %s (name) VALUES (%q);`, migrationsTableName, dropExtension(n))
|
|
|
|
if err := m.store.execTrans(ctx, string(mBytes)+recordStmt); err != nil {
|
|
return err
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Down applies the "down" migrations until the SQL database has migrations only >= untilMigration. Use untilMigration = 0 to apply all
|
|
// down migrations, which will delete all data from the database.
|
|
func (m *Migrator) Down(ctx context.Context, untilMigration int, source embed.FS) error {
|
|
knownMigrations, err := source.ReadDir(".")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// sort the list according to the version number to ensure the migrations are applied in the correct order
|
|
sort.Slice(knownMigrations, func(i, j int) bool {
|
|
return knownMigrations[i].Name() < knownMigrations[j].Name()
|
|
})
|
|
|
|
executedMigrations, err := m.store.allMigrationNames()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
migrationsToDo := len(executedMigrations) - untilMigration
|
|
if migrationsToDo == 0 {
|
|
return nil
|
|
}
|
|
|
|
if migrationsToDo < 0 {
|
|
m.log.Warn("SQL metadata is already on a schema older than target, nothing to do")
|
|
return nil
|
|
}
|
|
|
|
if m.backupPath != "" {
|
|
m.log.Info("Backing up pre-migration metadata", zap.String("backup_path", m.backupPath))
|
|
if err := func() error {
|
|
out, err := os.Create(m.backupPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer out.Close()
|
|
|
|
if err := m.store.BackupSqlStore(ctx, out); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}(); err != nil {
|
|
return fmt.Errorf("failed to back up pre-migration metadata: %w", err)
|
|
}
|
|
}
|
|
|
|
m.log.Info("Tearing down metadata migrations", zap.Int("migration_count", migrationsToDo))
|
|
|
|
for i := len(executedMigrations) - 1; i >= untilMigration; i-- {
|
|
downName := knownMigrations[i].Name()
|
|
downNameNoExtension := dropExtension(downName)
|
|
|
|
m.log.Debug("Executing metadata migration", zap.String("migration_name", downName))
|
|
mBytes, err := source.ReadFile(downName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
deleteStmt := fmt.Sprintf(`DELETE FROM %s WHERE name = %q;`, migrationsTableName, downNameNoExtension)
|
|
|
|
if err := m.store.execTrans(ctx, deleteStmt+string(mBytes)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// extract the version number as an integer from a file named like "0002_migration_name.sql"
|
|
func scriptVersion(filename string) (int, error) {
|
|
vString := strings.Split(filename, "_")[0]
|
|
vInt, err := strconv.Atoi(vString)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return vInt, nil
|
|
}
|
|
|
|
// dropExtension returns the filename excluding anything after the first "."
|
|
func dropExtension(filename string) string {
|
|
idx := strings.Index(filename, ".")
|
|
if idx == -1 {
|
|
return filename
|
|
}
|
|
|
|
return filename[:idx]
|
|
}
|