diff --git a/example/migrate/main.go b/example/migrate/main.go index 92142e5ee..f763f9ade 100644 --- a/example/migrate/main.go +++ b/example/migrate/main.go @@ -140,6 +140,23 @@ func newDBCommand(migrator *migrate.Migrator) *cli.Command { return nil }, }, + { + Name: "create_tx_sql", + Usage: "create up and down transactional SQL migrations", + Action: func(c *cli.Context) error { + name := strings.Join(c.Args().Slice(), "_") + files, err := migrator.CreateTxSQLMigrations(c.Context, name) + if err != nil { + return err + } + + for _, mf := range files { + fmt.Printf("created transaction migration %s (%s)\n", mf.Name, mf.Path) + } + + return nil + }, + }, { Name: "status", Usage: "print migrations status", diff --git a/migrate/migration.go b/migrate/migration.go index 1a4a67511..3f4076d2b 100644 --- a/migrate/migration.go +++ b/migrate/migration.go @@ -158,6 +158,11 @@ SELECT 1 SELECT 2 ` +const transactionalSQLTemplate = `SET statement_timeout = 0; + +SELECT 1; +` + //------------------------------------------------------------------------------ type MigrationSlice []Migration diff --git a/migrate/migrator.go b/migrate/migrator.go index ddf5485c0..52290b370 100644 --- a/migrate/migrator.go +++ b/migrate/migrator.go @@ -267,19 +267,39 @@ func (m *Migrator) CreateGoMigration( return mf, nil } -// CreateSQLMigrations creates an up and down SQL migration files. +// CreateTxSQLMigration creates transactional up and down SQL migration files. +func (m *Migrator) CreateTxSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) { + name, err := m.genMigrationName(name) + if err != nil { + return nil, err + } + + up, err := m.createSQL(ctx, name+".up.tx.sql", true) + if err != nil { + return nil, err + } + + down, err := m.createSQL(ctx, name+".down.tx.sql", true) + if err != nil { + return nil, err + } + + return []*MigrationFile{up, down}, nil +} + +// CreateSQLMigrations creates up and down SQL migration files. func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) { name, err := m.genMigrationName(name) if err != nil { return nil, err } - up, err := m.createSQL(ctx, name+".up.sql") + up, err := m.createSQL(ctx, name+".up.sql", false) if err != nil { return nil, err } - down, err := m.createSQL(ctx, name+".down.sql") + down, err := m.createSQL(ctx, name+".down.sql", false) if err != nil { return nil, err } @@ -287,10 +307,15 @@ func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*Mig return []*MigrationFile{up, down}, nil } -func (m *Migrator) createSQL(ctx context.Context, fname string) (*MigrationFile, error) { +func (m *Migrator) createSQL(ctx context.Context, fname string, transactional bool) (*MigrationFile, error) { fpath := filepath.Join(m.migrations.getDirectory(), fname) - if err := os.WriteFile(fpath, []byte(sqlTemplate), 0o644); err != nil { + template := sqlTemplate + if transactional { + template = transactionalSQLTemplate + } + + if err := os.WriteFile(fpath, []byte(template), 0o644); err != nil { return nil, err }