diff --git a/dialect/pgdialect/inspector.go b/dialect/pgdialect/inspector.go index b955d1b67..6321f9ad2 100644 --- a/dialect/pgdialect/inspector.go +++ b/dialect/pgdialect/inspector.go @@ -32,7 +32,7 @@ func newInspector(db *bun.DB, excludeTables ...string) *Inspector { func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Schema, error) { dbSchema := Schema{ - BaseTables: make(map[schema.FQN]Table), + Tables: make(map[schema.FQN]sqlschema.Table), ForeignKeys: make(map[sqlschema.ForeignKey]string), } @@ -59,7 +59,7 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Schema, error) { return dbSchema, err } - colDefs := make(map[string]*Column) + colDefs := make(map[string]sqlschema.Column) uniqueGroups := make(map[string][]string) for _, c := range columns { @@ -102,10 +102,10 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Schema, error) { } fqn := schema.FQN{Schema: table.Schema, Table: table.Name} - dbSchema.BaseTables[fqn] = Table{ + dbSchema.Tables[fqn] = &Table{ Schema: table.Schema, Name: table.Name, - ColumnDefinitions: colDefs, + Columns: colDefs, PrimaryKey: pk, UniqueConstraints: unique, } diff --git a/internal/dbtest/inspect_test.go b/internal/dbtest/inspect_test.go index f9be48d82..9eb77ab21 100644 --- a/internal/dbtest/inspect_test.go +++ b/internal/dbtest/inspect_test.go @@ -93,51 +93,51 @@ func TestDatabaseInspector_Inspect(t *testing.T) { defaultSchema := db.Dialect().DefaultSchema() // Tables come sorted alphabetically by schema and table. - wantTables := map[schema.FQN]sqlschema.BaseTable{ - {Schema: "admin", Table: "offices"}: { + wantTables := map[schema.FQN]sqlschema.Table{ + {Schema: "admin", Table: "offices"}: &sqlschema.BaseTable{ Schema: "admin", Name: "offices", - ColumnDefinitions: map[string]sqlschema.BaseColumn{ - "office_name": { + Columns: map[string]sqlschema.Column{ + "office_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, }, - "publisher_id": { + "publisher_id": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, - "publisher_name": { + "publisher_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, }, PrimaryKey: &sqlschema.PrimaryKey{Columns: sqlschema.NewColumns("office_name")}, }, - {Schema: defaultSchema, Table: "articles"}: { + {Schema: defaultSchema, Table: "articles"}: &sqlschema.BaseTable{ Schema: defaultSchema, Name: "articles", - ColumnDefinitions: map[string]sqlschema.BaseColumn{ - "isbn": { + Columns: map[string]sqlschema.Column{ + "isbn": &sqlschema.BaseColumn{ SQLType: "bigint", IsNullable: false, IsAutoIncrement: false, IsIdentity: true, DefaultValue: "", }, - "editor": { + "editor": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: false, IsAutoIncrement: false, IsIdentity: false, DefaultValue: "john doe", }, - "title": { + "title": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: false, IsAutoIncrement: false, IsIdentity: false, DefaultValue: "", }, - "locale": { + "locale": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, VarcharLen: 5, IsNullable: true, @@ -145,24 +145,24 @@ func TestDatabaseInspector_Inspect(t *testing.T) { IsIdentity: false, DefaultValue: "en-GB", }, - "page_count": { + "page_count": &sqlschema.BaseColumn{ SQLType: "smallint", IsNullable: false, IsAutoIncrement: false, IsIdentity: false, DefaultValue: "1", }, - "book_count": { + "book_count": &sqlschema.BaseColumn{ SQLType: "integer", IsNullable: false, IsAutoIncrement: true, IsIdentity: false, DefaultValue: "", }, - "publisher_id": { + "publisher_id": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, }, - "author_id": { + "author_id": &sqlschema.BaseColumn{ SQLType: "bigint", }, }, @@ -171,21 +171,21 @@ func TestDatabaseInspector_Inspect(t *testing.T) { {Columns: sqlschema.NewColumns("editor", "title")}, }, }, - {Schema: defaultSchema, Table: "authors"}: { + {Schema: defaultSchema, Table: "authors"}: &sqlschema.BaseTable{ Schema: defaultSchema, Name: "authors", - ColumnDefinitions: map[string]sqlschema.BaseColumn{ - "author_id": { + Columns: map[string]sqlschema.Column{ + "author_id": &sqlschema.BaseColumn{ SQLType: "bigint", IsIdentity: true, }, - "first_name": { + "first_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, }, - "last_name": { + "last_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, }, - "email": { + "email": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, }, }, @@ -195,31 +195,31 @@ func TestDatabaseInspector_Inspect(t *testing.T) { {Columns: sqlschema.NewColumns("email")}, }, }, - {Schema: defaultSchema, Table: "publisher_to_journalists"}: { + {Schema: defaultSchema, Table: "publisher_to_journalists"}: &sqlschema.BaseTable{ Schema: defaultSchema, Name: "publisher_to_journalists", - ColumnDefinitions: map[string]sqlschema.BaseColumn{ - "publisher_id": { + Columns: map[string]sqlschema.Column{ + "publisher_id": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, }, - "author_id": { + "author_id": &sqlschema.BaseColumn{ SQLType: "bigint", }, }, PrimaryKey: &sqlschema.PrimaryKey{Columns: sqlschema.NewColumns("publisher_id", "author_id")}, }, - {Schema: defaultSchema, Table: "publishers"}: { + {Schema: defaultSchema, Table: "publishers"}: &sqlschema.BaseTable{ Schema: defaultSchema, Name: "publishers", - ColumnDefinitions: map[string]sqlschema.BaseColumn{ - "publisher_id": { + Columns: map[string]sqlschema.Column{ + "publisher_id": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, DefaultValue: "gen_random_uuid()", }, - "publisher_name": { + "publisher_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, }, - "created_at": { + "created_at": &sqlschema.BaseColumn{ SQLType: "timestamp", DefaultValue: "current_timestamp", IsNullable: true, @@ -260,7 +260,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) { // State.FKs store their database names, which differ from dialect to dialect. // Because of that we compare FKs and Tables separately. - gotTables := got.(sqlschema.DatabaseSchema).BaseTables + gotTables := got.(sqlschema.DatabaseSchema).Tables cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, gotTables) var fks []sqlschema.ForeignKey @@ -293,7 +293,9 @@ func mustCreateSchema(tb testing.TB, ctx context.Context, db *bun.DB, schema str // cmpTables compares table schemas using dialect-specific equivalence checks for column types // and reports the differences as t.Error(). -func cmpTables(tb testing.TB, d sqlschema.InspectorDialect, want, got map[schema.FQN]sqlschema.BaseTable) { +func cmpTables( + tb testing.TB, d sqlschema.InspectorDialect, want, got map[schema.FQN]sqlschema.Table, +) { tb.Helper() require.ElementsMatch(tb, tableNames(want), tableNames(got), "different set of tables") @@ -301,21 +303,26 @@ func cmpTables(tb testing.TB, d sqlschema.InspectorDialect, want, got map[schema // Now we are guaranteed to have the same tables. for _, wantTable := range want { // TODO(dyma): this will be simplified by map[string]Table - var gt sqlschema.BaseTable + var gt sqlschema.Table for i := range got { - if got[i].Name == wantTable.Name { + if got[i].GetName() == wantTable.GetName() { gt = got[i] break } } - cmpColumns(tb, d, wantTable.Name, wantTable.ColumnDefinitions, gt.ColumnDefinitions) - cmpConstraints(tb, wantTable, gt) + cmpColumns(tb, d, wantTable.GetName(), wantTable.(*sqlschema.BaseTable).Columns, gt.(*sqlschema.BaseTable).Columns) + cmpConstraints(tb, wantTable.(*sqlschema.BaseTable), gt.(*sqlschema.BaseTable)) } } // cmpColumns compares that column definitions on the tables are -func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, want, got map[string]sqlschema.BaseColumn) { +func cmpColumns( + tb testing.TB, + d sqlschema.InspectorDialect, + tableName string, + want, got map[string]sqlschema.Column, +) { tb.Helper() var errs []string @@ -324,7 +331,8 @@ func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, w errorf := func(format string, args ...interface{}) { errs = append(errs, fmt.Sprintf("[%s.%s] "+format, append([]interface{}{tableName, colName}, args...)...)) } - gotCol, ok := got[colName] + wantCol := wantCol.(*sqlschema.BaseColumn) + gotCol, ok := got[colName].(*sqlschema.BaseColumn) if !ok { missing = append(missing, colName) continue @@ -372,7 +380,7 @@ func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, w } // cmpConstraints compares constraints defined on the table with the expected ones. -func cmpConstraints(tb testing.TB, want, got sqlschema.BaseTable) { +func cmpConstraints(tb testing.TB, want, got *sqlschema.BaseTable) { tb.Helper() if want.PrimaryKey != nil { @@ -392,18 +400,18 @@ func cmpConstraints(tb testing.TB, want, got sqlschema.BaseTable) { require.ElementsMatch(tb, stripNames(want.UniqueConstraints), stripNames(got.UniqueConstraints), "table %q does not have expected unique constraints (listA=want, listB=got)", want.Name) } -func tableNames(tables map[schema.FQN]sqlschema.BaseTable) (names []string) { +func tableNames(tables map[schema.FQN]sqlschema.Table) (names []string) { for fqn := range tables { names = append(names, fqn.Table) } return } -func formatType(c sqlschema.BaseColumn) string { - if c.VarcharLen == 0 { - return c.SQLType +func formatType(c sqlschema.Column) string { + if c.GetVarcharLen() == 0 { + return c.GetSQLType() } - return fmt.Sprintf("%s(%d)", c.SQLType, c.VarcharLen) + return fmt.Sprintf("%s(%d)", c.GetSQLType(), c.GetVarcharLen()) } func TestBunModelInspector_Inspect(t *testing.T) { @@ -422,12 +430,12 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables.Register((*Model)(nil)) inspector := sqlschema.NewBunModelInspector(tables) - want := map[string]sqlschema.BaseColumn{ - "id": { + want := map[string]sqlschema.Column{ + "id": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, DefaultValue: "random()", }, - "name": { + "name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, DefaultValue: "'John Doe'", }, @@ -439,7 +447,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { gotTables := got.(sqlschema.BunModelSchema).ModelTables require.Len(t, gotTables, 1) for _, table := range gotTables { - cmpColumns(t, dialect.(sqlschema.InspectorDialect), "model", want, table.ColumnDefinitions) + cmpColumns(t, dialect.(sqlschema.InspectorDialect), "model", want, table.Columns) return } }) @@ -455,15 +463,15 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables.Register((*Model)(nil)) inspector := sqlschema.NewBunModelInspector(tables) - want := map[string]sqlschema.BaseColumn{ - "id": { + want := map[string]sqlschema.Column{ + "id": &sqlschema.BaseColumn{ SQLType: "text", }, - "first_name": { + "first_name": &sqlschema.BaseColumn{ SQLType: "character varying", VarcharLen: 60, }, - "last_name": { + "last_name": &sqlschema.BaseColumn{ SQLType: "varchar", VarcharLen: 100, }, @@ -475,7 +483,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { gotTables := got.(sqlschema.BunModelSchema).ModelTables require.Len(t, gotTables, 1) for _, table := range gotTables { - cmpColumns(t, dialect.(sqlschema.InspectorDialect), "model", want, table.ColumnDefinitions) + cmpColumns(t, dialect.(sqlschema.InspectorDialect), "model", want, table.Columns) } }) @@ -490,7 +498,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables.Register((*Model)(nil)) inspector := sqlschema.NewBunModelInspector(tables) - want := sqlschema.BaseTable{ + want := &sqlschema.BaseTable{ Name: "models", UniqueConstraints: []sqlschema.Unique{ {Columns: sqlschema.NewColumns("id")}, @@ -504,7 +512,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { gotTables := got.(sqlschema.BunModelSchema).ModelTables require.Len(t, gotTables, 1) for _, table := range gotTables { - cmpConstraints(t, want, table.BaseTable) + cmpConstraints(t, want, &table.BaseTable) return } }) diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index 592178d5e..3e7d689f6 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -369,7 +369,7 @@ func testRenameTable(t *testing.T, db *bun.DB) { // Assert state := inspect(ctx) - tables := state.BaseTables + tables := state.Tables require.Len(t, tables, 1) require.Contains(t, tables, schema.FQN{Schema: db.Dialect().DefaultSchema(), Table: "changed"}) } @@ -398,7 +398,7 @@ func testCreateDropTable(t *testing.T, db *bun.DB) { // Assert state := inspect(ctx) - tables := state.BaseTables + tables := state.Tables require.Len(t, tables, 1) require.Contains(t, tables, schema.FQN{Schema: db.Dialect().DefaultSchema(), Table: "createme"}) } @@ -524,11 +524,11 @@ func testRenamedColumns(t *testing.T, db *bun.DB) { // Assert state := inspect(ctx) - require.Len(t, state.BaseTables, 2) + require.Len(t, state.Tables, 2) - var renamed, model2 sqlschema.BaseTable - for _, tbl := range state.BaseTables { - switch tbl.Name { + var renamed, model2 sqlschema.Table + for _, tbl := range state.Tables { + switch tbl.GetName() { case "renamed": renamed = tbl case "models": @@ -536,9 +536,9 @@ func testRenamedColumns(t *testing.T, db *bun.DB) { } } - require.Contains(t, renamed.ColumnDefinitions, "count") - require.Contains(t, model2.ColumnDefinitions, "second_column") - require.Contains(t, model2.ColumnDefinitions, "do_not_rename") + require.Contains(t, renamed.GetColumns(), "count") + require.Contains(t, model2.GetColumns(), "second_column") + require.Contains(t, model2.GetColumns(), "do_not_rename") } // testChangeColumnType_AutoCast checks type changes which can be type-casted automatically, @@ -568,35 +568,35 @@ func testChangeColumnType_AutoCast(t *testing.T, db *bun.DB) { // ManyValues []string `bun:",array"` // did not change } - wantTables := map[schema.FQN]sqlschema.BaseTable{ - {Schema: db.Dialect().DefaultSchema(), Table: "change_me_own_type"}: { + wantTables := map[schema.FQN]sqlschema.Table{ + {Schema: db.Dialect().DefaultSchema(), Table: "change_me_own_type"}: &sqlschema.BaseTable{ Schema: db.Dialect().DefaultSchema(), Name: "change_me_own_type", - ColumnDefinitions: map[string]sqlschema.BaseColumn{ - "bigger_int": { + Columns: map[string]sqlschema.Column{ + "bigger_int": &sqlschema.BaseColumn{ SQLType: "bigint", IsIdentity: true, }, - "ts": { + "ts": &sqlschema.BaseColumn{ SQLType: "timestamp", // FIXME(dyma): convert "timestamp with time zone" to sqltype.Timestamp DefaultValue: "current_timestamp", // FIXME(dyma): Convert driver-specific value to common "expressions" (e.g. CURRENT_TIMESTAMP == current_timestamp) OR lowercase all types. IsNullable: true, }, - "default_expr": { + "default_expr": &sqlschema.BaseColumn{ SQLType: "varchar", IsNullable: true, DefaultValue: "random()", }, - "empty_default": { + "empty_default": &sqlschema.BaseColumn{ SQLType: "varchar", IsNullable: true, DefaultValue: "", // NOT "''" }, - "not_null": { + "not_null": &sqlschema.BaseColumn{ SQLType: "varchar", IsNullable: false, }, - "type_override": { + "type_override": &sqlschema.BaseColumn{ SQLType: "varchar", IsNullable: true, VarcharLen: 200, @@ -619,7 +619,7 @@ func testChangeColumnType_AutoCast(t *testing.T, db *bun.DB) { // Assert state := inspect(ctx) - cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.BaseTables) + cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables) } func testIdentity(t *testing.T, db *bun.DB) { @@ -635,16 +635,16 @@ func testIdentity(t *testing.T, db *bun.DB) { B int64 `bun:",notnull,identity"` } - wantTables := map[schema.FQN]sqlschema.BaseTable{ - {Schema: db.Dialect().DefaultSchema(), Table: "bourne_identity"}: { + wantTables := map[schema.FQN]sqlschema.Table{ + {Schema: db.Dialect().DefaultSchema(), Table: "bourne_identity"}: &sqlschema.BaseTable{ Schema: db.Dialect().DefaultSchema(), Name: "bourne_identity", - ColumnDefinitions: map[string]sqlschema.BaseColumn{ - "a": { + Columns: map[string]sqlschema.Column{ + "a": &sqlschema.BaseColumn{ SQLType: sqltype.BigInt, IsIdentity: false, // <- drop IDENTITY }, - "b": { + "b": &sqlschema.BaseColumn{ SQLType: sqltype.BigInt, IsIdentity: true, // <- add IDENTITY }, @@ -662,7 +662,7 @@ func testIdentity(t *testing.T, db *bun.DB) { // Assert state := inspect(ctx) - cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.BaseTables) + cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables) } func testAddDropColumn(t *testing.T, db *bun.DB) { @@ -678,16 +678,16 @@ func testAddDropColumn(t *testing.T, db *bun.DB) { AddMe bool `bun:"addme"` } - wantTables := map[schema.FQN]sqlschema.BaseTable{ - {Schema: db.Dialect().DefaultSchema(), Table: "column_madness"}: { + wantTables := map[schema.FQN]sqlschema.Table{ + {Schema: db.Dialect().DefaultSchema(), Table: "column_madness"}: &sqlschema.BaseTable{ Schema: db.Dialect().DefaultSchema(), Name: "column_madness", - ColumnDefinitions: map[string]sqlschema.BaseColumn{ - "do_not_touch": { + Columns: map[string]sqlschema.Column{ + "do_not_touch": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, - "addme": { + "addme": &sqlschema.BaseColumn{ SQLType: sqltype.Boolean, IsNullable: true, }, @@ -705,7 +705,7 @@ func testAddDropColumn(t *testing.T, db *bun.DB) { // Assert state := inspect(ctx) - cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.BaseTables) + cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables) } func testUnique(t *testing.T, db *bun.DB) { @@ -731,36 +731,36 @@ func testUnique(t *testing.T, db *bun.DB) { PetBreed string `bun:"pet_breed"` // shrink "pet" unique group } - wantTables := map[schema.FQN]sqlschema.BaseTable{ - {Schema: db.Dialect().DefaultSchema(), Table: "uniqlo_stores"}: { + wantTables := map[schema.FQN]sqlschema.Table{ + {Schema: db.Dialect().DefaultSchema(), Table: "uniqlo_stores"}: &sqlschema.BaseTable{ Schema: db.Dialect().DefaultSchema(), Name: "uniqlo_stores", - ColumnDefinitions: map[string]sqlschema.BaseColumn{ - "first_name": { + Columns: map[string]sqlschema.Column{ + "first_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, - "middle_name": { + "middle_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, - "last_name": { + "last_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, - "birthday": { + "birthday": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, - "email": { + "email": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, - "pet_name": { + "pet_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, - "pet_breed": { + "pet_breed": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, @@ -784,7 +784,7 @@ func testUnique(t *testing.T, db *bun.DB) { // Assert state := inspect(ctx) - cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.BaseTables) + cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables) } func testUniqueRenamedTable(t *testing.T, db *bun.DB) { @@ -809,28 +809,28 @@ func testUniqueRenamedTable(t *testing.T, db *bun.DB) { PetBreed string `bun:"pet_breed,unique"` } - wantTables := map[schema.FQN]sqlschema.BaseTable{ - {Schema: db.Dialect().DefaultSchema(), Table: "after"}: { + wantTables := map[schema.FQN]sqlschema.Table{ + {Schema: db.Dialect().DefaultSchema(), Table: "after"}: &sqlschema.BaseTable{ Schema: db.Dialect().DefaultSchema(), Name: "after", - ColumnDefinitions: map[string]sqlschema.BaseColumn{ - "first_name": { + Columns: map[string]sqlschema.Column{ + "first_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, - "last_name": { + "last_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, - "birthday": { + "birthday": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, - "pet_name": { + "pet_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, - "pet_breed": { + "pet_breed": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, @@ -854,7 +854,7 @@ func testUniqueRenamedTable(t *testing.T, db *bun.DB) { // Assert state := inspect(ctx) - cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.BaseTables) + cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables) } func testUpdatePrimaryKeys(t *testing.T, db *bun.DB) { @@ -904,50 +904,50 @@ func testUpdatePrimaryKeys(t *testing.T, db *bun.DB) { LastName string `bun:"last_name,pk"` } - wantTables := map[schema.FQN]sqlschema.BaseTable{ - {Schema: db.Dialect().DefaultSchema(), Table: "drop_your_pks"}: { + wantTables := map[schema.FQN]sqlschema.Table{ + {Schema: db.Dialect().DefaultSchema(), Table: "drop_your_pks"}: &sqlschema.BaseTable{ Schema: db.Dialect().DefaultSchema(), Name: "drop_your_pks", - ColumnDefinitions: map[string]sqlschema.BaseColumn{ - "first_name": { + Columns: map[string]sqlschema.Column{ + "first_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: false, }, - "last_name": { + "last_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: false, }, }, }, - {Schema: db.Dialect().DefaultSchema(), Table: "add_new_pk"}: { + {Schema: db.Dialect().DefaultSchema(), Table: "add_new_pk"}: &sqlschema.BaseTable{ Schema: db.Dialect().DefaultSchema(), Name: "add_new_pk", - ColumnDefinitions: map[string]sqlschema.BaseColumn{ - "new_id": { + Columns: map[string]sqlschema.Column{ + "new_id": &sqlschema.BaseColumn{ SQLType: sqltype.BigInt, IsNullable: false, IsIdentity: true, }, - "first_name": { + "first_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, - "last_name": { + "last_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: true, }, }, PrimaryKey: &sqlschema.PrimaryKey{Columns: sqlschema.NewColumns("new_id")}, }, - {Schema: db.Dialect().DefaultSchema(), Table: "change_pk"}: { + {Schema: db.Dialect().DefaultSchema(), Table: "change_pk"}: &sqlschema.BaseTable{ Schema: db.Dialect().DefaultSchema(), Name: "change_pk", - ColumnDefinitions: map[string]sqlschema.BaseColumn{ - "first_name": { + Columns: map[string]sqlschema.Column{ + "first_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: false, }, - "last_name": { + "last_name": &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: false, }, @@ -974,5 +974,5 @@ func testUpdatePrimaryKeys(t *testing.T, db *bun.DB) { // Assert state := inspect(ctx) - cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.BaseTables) + cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables) } diff --git a/internal/dbtest/query_test.go b/internal/dbtest/query_test.go index 1d074f7ef..093c940c2 100644 --- a/internal/dbtest/query_test.go +++ b/internal/dbtest/query_test.go @@ -1618,7 +1618,7 @@ func TestAlterTable(t *testing.T) { {name: "add column with default value", operation: &migrate.AddColumnOp{ FQN: fqn, Column: "language", - ColDef: sqlschema.BaseColumn{ + ColDef: &sqlschema.BaseColumn{ SQLType: "varchar", VarcharLen: 20, IsNullable: false, @@ -1628,7 +1628,7 @@ func TestAlterTable(t *testing.T) { {name: "add column with identity", operation: &migrate.AddColumnOp{ FQN: fqn, Column: "n", - ColDef: sqlschema.BaseColumn{ + ColDef: &sqlschema.BaseColumn{ SQLType: sqltype.BigInt, IsNullable: false, IsIdentity: true, @@ -1637,7 +1637,7 @@ func TestAlterTable(t *testing.T) { {name: "drop column", operation: &migrate.DropColumnOp{ FQN: fqn, Column: "director", - ColDef: sqlschema.BaseColumn{ + ColDef: &sqlschema.BaseColumn{ SQLType: sqltype.VarChar, IsNullable: false, }, @@ -1659,50 +1659,50 @@ func TestAlterTable(t *testing.T) { {name: "change column type int to bigint", operation: &migrate.ChangeColumnTypeOp{ FQN: fqn, Column: "budget", - From: sqlschema.BaseColumn{SQLType: sqltype.Integer}, - To: sqlschema.BaseColumn{SQLType: sqltype.BigInt}, + From: &sqlschema.BaseColumn{SQLType: sqltype.Integer}, + To: &sqlschema.BaseColumn{SQLType: sqltype.BigInt}, }}, {name: "add default", operation: &migrate.ChangeColumnTypeOp{ FQN: fqn, Column: "budget", - From: sqlschema.BaseColumn{DefaultValue: ""}, - To: sqlschema.BaseColumn{DefaultValue: "100"}, + From: &sqlschema.BaseColumn{DefaultValue: ""}, + To: &sqlschema.BaseColumn{DefaultValue: "100"}, }}, {name: "drop default", operation: &migrate.ChangeColumnTypeOp{ FQN: fqn, Column: "budget", - From: sqlschema.BaseColumn{DefaultValue: "100"}, - To: sqlschema.BaseColumn{DefaultValue: ""}, + From: &sqlschema.BaseColumn{DefaultValue: "100"}, + To: &sqlschema.BaseColumn{DefaultValue: ""}, }}, {name: "make nullable", operation: &migrate.ChangeColumnTypeOp{ FQN: fqn, Column: "director", - From: sqlschema.BaseColumn{IsNullable: false}, - To: sqlschema.BaseColumn{IsNullable: true}, + From: &sqlschema.BaseColumn{IsNullable: false}, + To: &sqlschema.BaseColumn{IsNullable: true}, }}, {name: "add notnull", operation: &migrate.ChangeColumnTypeOp{ FQN: fqn, Column: "budget", - From: sqlschema.BaseColumn{IsNullable: true}, - To: sqlschema.BaseColumn{IsNullable: false}, + From: &sqlschema.BaseColumn{IsNullable: true}, + To: &sqlschema.BaseColumn{IsNullable: false}, }}, {name: "increase varchar length", operation: &migrate.ChangeColumnTypeOp{ FQN: fqn, Column: "language", - From: sqlschema.BaseColumn{SQLType: "varchar", VarcharLen: 20}, - To: sqlschema.BaseColumn{SQLType: "varchar", VarcharLen: 255}, + From: &sqlschema.BaseColumn{SQLType: "varchar", VarcharLen: 20}, + To: &sqlschema.BaseColumn{SQLType: "varchar", VarcharLen: 255}, }}, {name: "add identity", operation: &migrate.ChangeColumnTypeOp{ FQN: fqn, Column: "id", - From: sqlschema.BaseColumn{IsIdentity: false}, - To: sqlschema.BaseColumn{IsIdentity: true}, + From: &sqlschema.BaseColumn{IsIdentity: false}, + To: &sqlschema.BaseColumn{IsIdentity: true}, }}, {name: "drop identity", operation: &migrate.ChangeColumnTypeOp{ FQN: fqn, Column: "id", - From: sqlschema.BaseColumn{IsIdentity: true}, - To: sqlschema.BaseColumn{IsIdentity: false}, + From: &sqlschema.BaseColumn{IsIdentity: true}, + To: &sqlschema.BaseColumn{IsIdentity: false}, }}, {name: "add primary key", operation: &migrate.AddPrimaryKeyOp{ FQN: fqn, diff --git a/migrate/diff.go b/migrate/diff.go index 4d6a177c6..e1fdd160e 100644 --- a/migrate/diff.go +++ b/migrate/diff.go @@ -58,7 +58,7 @@ RenameCreate: // If wantTable does not exist in the database and was not renamed // then we need to create this table in the database. - additional := wantTable.(sqlschema.BunTable) + additional := wantTable.(*sqlschema.BunTable) d.changes.Add(&CreateTableOp{ FQN: wantTable.GetFQN(), Model: additional.Model, diff --git a/migrate/sqlschema/column.go b/migrate/sqlschema/column.go index 5a8b70483..95d9a3efc 100644 --- a/migrate/sqlschema/column.go +++ b/migrate/sqlschema/column.go @@ -19,7 +19,8 @@ type Column interface { var _ Column = (*BaseColumn)(nil) -// BaseColumn stores attributes of a database column. +// BaseColumn is a base column definition that stores various attributes of a column. +// It MUST only be used by dialects to implement the Column interface. type BaseColumn struct { Name string SQLType string diff --git a/migrate/sqlschema/inspector.go b/migrate/sqlschema/inspector.go index d8b882182..24c52fcb2 100644 --- a/migrate/sqlschema/inspector.go +++ b/migrate/sqlschema/inspector.go @@ -66,7 +66,7 @@ func NewBunModelInspector(tables *schema.Tables) *BunModelInspector { type BunModelSchema struct { DatabaseSchema - ModelTables map[schema.FQN]BunTable + ModelTables map[schema.FQN]*BunTable } func (ms BunModelSchema) GetTables() []Table { @@ -90,10 +90,10 @@ func (bmi *BunModelInspector) Inspect(ctx context.Context) (Schema, error) { DatabaseSchema: DatabaseSchema{ ForeignKeys: make(map[ForeignKey]string), }, - ModelTables: make(map[schema.FQN]BunTable), + ModelTables: make(map[schema.FQN]*BunTable), } for _, t := range bmi.tables.All() { - columns := make(map[string]*BaseColumn) + columns := make(map[string]Column) for _, f := range t.Fields { sqlType, length, err := parseLen(f.CreateTableSQLType) @@ -140,11 +140,11 @@ func (bmi *BunModelInspector) Inspect(ctx context.Context) (Schema, error) { } fqn := schema.FQN{Schema: t.Schema, Table: t.Name} - state.ModelTables[fqn] = BunTable{ + state.ModelTables[fqn] = &BunTable{ BaseTable: BaseTable{ Schema: t.Schema, Name: t.Name, - ColumnDefinitions: columns, + Columns: columns, UniqueConstraints: unique, PrimaryKey: pk, }, diff --git a/migrate/sqlschema/schema.go b/migrate/sqlschema/schema.go index f823d4cd8..3d72a0a37 100644 --- a/migrate/sqlschema/schema.go +++ b/migrate/sqlschema/schema.go @@ -11,7 +11,7 @@ import ( // Dialects which support schema inspection may return it directly from Inspect() // or embed it in their custom schema structs. type DatabaseSchema struct { - BaseTables map[schema.FQN]BaseTable + Tables map[schema.FQN]Table ForeignKeys map[ForeignKey]string } @@ -113,8 +113,8 @@ type ColumnReference struct { func (ds DatabaseSchema) GetTables() []Table { var tables []Table - for i := range ds.BaseTables { - tables = append(tables, ds.BaseTables[i]) + for i := range ds.Tables { + tables = append(tables, ds.Tables[i]) } return tables } @@ -122,28 +122,3 @@ func (ds DatabaseSchema) GetTables() []Table { func (ds DatabaseSchema) GetForeignKeys() map[ForeignKey]string { return ds.ForeignKeys } - -func (td BaseTable) GetSchema() string { - return td.Schema -} -func (td BaseTable) GetName() string { - return td.Name -} -func (td BaseTable) GetColumns() []Column { - var columns []Column - // FIXME: columns will be returned in a random order - for colName := range td.ColumnDefinitions { - columns = append(columns, td.ColumnDefinitions[colName]) - } - return columns -} -func (td BaseTable) GetPrimaryKey() *PrimaryKey { - return td.PrimaryKey -} -func (td BaseTable) GetUniqueConstraints() []Unique { - return td.UniqueConstraints -} - -func (t BaseTable) GetFQN() schema.FQN { - return schema.FQN{Schema: t.Schema, Table: t.Name} -} diff --git a/migrate/sqlschema/table.go b/migrate/sqlschema/table.go index 50e479efb..54877f60b 100644 --- a/migrate/sqlschema/table.go +++ b/migrate/sqlschema/table.go @@ -13,12 +13,15 @@ type Table interface { var _ Table = (*BaseTable)(nil) +// BaseTable is a base table definition. +// It MUST only be used by dialects to implement the Table interface. type BaseTable struct { Schema string Name string // ColumnDefinitions map each column name to the column definition. - ColumnDefinitions map[string]*BaseColumn + // TODO: this must be an ordered map so the order of columns is preserved + Columns map[string]Column // PrimaryKey holds the primary key definition. // A nil value means that no primary key is defined for the table. @@ -33,3 +36,32 @@ type PrimaryKey struct { Name string Columns Columns } + +func (td *BaseTable) GetSchema() string { + return td.Schema +} + +func (td *BaseTable) GetName() string { + return td.Name +} + +func (td *BaseTable) GetColumns() []Column { + var columns []Column + // FIXME: columns will be returned in a random order + for colName := range td.Columns { + columns = append(columns, td.Columns[colName]) + } + return columns +} + +func (td *BaseTable) GetPrimaryKey() *PrimaryKey { + return td.PrimaryKey +} + +func (td *BaseTable) GetUniqueConstraints() []Unique { + return td.UniqueConstraints +} + +func (t *BaseTable) GetFQN() schema.FQN { + return schema.FQN{Schema: t.Schema, Table: t.Name} +}