diff --git a/dialect/pgdialect/inspector.go b/dialect/pgdialect/inspector.go index ae2b7cc7e..d21e21911 100644 --- a/dialect/pgdialect/inspector.go +++ b/dialect/pgdialect/inspector.go @@ -15,40 +15,42 @@ type ( Column = sqlschema.BaseColumn ) -func (d *Dialect) Inspector(db *bun.DB, excludeTables ...string) sqlschema.Inspector { - return newInspector(db, excludeTables...) +func (d *Dialect) Inspector(db *bun.DB, options ...sqlschema.InspectorOption) sqlschema.Inspector { + return newInspector(db, options...) } type Inspector struct { - db *bun.DB - excludeTables []string + sqlschema.InspectorConfig + db *bun.DB } var _ sqlschema.Inspector = (*Inspector)(nil) -func newInspector(db *bun.DB, excludeTables ...string) *Inspector { - return &Inspector{db: db, excludeTables: excludeTables} +func newInspector(db *bun.DB, options ...sqlschema.InspectorOption) *Inspector { + i := &Inspector{db: db} + sqlschema.ApplyInspectorOptions(&i.InspectorConfig, options...) + return i } -func (in *Inspector) Inspect(ctx context.Context, schemaName string) (sqlschema.Database, error) { +func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) { dbSchema := Schema{ Tables: orderedmap.New[string, sqlschema.Table](), ForeignKeys: make(map[sqlschema.ForeignKey]string), } - exclude := in.excludeTables + exclude := in.ExcludeTables if len(exclude) == 0 { // Avoid getting NOT IN (NULL) if bun.In() is called with an empty slice. exclude = []string{""} } var tables []*InformationSchemaTable - if err := in.db.NewRaw(sqlInspectTables, schemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil { + if err := in.db.NewRaw(sqlInspectTables, in.SchemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil { return dbSchema, err } var fks []*ForeignKey - if err := in.db.NewRaw(sqlInspectForeignKeys, schemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil { + if err := in.db.NewRaw(sqlInspectForeignKeys, in.SchemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil { return dbSchema, err } dbSchema.ForeignKeys = make(map[sqlschema.ForeignKey]string, len(fks)) diff --git a/internal/dbtest/inspect_test.go b/internal/dbtest/inspect_test.go index dd37e2f13..9ecea49da 100644 --- a/internal/dbtest/inspect_test.go +++ b/internal/dbtest/inspect_test.go @@ -335,7 +335,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) { t.Run(tt.name, func(t *testing.T) { db.RegisterModel((*PublisherToJournalist)(nil)) - dbInspector, err := sqlschema.NewInspector(db, migrationsTable, migrationLocksTable) + dbInspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(tt.schemaName), sqlschema.WithExcludeTables(migrationsTable, migrationLocksTable)) if err != nil { t.Skip(err) } @@ -353,7 +353,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) { (*Article)(nil), // references Journalist and Publisher ) - got, err := dbInspector.Inspect(ctx, tt.schemaName) + got, err := dbInspector.Inspect(ctx) require.NoError(t, err) // State.FKs store their database names, which differ from dialect to dialect. @@ -523,7 +523,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables := schema.NewTables(dialect) tables.Register((*Model)(nil)) - inspector := sqlschema.NewBunModelInspector(tables) + inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema())) want := orderedmap.New[string, sqlschema.Column](orderedmap.WithInitialData( orderedmap.Pair[string, sqlschema.Column]{ @@ -542,7 +542,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { }, )) - got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema()) + got, err := inspector.Inspect(context.Background()) require.NoError(t, err) gotTables := got.GetTables() @@ -562,7 +562,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables := schema.NewTables(dialect) tables.Register((*Model)(nil)) - inspector := sqlschema.NewBunModelInspector(tables) + inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema())) want := orderedmap.New[string, sqlschema.Column](orderedmap.WithInitialData( orderedmap.Pair[string, sqlschema.Column]{ @@ -587,7 +587,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { }, )) - got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema()) + got, err := inspector.Inspect(context.Background()) require.NoError(t, err) gotTables := got.GetTables() @@ -606,7 +606,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables := schema.NewTables(dialect) tables.Register((*Model)(nil)) - inspector := sqlschema.NewBunModelInspector(tables) + inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema())) want := &sqlschema.BaseTable{ Name: "models", @@ -616,7 +616,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { }, } - got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema()) + got, err := inspector.Inspect(context.Background()) require.NoError(t, err) gotTables := got.GetTables() @@ -635,10 +635,10 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables := schema.NewTables(dialect) tables.Register((*Model)(nil)) - inspector := sqlschema.NewBunModelInspector(tables) + inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema())) want := sqlschema.NewColumns("id", "email") - got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema()) + got, err := inspector.Inspect(context.Background()) require.NoError(t, err) gotTables := got.GetTables() @@ -658,9 +658,9 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables := schema.NewTables(dialect) tables.Register((*Model)(nil)) - inspector := sqlschema.NewBunModelInspector(tables) + inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName("custom_schema")) - got, err := inspector.Inspect(context.Background(), "custom_schema") + got, err := inspector.Inspect(context.Background()) require.NoError(t, err) gotTables := got.GetTables() @@ -683,9 +683,9 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables := schema.NewTables(dialect) tables.Register((*KeepMe)(nil), (*LoseMe)(nil)) - inspector := sqlschema.NewBunModelInspector(tables) + inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName("want")) - got, err := inspector.Inspect(context.Background(), "want") + got, err := inspector.Inspect(context.Background()) require.NoError(t, err) gotTables := got.GetTables() diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index 06bc531af..0705aec30 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -217,19 +217,21 @@ func newAutoMigratorOrSkip(tb testing.TB, db *bun.DB, opts ...migrate.AutoMigrat // and fail if the inspector cannot successfully retrieve database state. func inspectDbOrSkip(tb testing.TB, db *bun.DB, schemaName ...string) func(context.Context) sqlschema.BaseDatabase { tb.Helper() - // AutoMigrator excludes these tables by default, but here we need to do this explicitly. - inspector, err := sqlschema.NewInspector(db, migrationsTable, migrationLocksTable) - if err != nil { - tb.Skip(err) - } // For convenience, schemaName is an optional parameter in this function. inspectSchema := db.Dialect().DefaultSchema() if len(schemaName) > 0 { inspectSchema = schemaName[0] } + + // AutoMigrator excludes these tables by default, but here we need to do this explicitly. + inspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(inspectSchema), sqlschema.WithExcludeTables(migrationsTable, migrationLocksTable)) + if err != nil { + tb.Skip(err) + } + return func(ctx context.Context) sqlschema.BaseDatabase { - state, err := inspector.Inspect(ctx, inspectSchema) + state, err := inspector.Inspect(ctx) require.NoError(tb, err) return state.(sqlschema.BaseDatabase) } diff --git a/migrate/auto.go b/migrate/auto.go index 32582eba3..be6954f01 100644 --- a/migrate/auto.go +++ b/migrate/auto.go @@ -27,6 +27,8 @@ func WithModel(models ...interface{}) AutoMigratorOption { // WithExcludeTable tells the AutoMigrator to ignore a table in the database. // This prevents AutoMigrator from dropping tables which may exist in the schema // but which are not used by the application. +// +// Do not exclude tables included via WithModel, as BunModelInspector ignores this setting. func WithExcludeTable(tables ...string) AutoMigratorOption { return func(m *AutoMigrator) { m.excludeTables = append(m.excludeTables, tables...) @@ -148,7 +150,7 @@ func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, err } am.excludeTables = append(am.excludeTables, am.table, am.locksTable) - dbInspector, err := sqlschema.NewInspector(db, am.excludeTables...) + dbInspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(am.schemaName), sqlschema.WithExcludeTables(am.excludeTables...)) if err != nil { return nil, err } @@ -163,7 +165,7 @@ func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, err tables := schema.NewTables(db.Dialect()) tables.Register(am.includeModels...) - am.modelInspector = sqlschema.NewBunModelInspector(tables) + am.modelInspector = sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(am.schemaName)) return am, nil } @@ -171,12 +173,12 @@ func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, err func (am *AutoMigrator) plan(ctx context.Context) (*changeset, error) { var err error - got, err := am.dbInspector.Inspect(ctx, am.schemaName) + got, err := am.dbInspector.Inspect(ctx) if err != nil { return nil, err } - want, err := am.modelInspector.Inspect(ctx, am.schemaName) + want, err := am.modelInspector.Inspect(ctx) if err != nil { return nil, err } diff --git a/migrate/sqlschema/inspector.go b/migrate/sqlschema/inspector.go index ed474ed95..1464e5ccf 100644 --- a/migrate/sqlschema/inspector.go +++ b/migrate/sqlschema/inspector.go @@ -13,7 +13,13 @@ import ( type InspectorDialect interface { schema.Dialect - Inspector(db *bun.DB, excludeTables ...string) Inspector + + // Inspector returns a new instance of Inspector for the dialect. + // Dialects MAY set their default InspectorConfig values in constructor + // but MUST apply InspectorOptions to ensure they can be overriden. + // + // Use ApplyInspectorOptions to reduce boilerplate. + Inspector(db *bun.DB, options ...InspectorOption) Inspector // EquivalentType returns true if col1 and co2 SQL types are equivalent, // i.e. they might use dialect-specifc type aliases (SERIAL ~ SMALLINT) @@ -21,61 +27,77 @@ type InspectorDialect interface { EquivalentType(Column, Column) bool } +// InspectorConfig controls the scope of migration by limiting the objects Inspector should return. +// Inspectors SHOULD use the configuration directly instead of copying it, or MAY choose to embed it, +// to make sure options are always applied correctly. +type InspectorConfig struct { + // SchemaName limits inspection to tables in a particular schema. + SchemaName string + + // ExcludeTables from inspection. + ExcludeTables []string +} + // Inspector reads schema state. type Inspector interface { - Inspect(ctx context.Context, schemaName string) (Database, error) + Inspect(ctx context.Context) (Database, error) } -// inspector is opaque pointer to a databse inspector. -type inspector struct { - Inspector +func WithSchemaName(schemaName string) InspectorOption { + return func(cfg *InspectorConfig) { + cfg.SchemaName = schemaName + } +} + +// WithExcludeTables works in append-only mode, i.e. tables cannot be re-included. +func WithExcludeTables(tables ...string) InspectorOption { + return func(cfg *InspectorConfig) { + cfg.ExcludeTables = append(cfg.ExcludeTables, tables...) + } } // NewInspector creates a new database inspector, if the dialect supports it. -func NewInspector(db *bun.DB, excludeTables ...string) (Inspector, error) { +func NewInspector(db *bun.DB, options ...InspectorOption) (Inspector, error) { dialect, ok := (db.Dialect()).(InspectorDialect) if !ok { return nil, fmt.Errorf("%s does not implement sqlschema.Inspector", db.Dialect().Name()) } return &inspector{ - Inspector: dialect.Inspector(db, excludeTables...), + Inspector: dialect.Inspector(db, options...), }, nil } -// BunModelInspector creates the current project state from the passed bun.Models. -// Do not recycle BunModelInspector for different sets of models, as older models will not be de-registerred before the next run. -type BunModelInspector struct { - tables *schema.Tables -} - -var _ Inspector = (*BunModelInspector)(nil) - -func NewBunModelInspector(tables *schema.Tables) *BunModelInspector { - return &BunModelInspector{ +func NewBunModelInspector(tables *schema.Tables, options ...InspectorOption) *BunModelInspector { + bmi := &BunModelInspector{ tables: tables, } + ApplyInspectorOptions(&bmi.InspectorConfig, options...) + return bmi } -// BunModelSchema is the schema state derived from bun table models. -type BunModelSchema struct { - BaseDatabase +type InspectorOption func(*InspectorConfig) - Tables *orderedmap.OrderedMap[string, Table] +func ApplyInspectorOptions(cfg *InspectorConfig, options ...InspectorOption) { + for _, opt := range options { + opt(cfg) + } } -func (ms BunModelSchema) GetTables() *orderedmap.OrderedMap[string, Table] { - return ms.Tables +// inspector is opaque pointer to a database inspector. +type inspector struct { + Inspector } -// BunTable provides additional table metadata that is only accessible from scanning bun models. -type BunTable struct { - BaseTable - - // Model stores the zero interface to the underlying Go struct. - Model interface{} +// BunModelInspector creates the current project state from the passed bun.Models. +// Do not recycle BunModelInspector for different sets of models, as older models will not be de-registerred before the next run. +type BunModelInspector struct { + InspectorConfig + tables *schema.Tables } -func (bmi *BunModelInspector) Inspect(ctx context.Context, schemaName string) (Database, error) { +var _ Inspector = (*BunModelInspector)(nil) + +func (bmi *BunModelInspector) Inspect(ctx context.Context) (Database, error) { state := BunModelSchema{ BaseDatabase: BaseDatabase{ ForeignKeys: make(map[ForeignKey]string), @@ -83,7 +105,7 @@ func (bmi *BunModelInspector) Inspect(ctx context.Context, schemaName string) (D Tables: orderedmap.New[string, Table](), } for _, t := range bmi.tables.All() { - if t.Schema != schemaName { + if t.Schema != bmi.SchemaName { continue } @@ -198,3 +220,22 @@ func exprToLower(s string) string { } return strings.ToLower(s) } + +// BunModelSchema is the schema state derived from bun table models. +type BunModelSchema struct { + BaseDatabase + + Tables *orderedmap.OrderedMap[string, Table] +} + +func (ms BunModelSchema) GetTables() *orderedmap.OrderedMap[string, Table] { + return ms.Tables +} + +// BunTable provides additional table metadata that is only accessible from scanning bun models. +type BunTable struct { + BaseTable + + // Model stores the zero interface to the underlying Go struct. + Model interface{} +}