Skip to content

Commit

Permalink
Merge pull request #1058 from bevzzz/feature/automigrate-followup
Browse files Browse the repository at this point in the history
Follow-up #926
  • Loading branch information
vmihailenco authored Nov 13, 2024
2 parents b10bd39 + c228b0e commit abcd779
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 91 deletions.
4 changes: 2 additions & 2 deletions dialect/pgdialect/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/uptrace/bun/schema"
)

func (d *Dialect) Migrator(db *bun.DB, schemaName string) sqlschema.Migrator {
func (d *Dialect) NewMigrator(db *bun.DB, schemaName string) sqlschema.Migrator {
return &migrator{db: db, schemaName: schemaName, BaseMigrator: sqlschema.NewBaseMigrator(db)}
}

Expand Down Expand Up @@ -202,7 +202,7 @@ func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *mi
got, want := colDef.From, colDef.To

inspector := m.db.Dialect().(sqlschema.InspectorDialect)
if !inspector.EquivalentType(want, got) {
if !inspector.CompareType(want, got) {
appendAlterColumn()
b = append(b, " SET DATA TYPE "...)
if b, err = want.AppendQuery(fmter, b); err != nil {
Expand Down
22 changes: 12 additions & 10 deletions dialect/pgdialect/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,42 @@ type (
Column = sqlschema.BaseColumn
)

func (d *Dialect) Inspector(db *bun.DB, excludeTables ...string) sqlschema.Inspector {
return newInspector(db, excludeTables...)
func (d *Dialect) NewInspector(db *bun.DB, options ...sqlschema.InspectorOption) sqlschema.Inspector {
return newInspector(db, options...)
}

type Inspector struct {
db *bun.DB
excludeTables []string
sqlschema.InspectorConfig
db *bun.DB
}

var _ sqlschema.Inspector = (*Inspector)(nil)

func newInspector(db *bun.DB, excludeTables ...string) *Inspector {
return &Inspector{db: db, excludeTables: excludeTables}
func newInspector(db *bun.DB, options ...sqlschema.InspectorOption) *Inspector {
i := &Inspector{db: db}
sqlschema.ApplyInspectorOptions(&i.InspectorConfig, options...)
return i
}

func (in *Inspector) Inspect(ctx context.Context, schemaName string) (sqlschema.Database, error) {
func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) {
dbSchema := Schema{
Tables: orderedmap.New[string, sqlschema.Table](),
ForeignKeys: make(map[sqlschema.ForeignKey]string),
}

exclude := in.excludeTables
exclude := in.ExcludeTables
if len(exclude) == 0 {
// Avoid getting NOT IN (NULL) if bun.In() is called with an empty slice.
exclude = []string{""}
}

var tables []*InformationSchemaTable
if err := in.db.NewRaw(sqlInspectTables, schemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil {
if err := in.db.NewRaw(sqlInspectTables, in.SchemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil {
return dbSchema, err
}

var fks []*ForeignKey
if err := in.db.NewRaw(sqlInspectForeignKeys, schemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil {
if err := in.db.NewRaw(sqlInspectForeignKeys, in.SchemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil {
return dbSchema, err
}
dbSchema.ForeignKeys = make(map[sqlschema.ForeignKey]string, len(fks))
Expand Down
2 changes: 1 addition & 1 deletion dialect/pgdialect/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ var (
timestampTz = newAliases(sqltype.Timestamp, pgTypeTimestampTz, pgTypeTimestampWithTz)
)

func (d *Dialect) EquivalentType(col1, col2 sqlschema.Column) bool {
func (d *Dialect) CompareType(col1, col2 sqlschema.Column) bool {
typ1, typ2 := strings.ToUpper(col1.GetSQLType()), strings.ToUpper(col2.GetSQLType())

if typ1 == typ2 {
Expand Down
6 changes: 3 additions & 3 deletions dialect/pgdialect/sqltype_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/uptrace/bun/migrate/sqlschema"
)

func TestInspectorDialect_EquivalentType(t *testing.T) {
func TestInspectorDialect_CompareType(t *testing.T) {
d := New()

t.Run("common types", func(t *testing.T) {
Expand Down Expand Up @@ -41,7 +41,7 @@ func TestInspectorDialect_EquivalentType(t *testing.T) {
eq = " !~ "
}
t.Run(tt.typ1+eq+tt.typ2, func(t *testing.T) {
got := d.EquivalentType(
got := d.CompareType(
&sqlschema.BaseColumn{SQLType: tt.typ1},
&sqlschema.BaseColumn{SQLType: tt.typ2},
)
Expand Down Expand Up @@ -77,7 +77,7 @@ func TestInspectorDialect_EquivalentType(t *testing.T) {
},
} {
t.Run(tt.name, func(t *testing.T) {
got := d.EquivalentType(&tt.col1, &tt.col2)
got := d.CompareType(&tt.col1, &tt.col2)
require.Equal(t, tt.want, got)
})
}
Expand Down
30 changes: 15 additions & 15 deletions internal/dbtest/inspect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
db.RegisterModel((*PublisherToJournalist)(nil))

dbInspector, err := sqlschema.NewInspector(db, migrationsTable, migrationLocksTable)
dbInspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(tt.schemaName), sqlschema.WithExcludeTables(migrationsTable, migrationLocksTable))
if err != nil {
t.Skip(err)
}
Expand All @@ -353,7 +353,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
(*Article)(nil), // references Journalist and Publisher
)

got, err := dbInspector.Inspect(ctx, tt.schemaName)
got, err := dbInspector.Inspect(ctx)
require.NoError(t, err)

// State.FKs store their database names, which differ from dialect to dialect.
Expand Down Expand Up @@ -433,7 +433,7 @@ func cmpColumns(
continue
}

if !d.EquivalentType(wantCol, gotCol) {
if !d.CompareType(wantCol, gotCol) {
errorf("sql types are not equivalent:\n\t(+want)\t%s\n\t(-got)\t%s", formatType(wantCol), formatType(gotCol))
}

Expand Down Expand Up @@ -523,7 +523,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema()))

want := orderedmap.New[string, sqlschema.Column](orderedmap.WithInitialData(
orderedmap.Pair[string, sqlschema.Column]{
Expand All @@ -542,7 +542,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {
},
))

got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema())
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -562,7 +562,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema()))

want := orderedmap.New[string, sqlschema.Column](orderedmap.WithInitialData(
orderedmap.Pair[string, sqlschema.Column]{
Expand All @@ -587,7 +587,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {
},
))

got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema())
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -606,7 +606,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema()))

want := &sqlschema.BaseTable{
Name: "models",
Expand All @@ -616,7 +616,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {
},
}

got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema())
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -635,10 +635,10 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema()))
want := sqlschema.NewColumns("id", "email")

got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema())
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -658,9 +658,9 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName("custom_schema"))

got, err := inspector.Inspect(context.Background(), "custom_schema")
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -683,9 +683,9 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*KeepMe)(nil), (*LoseMe)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName("want"))

got, err := inspector.Inspect(context.Background(), "want")
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand Down
14 changes: 8 additions & 6 deletions internal/dbtest/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,19 +217,21 @@ func newAutoMigratorOrSkip(tb testing.TB, db *bun.DB, opts ...migrate.AutoMigrat
// and fail if the inspector cannot successfully retrieve database state.
func inspectDbOrSkip(tb testing.TB, db *bun.DB, schemaName ...string) func(context.Context) sqlschema.BaseDatabase {
tb.Helper()
// AutoMigrator excludes these tables by default, but here we need to do this explicitly.
inspector, err := sqlschema.NewInspector(db, migrationsTable, migrationLocksTable)
if err != nil {
tb.Skip(err)
}

// For convenience, schemaName is an optional parameter in this function.
inspectSchema := db.Dialect().DefaultSchema()
if len(schemaName) > 0 {
inspectSchema = schemaName[0]
}

// AutoMigrator excludes these tables by default, but here we need to do this explicitly.
inspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(inspectSchema), sqlschema.WithExcludeTables(migrationsTable, migrationLocksTable))
if err != nil {
tb.Skip(err)
}

return func(ctx context.Context) sqlschema.BaseDatabase {
state, err := inspector.Inspect(ctx, inspectSchema)
state, err := inspector.Inspect(ctx)
require.NoError(tb, err)
return state.(sqlschema.BaseDatabase)
}
Expand Down
12 changes: 7 additions & 5 deletions migrate/auto.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ func WithModel(models ...interface{}) AutoMigratorOption {
// WithExcludeTable tells the AutoMigrator to ignore a table in the database.
// This prevents AutoMigrator from dropping tables which may exist in the schema
// but which are not used by the application.
//
// Do not exclude tables included via WithModel, as BunModelInspector ignores this setting.
func WithExcludeTable(tables ...string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.excludeTables = append(m.excludeTables, tables...)
Expand Down Expand Up @@ -148,12 +150,12 @@ func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, err
}
am.excludeTables = append(am.excludeTables, am.table, am.locksTable)

dbInspector, err := sqlschema.NewInspector(db, am.excludeTables...)
dbInspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(am.schemaName), sqlschema.WithExcludeTables(am.excludeTables...))
if err != nil {
return nil, err
}
am.dbInspector = dbInspector
am.diffOpts = append(am.diffOpts, withTypeEquivalenceFunc(db.Dialect().(sqlschema.InspectorDialect).EquivalentType))
am.diffOpts = append(am.diffOpts, withCompareTypeFunc(db.Dialect().(sqlschema.InspectorDialect).CompareType))

dbMigrator, err := sqlschema.NewMigrator(db, am.schemaName)
if err != nil {
Expand All @@ -163,20 +165,20 @@ func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, err

tables := schema.NewTables(db.Dialect())
tables.Register(am.includeModels...)
am.modelInspector = sqlschema.NewBunModelInspector(tables)
am.modelInspector = sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(am.schemaName))

return am, nil
}

func (am *AutoMigrator) plan(ctx context.Context) (*changeset, error) {
var err error

got, err := am.dbInspector.Inspect(ctx, am.schemaName)
got, err := am.dbInspector.Inspect(ctx)
if err != nil {
return nil, err
}

want, err := am.modelInspector.Inspect(ctx, am.schemaName)
want, err := am.modelInspector.Inspect(ctx)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit abcd779

Please sign in to comment.