Skip to content

Commit

Permalink
fix most tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Nov 9, 2024
1 parent acd2560 commit 8af7563
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 181 deletions.
8 changes: 4 additions & 4 deletions dialect/pgdialect/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func newInspector(db *bun.DB, excludeTables ...string) *Inspector {

func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Schema, error) {
dbSchema := Schema{
BaseTables: make(map[schema.FQN]Table),
Tables: make(map[schema.FQN]sqlschema.Table),
ForeignKeys: make(map[sqlschema.ForeignKey]string),
}

Expand All @@ -59,7 +59,7 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Schema, error) {
return dbSchema, err
}

colDefs := make(map[string]*Column)
colDefs := make(map[string]sqlschema.Column)
uniqueGroups := make(map[string][]string)

for _, c := range columns {
Expand Down Expand Up @@ -102,10 +102,10 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Schema, error) {
}

fqn := schema.FQN{Schema: table.Schema, Table: table.Name}
dbSchema.BaseTables[fqn] = Table{
dbSchema.Tables[fqn] = &Table{
Schema: table.Schema,
Name: table.Name,
ColumnDefinitions: colDefs,
Columns: colDefs,
PrimaryKey: pk,
UniqueConstraints: unique,
}
Expand Down
120 changes: 64 additions & 56 deletions internal/dbtest/inspect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,76 +93,76 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
defaultSchema := db.Dialect().DefaultSchema()

// Tables come sorted alphabetically by schema and table.
wantTables := map[schema.FQN]sqlschema.BaseTable{
{Schema: "admin", Table: "offices"}: {
wantTables := map[schema.FQN]sqlschema.Table{
{Schema: "admin", Table: "offices"}: &sqlschema.BaseTable{
Schema: "admin",
Name: "offices",
ColumnDefinitions: map[string]sqlschema.BaseColumn{
"office_name": {
Columns: map[string]sqlschema.Column{
"office_name": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
},
"publisher_id": {
"publisher_id": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
IsNullable: true,
},
"publisher_name": {
"publisher_name": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
IsNullable: true,
},
},
PrimaryKey: &sqlschema.PrimaryKey{Columns: sqlschema.NewColumns("office_name")},
},
{Schema: defaultSchema, Table: "articles"}: {
{Schema: defaultSchema, Table: "articles"}: &sqlschema.BaseTable{
Schema: defaultSchema,
Name: "articles",
ColumnDefinitions: map[string]sqlschema.BaseColumn{
"isbn": {
Columns: map[string]sqlschema.Column{
"isbn": &sqlschema.BaseColumn{
SQLType: "bigint",
IsNullable: false,
IsAutoIncrement: false,
IsIdentity: true,
DefaultValue: "",
},
"editor": {
"editor": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
IsNullable: false,
IsAutoIncrement: false,
IsIdentity: false,
DefaultValue: "john doe",
},
"title": {
"title": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
IsNullable: false,
IsAutoIncrement: false,
IsIdentity: false,
DefaultValue: "",
},
"locale": {
"locale": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
VarcharLen: 5,
IsNullable: true,
IsAutoIncrement: false,
IsIdentity: false,
DefaultValue: "en-GB",
},
"page_count": {
"page_count": &sqlschema.BaseColumn{
SQLType: "smallint",
IsNullable: false,
IsAutoIncrement: false,
IsIdentity: false,
DefaultValue: "1",
},
"book_count": {
"book_count": &sqlschema.BaseColumn{
SQLType: "integer",
IsNullable: false,
IsAutoIncrement: true,
IsIdentity: false,
DefaultValue: "",
},
"publisher_id": {
"publisher_id": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
},
"author_id": {
"author_id": &sqlschema.BaseColumn{
SQLType: "bigint",
},
},
Expand All @@ -171,21 +171,21 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
{Columns: sqlschema.NewColumns("editor", "title")},
},
},
{Schema: defaultSchema, Table: "authors"}: {
{Schema: defaultSchema, Table: "authors"}: &sqlschema.BaseTable{
Schema: defaultSchema,
Name: "authors",
ColumnDefinitions: map[string]sqlschema.BaseColumn{
"author_id": {
Columns: map[string]sqlschema.Column{
"author_id": &sqlschema.BaseColumn{
SQLType: "bigint",
IsIdentity: true,
},
"first_name": {
"first_name": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
},
"last_name": {
"last_name": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
},
"email": {
"email": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
},
},
Expand All @@ -195,31 +195,31 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
{Columns: sqlschema.NewColumns("email")},
},
},
{Schema: defaultSchema, Table: "publisher_to_journalists"}: {
{Schema: defaultSchema, Table: "publisher_to_journalists"}: &sqlschema.BaseTable{
Schema: defaultSchema,
Name: "publisher_to_journalists",
ColumnDefinitions: map[string]sqlschema.BaseColumn{
"publisher_id": {
Columns: map[string]sqlschema.Column{
"publisher_id": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
},
"author_id": {
"author_id": &sqlschema.BaseColumn{
SQLType: "bigint",
},
},
PrimaryKey: &sqlschema.PrimaryKey{Columns: sqlschema.NewColumns("publisher_id", "author_id")},
},
{Schema: defaultSchema, Table: "publishers"}: {
{Schema: defaultSchema, Table: "publishers"}: &sqlschema.BaseTable{
Schema: defaultSchema,
Name: "publishers",
ColumnDefinitions: map[string]sqlschema.BaseColumn{
"publisher_id": {
Columns: map[string]sqlschema.Column{
"publisher_id": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
DefaultValue: "gen_random_uuid()",
},
"publisher_name": {
"publisher_name": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
},
"created_at": {
"created_at": &sqlschema.BaseColumn{
SQLType: "timestamp",
DefaultValue: "current_timestamp",
IsNullable: true,
Expand Down Expand Up @@ -260,7 +260,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) {

// State.FKs store their database names, which differ from dialect to dialect.
// Because of that we compare FKs and Tables separately.
gotTables := got.(sqlschema.DatabaseSchema).BaseTables
gotTables := got.(sqlschema.DatabaseSchema).Tables
cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, gotTables)

var fks []sqlschema.ForeignKey
Expand Down Expand Up @@ -293,29 +293,36 @@ func mustCreateSchema(tb testing.TB, ctx context.Context, db *bun.DB, schema str

// cmpTables compares table schemas using dialect-specific equivalence checks for column types
// and reports the differences as t.Error().
func cmpTables(tb testing.TB, d sqlschema.InspectorDialect, want, got map[schema.FQN]sqlschema.BaseTable) {
func cmpTables(
tb testing.TB, d sqlschema.InspectorDialect, want, got map[schema.FQN]sqlschema.Table,
) {
tb.Helper()

require.ElementsMatch(tb, tableNames(want), tableNames(got), "different set of tables")

// Now we are guaranteed to have the same tables.
for _, wantTable := range want {
// TODO(dyma): this will be simplified by map[string]Table
var gt sqlschema.BaseTable
var gt sqlschema.Table
for i := range got {
if got[i].Name == wantTable.Name {
if got[i].GetName() == wantTable.GetName() {
gt = got[i]
break
}
}

cmpColumns(tb, d, wantTable.Name, wantTable.ColumnDefinitions, gt.ColumnDefinitions)
cmpConstraints(tb, wantTable, gt)
cmpColumns(tb, d, wantTable.GetName(), wantTable.(*sqlschema.BaseTable).Columns, gt.(*sqlschema.BaseTable).Columns)
cmpConstraints(tb, wantTable.(*sqlschema.BaseTable), gt.(*sqlschema.BaseTable))
}
}

// cmpColumns compares that column definitions on the tables are
func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, want, got map[string]sqlschema.BaseColumn) {
func cmpColumns(
tb testing.TB,
d sqlschema.InspectorDialect,
tableName string,
want, got map[string]sqlschema.Column,
) {
tb.Helper()
var errs []string

Expand All @@ -324,7 +331,8 @@ func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, w
errorf := func(format string, args ...interface{}) {
errs = append(errs, fmt.Sprintf("[%s.%s] "+format, append([]interface{}{tableName, colName}, args...)...))
}
gotCol, ok := got[colName]
wantCol := wantCol.(*sqlschema.BaseColumn)
gotCol, ok := got[colName].(*sqlschema.BaseColumn)
if !ok {
missing = append(missing, colName)
continue
Expand Down Expand Up @@ -372,7 +380,7 @@ func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, w
}

// cmpConstraints compares constraints defined on the table with the expected ones.
func cmpConstraints(tb testing.TB, want, got sqlschema.BaseTable) {
func cmpConstraints(tb testing.TB, want, got *sqlschema.BaseTable) {
tb.Helper()

if want.PrimaryKey != nil {
Expand All @@ -392,18 +400,18 @@ func cmpConstraints(tb testing.TB, want, got sqlschema.BaseTable) {
require.ElementsMatch(tb, stripNames(want.UniqueConstraints), stripNames(got.UniqueConstraints), "table %q does not have expected unique constraints (listA=want, listB=got)", want.Name)
}

func tableNames(tables map[schema.FQN]sqlschema.BaseTable) (names []string) {
func tableNames(tables map[schema.FQN]sqlschema.Table) (names []string) {
for fqn := range tables {
names = append(names, fqn.Table)
}
return
}

func formatType(c sqlschema.BaseColumn) string {
if c.VarcharLen == 0 {
return c.SQLType
func formatType(c sqlschema.Column) string {
if c.GetVarcharLen() == 0 {
return c.GetSQLType()
}
return fmt.Sprintf("%s(%d)", c.SQLType, c.VarcharLen)
return fmt.Sprintf("%s(%d)", c.GetSQLType(), c.GetVarcharLen())
}

func TestBunModelInspector_Inspect(t *testing.T) {
Expand All @@ -422,12 +430,12 @@ func TestBunModelInspector_Inspect(t *testing.T) {
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)

want := map[string]sqlschema.BaseColumn{
"id": {
want := map[string]sqlschema.Column{
"id": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
DefaultValue: "random()",
},
"name": {
"name": &sqlschema.BaseColumn{
SQLType: sqltype.VarChar,
DefaultValue: "'John Doe'",
},
Expand All @@ -439,7 +447,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {
gotTables := got.(sqlschema.BunModelSchema).ModelTables
require.Len(t, gotTables, 1)
for _, table := range gotTables {
cmpColumns(t, dialect.(sqlschema.InspectorDialect), "model", want, table.ColumnDefinitions)
cmpColumns(t, dialect.(sqlschema.InspectorDialect), "model", want, table.Columns)
return
}
})
Expand All @@ -455,15 +463,15 @@ func TestBunModelInspector_Inspect(t *testing.T) {
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)

want := map[string]sqlschema.BaseColumn{
"id": {
want := map[string]sqlschema.Column{
"id": &sqlschema.BaseColumn{
SQLType: "text",
},
"first_name": {
"first_name": &sqlschema.BaseColumn{
SQLType: "character varying",
VarcharLen: 60,
},
"last_name": {
"last_name": &sqlschema.BaseColumn{
SQLType: "varchar",
VarcharLen: 100,
},
Expand All @@ -475,7 +483,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {
gotTables := got.(sqlschema.BunModelSchema).ModelTables
require.Len(t, gotTables, 1)
for _, table := range gotTables {
cmpColumns(t, dialect.(sqlschema.InspectorDialect), "model", want, table.ColumnDefinitions)
cmpColumns(t, dialect.(sqlschema.InspectorDialect), "model", want, table.Columns)
}
})

Expand All @@ -490,7 +498,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)

want := sqlschema.BaseTable{
want := &sqlschema.BaseTable{
Name: "models",
UniqueConstraints: []sqlschema.Unique{
{Columns: sqlschema.NewColumns("id")},
Expand All @@ -504,7 +512,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {
gotTables := got.(sqlschema.BunModelSchema).ModelTables
require.Len(t, gotTables, 1)
for _, table := range gotTables {
cmpConstraints(t, want, table.BaseTable)
cmpConstraints(t, want, &table.BaseTable)
return
}
})
Expand Down
Loading

0 comments on commit 8af7563

Please sign in to comment.