diff --git a/example/migrate/main.go b/example/migrate/main.go index 92142e5ee..8908aba92 100644 --- a/example/migrate/main.go +++ b/example/migrate/main.go @@ -128,7 +128,7 @@ func newDBCommand(migrator *migrate.Migrator) *cli.Command { Usage: "create up and down SQL migrations", Action: func(c *cli.Context) error { name := strings.Join(c.Args().Slice(), "_") - files, err := migrator.CreateSQLMigrations(c.Context, name) + files, err := migrator.CreateSQLMigrations(c.Context, name, false) if err != nil { return err } diff --git a/migrate/migration.go b/migrate/migration.go index 1a4a67511..3f4076d2b 100644 --- a/migrate/migration.go +++ b/migrate/migration.go @@ -158,6 +158,11 @@ SELECT 1 SELECT 2 ` +const transactionalSQLTemplate = `SET statement_timeout = 0; + +SELECT 1; +` + //------------------------------------------------------------------------------ type MigrationSlice []Migration diff --git a/migrate/migrator.go b/migrate/migrator.go index ddf5485c0..6b2b55bf8 100644 --- a/migrate/migrator.go +++ b/migrate/migrator.go @@ -268,18 +268,26 @@ func (m *Migrator) CreateGoMigration( } // CreateSQLMigrations creates an up and down SQL migration files. -func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) { +func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string, transactional bool) ([]*MigrationFile, error) { name, err := m.genMigrationName(name) if err != nil { return nil, err } - up, err := m.createSQL(ctx, name+".up.sql") + upSuffix := ".up.sql" + downSuffix := ".down.sql" + + if transactional { + upSuffix = ".up.tx.sql" + downSuffix = ".down.tx.sql" + } + + up, err := m.createSQL(ctx, name+upSuffix, transactional) if err != nil { return nil, err } - down, err := m.createSQL(ctx, name+".down.sql") + down, err := m.createSQL(ctx, name+downSuffix, transactional) if err != nil { return nil, err } @@ -287,10 +295,15 @@ func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*Mig return []*MigrationFile{up, down}, nil } -func (m *Migrator) createSQL(ctx context.Context, fname string) (*MigrationFile, error) { +func (m *Migrator) createSQL(ctx context.Context, fname string, transactional bool) (*MigrationFile, error) { fpath := filepath.Join(m.migrations.getDirectory(), fname) - if err := os.WriteFile(fpath, []byte(sqlTemplate), 0o644); err != nil { + template := sqlTemplate + if transactional { + template = transactionalSQLTemplate + } + + if err := os.WriteFile(fpath, []byte(template), 0o644); err != nil { return nil, err }