Skip to content

Commit

Permalink
refactor: pass schemaName and excludeTables as InspectorOptions
Browse files Browse the repository at this point in the history
  • Loading branch information
bevzzz committed Nov 13, 2024
1 parent ac8d221 commit b2288fc
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 65 deletions.
22 changes: 12 additions & 10 deletions dialect/pgdialect/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,42 @@ type (
Column = sqlschema.BaseColumn
)

func (d *Dialect) Inspector(db *bun.DB, excludeTables ...string) sqlschema.Inspector {
return newInspector(db, excludeTables...)
func (d *Dialect) Inspector(db *bun.DB, options ...sqlschema.InspectorOption) sqlschema.Inspector {
return newInspector(db, options...)
}

type Inspector struct {
db *bun.DB
excludeTables []string
sqlschema.InspectorConfig
db *bun.DB
}

var _ sqlschema.Inspector = (*Inspector)(nil)

func newInspector(db *bun.DB, excludeTables ...string) *Inspector {
return &Inspector{db: db, excludeTables: excludeTables}
func newInspector(db *bun.DB, options ...sqlschema.InspectorOption) *Inspector {
i := &Inspector{db: db}
sqlschema.ApplyInspectorOptions(&i.InspectorConfig, options...)
return i
}

func (in *Inspector) Inspect(ctx context.Context, schemaName string) (sqlschema.Database, error) {
func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) {
dbSchema := Schema{
Tables: orderedmap.New[string, sqlschema.Table](),
ForeignKeys: make(map[sqlschema.ForeignKey]string),
}

exclude := in.excludeTables
exclude := in.ExcludeTables
if len(exclude) == 0 {
// Avoid getting NOT IN (NULL) if bun.In() is called with an empty slice.
exclude = []string{""}
}

var tables []*InformationSchemaTable
if err := in.db.NewRaw(sqlInspectTables, schemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil {
if err := in.db.NewRaw(sqlInspectTables, in.SchemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil {
return dbSchema, err
}

var fks []*ForeignKey
if err := in.db.NewRaw(sqlInspectForeignKeys, schemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil {
if err := in.db.NewRaw(sqlInspectForeignKeys, in.SchemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil {
return dbSchema, err
}
dbSchema.ForeignKeys = make(map[sqlschema.ForeignKey]string, len(fks))
Expand Down
28 changes: 14 additions & 14 deletions internal/dbtest/inspect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
db.RegisterModel((*PublisherToJournalist)(nil))

dbInspector, err := sqlschema.NewInspector(db, migrationsTable, migrationLocksTable)
dbInspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(tt.schemaName), sqlschema.WithExcludeTables(migrationsTable, migrationLocksTable))
if err != nil {
t.Skip(err)
}
Expand All @@ -353,7 +353,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
(*Article)(nil), // references Journalist and Publisher
)

got, err := dbInspector.Inspect(ctx, tt.schemaName)
got, err := dbInspector.Inspect(ctx)
require.NoError(t, err)

// State.FKs store their database names, which differ from dialect to dialect.
Expand Down Expand Up @@ -523,7 +523,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema()))

want := orderedmap.New[string, sqlschema.Column](orderedmap.WithInitialData(
orderedmap.Pair[string, sqlschema.Column]{
Expand All @@ -542,7 +542,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {
},
))

got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema())
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -562,7 +562,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema()))

want := orderedmap.New[string, sqlschema.Column](orderedmap.WithInitialData(
orderedmap.Pair[string, sqlschema.Column]{
Expand All @@ -587,7 +587,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {
},
))

got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema())
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -606,7 +606,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema()))

want := &sqlschema.BaseTable{
Name: "models",
Expand All @@ -616,7 +616,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {
},
}

got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema())
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -635,10 +635,10 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema()))
want := sqlschema.NewColumns("id", "email")

got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema())
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -658,9 +658,9 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName("custom_schema"))

got, err := inspector.Inspect(context.Background(), "custom_schema")
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -683,9 +683,9 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*KeepMe)(nil), (*LoseMe)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName("want"))

got, err := inspector.Inspect(context.Background(), "want")
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand Down
14 changes: 8 additions & 6 deletions internal/dbtest/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,19 +217,21 @@ func newAutoMigratorOrSkip(tb testing.TB, db *bun.DB, opts ...migrate.AutoMigrat
// and fail if the inspector cannot successfully retrieve database state.
func inspectDbOrSkip(tb testing.TB, db *bun.DB, schemaName ...string) func(context.Context) sqlschema.BaseDatabase {
tb.Helper()
// AutoMigrator excludes these tables by default, but here we need to do this explicitly.
inspector, err := sqlschema.NewInspector(db, migrationsTable, migrationLocksTable)
if err != nil {
tb.Skip(err)
}

// For convenience, schemaName is an optional parameter in this function.
inspectSchema := db.Dialect().DefaultSchema()
if len(schemaName) > 0 {
inspectSchema = schemaName[0]
}

// AutoMigrator excludes these tables by default, but here we need to do this explicitly.
inspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(inspectSchema), sqlschema.WithExcludeTables(migrationsTable, migrationLocksTable))
if err != nil {
tb.Skip(err)
}

return func(ctx context.Context) sqlschema.BaseDatabase {
state, err := inspector.Inspect(ctx, inspectSchema)
state, err := inspector.Inspect(ctx)
require.NoError(tb, err)
return state.(sqlschema.BaseDatabase)
}
Expand Down
10 changes: 6 additions & 4 deletions migrate/auto.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ func WithModel(models ...interface{}) AutoMigratorOption {
// WithExcludeTable tells the AutoMigrator to ignore a table in the database.
// This prevents AutoMigrator from dropping tables which may exist in the schema
// but which are not used by the application.
//
// Do not exclude tables included via WithModel, as BunModelInspector ignores this setting.
func WithExcludeTable(tables ...string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.excludeTables = append(m.excludeTables, tables...)
Expand Down Expand Up @@ -148,7 +150,7 @@ func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, err
}
am.excludeTables = append(am.excludeTables, am.table, am.locksTable)

dbInspector, err := sqlschema.NewInspector(db, am.excludeTables...)
dbInspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(am.schemaName), sqlschema.WithExcludeTables(am.excludeTables...))
if err != nil {
return nil, err
}
Expand All @@ -163,20 +165,20 @@ func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, err

tables := schema.NewTables(db.Dialect())
tables.Register(am.includeModels...)
am.modelInspector = sqlschema.NewBunModelInspector(tables)
am.modelInspector = sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(am.schemaName))

return am, nil
}

func (am *AutoMigrator) plan(ctx context.Context) (*changeset, error) {
var err error

got, err := am.dbInspector.Inspect(ctx, am.schemaName)
got, err := am.dbInspector.Inspect(ctx)
if err != nil {
return nil, err
}

want, err := am.modelInspector.Inspect(ctx, am.schemaName)
want, err := am.modelInspector.Inspect(ctx)
if err != nil {
return nil, err
}
Expand Down
103 changes: 72 additions & 31 deletions migrate/sqlschema/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,77 +13,99 @@ import (

type InspectorDialect interface {
schema.Dialect
Inspector(db *bun.DB, excludeTables ...string) Inspector

// Inspector returns a new instance of Inspector for the dialect.
// Dialects MAY set their default InspectorConfig values in constructor
// but MUST apply InspectorOptions to ensure they can be overriden.
//
// Use ApplyInspectorOptions to reduce boilerplate.
Inspector(db *bun.DB, options ...InspectorOption) Inspector

// EquivalentType returns true if col1 and co2 SQL types are equivalent,
// i.e. they might use dialect-specifc type aliases (SERIAL ~ SMALLINT)
// or specify the same VARCHAR length differently (VARCHAR(255) ~ VARCHAR).
EquivalentType(Column, Column) bool
}

// InspectorConfig controls the scope of migration by limiting the objects Inspector should return.
// Inspectors SHOULD use the configuration directly instead of copying it, or MAY choose to embed it,
// to make sure options are always applied correctly.
type InspectorConfig struct {
// SchemaName limits inspection to tables in a particular schema.
SchemaName string

// ExcludeTables from inspection.
ExcludeTables []string
}

// Inspector reads schema state.
type Inspector interface {
Inspect(ctx context.Context, schemaName string) (Database, error)
Inspect(ctx context.Context) (Database, error)
}

// inspector is opaque pointer to a databse inspector.
type inspector struct {
Inspector
func WithSchemaName(schemaName string) InspectorOption {
return func(cfg *InspectorConfig) {
cfg.SchemaName = schemaName
}
}

// WithExcludeTables works in append-only mode, i.e. tables cannot be re-included.
func WithExcludeTables(tables ...string) InspectorOption {
return func(cfg *InspectorConfig) {
cfg.ExcludeTables = append(cfg.ExcludeTables, tables...)
}
}

// NewInspector creates a new database inspector, if the dialect supports it.
func NewInspector(db *bun.DB, excludeTables ...string) (Inspector, error) {
func NewInspector(db *bun.DB, options ...InspectorOption) (Inspector, error) {
dialect, ok := (db.Dialect()).(InspectorDialect)
if !ok {
return nil, fmt.Errorf("%s does not implement sqlschema.Inspector", db.Dialect().Name())
}
return &inspector{
Inspector: dialect.Inspector(db, excludeTables...),
Inspector: dialect.Inspector(db, options...),
}, nil
}

// BunModelInspector creates the current project state from the passed bun.Models.
// Do not recycle BunModelInspector for different sets of models, as older models will not be de-registerred before the next run.
type BunModelInspector struct {
tables *schema.Tables
}

var _ Inspector = (*BunModelInspector)(nil)

func NewBunModelInspector(tables *schema.Tables) *BunModelInspector {
return &BunModelInspector{
func NewBunModelInspector(tables *schema.Tables, options ...InspectorOption) *BunModelInspector {
bmi := &BunModelInspector{
tables: tables,
}
ApplyInspectorOptions(&bmi.InspectorConfig, options...)
return bmi
}

// BunModelSchema is the schema state derived from bun table models.
type BunModelSchema struct {
BaseDatabase
type InspectorOption func(*InspectorConfig)

Tables *orderedmap.OrderedMap[string, Table]
func ApplyInspectorOptions(cfg *InspectorConfig, options ...InspectorOption) {
for _, opt := range options {
opt(cfg)
}
}

func (ms BunModelSchema) GetTables() *orderedmap.OrderedMap[string, Table] {
return ms.Tables
// inspector is opaque pointer to a database inspector.
type inspector struct {
Inspector
}

// BunTable provides additional table metadata that is only accessible from scanning bun models.
type BunTable struct {
BaseTable

// Model stores the zero interface to the underlying Go struct.
Model interface{}
// BunModelInspector creates the current project state from the passed bun.Models.
// Do not recycle BunModelInspector for different sets of models, as older models will not be de-registerred before the next run.
type BunModelInspector struct {
InspectorConfig
tables *schema.Tables
}

func (bmi *BunModelInspector) Inspect(ctx context.Context, schemaName string) (Database, error) {
var _ Inspector = (*BunModelInspector)(nil)

func (bmi *BunModelInspector) Inspect(ctx context.Context) (Database, error) {
state := BunModelSchema{
BaseDatabase: BaseDatabase{
ForeignKeys: make(map[ForeignKey]string),
},
Tables: orderedmap.New[string, Table](),
}
for _, t := range bmi.tables.All() {
if t.Schema != schemaName {
if t.Schema != bmi.SchemaName {
continue
}

Expand Down Expand Up @@ -198,3 +220,22 @@ func exprToLower(s string) string {
}
return strings.ToLower(s)
}

// BunModelSchema is the schema state derived from bun table models.
type BunModelSchema struct {
BaseDatabase

Tables *orderedmap.OrderedMap[string, Table]
}

func (ms BunModelSchema) GetTables() *orderedmap.OrderedMap[string, Table] {
return ms.Tables
}

// BunTable provides additional table metadata that is only accessible from scanning bun models.
type BunTable struct {
BaseTable

// Model stores the zero interface to the underlying Go struct.
Model interface{}
}

0 comments on commit b2288fc

Please sign in to comment.