diff --git a/internal/dbtest/orm_test.go b/internal/dbtest/orm_test.go index fcd1b3931..1830e999d 100644 --- a/internal/dbtest/orm_test.go +++ b/internal/dbtest/orm_test.go @@ -13,6 +13,7 @@ import ( "github.com/uptrace/bun" "github.com/uptrace/bun/dbfixture" + "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect/feature" ) @@ -32,6 +33,7 @@ func TestORM(t *testing.T) { {testM2MRelationExcludeColumn}, {testRelationBelongsToSelf}, {testCompositeHasMany}, + {testCompositeM2M}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -441,6 +443,82 @@ func testCompositeHasMany(t *testing.T, db *bun.DB) { require.Equal(t, 2, len(department.Employees)) } +func testCompositeM2M(t *testing.T, db *bun.DB) { + if db.Dialect().Name() == dialect.MSSQL { + t.Skip() + } + + type Item struct { + ID int64 `bun:",pk"` + ShopID int64 `bun:",pk"` + } + + type Order struct { + ID int64 `bun:",pk"` + ShopID int64 `bun:",pk"` + Items []Item `bun:"m2m:orders_to_items,join:Order=Item"` + } + + type OrderToItem struct { + bun.BaseModel `bun:"table:orders_to_items"` + + ShopID int64 `bun:""` + + OrderID int64 `bun:""` + Order *Order `bun:"rel:belongs-to,join:shop_id=shop_id,join:order_id=id"` + ItemID int64 `bun:""` + Item *Item `bun:"rel:belongs-to,join:shop_id=shop_id,join:item_id=id"` + } + + db.RegisterModel((*OrderToItem)(nil)) + mustResetModel(t, ctx, db, (*Order)(nil), (*Item)(nil), (*OrderToItem)(nil)) + + items := []Item{ + {ID: 1, ShopID: 22}, + {ID: 2, ShopID: 22}, + {ID: 3, ShopID: 22}, + } + _, err := db.NewInsert().Model(&items).Exec(ctx) + require.NoError(t, err) + + orders := []Order{ + {ID: 12, ShopID: 22}, + {ID: 13, ShopID: 22}, + } + _, err = db.NewInsert().Model(&orders).Exec(ctx) + require.NoError(t, err) + + orderItems := []OrderToItem{ + {OrderID: 12, ItemID: 1, ShopID: 22}, + {OrderID: 12, ItemID: 2, ShopID: 22}, + {OrderID: 13, ItemID: 3, ShopID: 22}, + } + _, err = db.NewInsert().Model(&orderItems).Exec(ctx) + require.NoError(t, err) + + var ordersOut []Order + + err = db.NewSelect(). + Model(&ordersOut). + Where("id = ?", 12). + Relation("Items"). + Scan(ctx) + require.NoError(t, err) + require.Equal(t, 1, len(ordersOut)) + require.Equal(t, 2, len(ordersOut[0].Items)) + + var ordersOut2 []Order + + err = db.NewSelect(). + Model(&ordersOut2). + Where("id = ?", 13). + Relation("Items"). + Scan(ctx) + require.NoError(t, err) + require.Equal(t, 1, len(ordersOut2)) + require.Equal(t, 1, len(ordersOut2[0].Items)) +} + type Genre struct { ID int `bun:",pk"` Name string diff --git a/model_table_m2m.go b/model_table_m2m.go index 88d8a1268..789a5b0b0 100644 --- a/model_table_m2m.go +++ b/model_table_m2m.go @@ -24,7 +24,7 @@ var _ TableModel = (*m2mModel)(nil) func newM2MModel(j *relationJoin) *m2mModel { baseTable := j.BaseModel.Table() joinModel := j.JoinModel.(*sliceTableModel) - baseValues := baseValues(joinModel, baseTable.PKs) + baseValues := baseValues(joinModel, j.Relation.BaseFields) if len(baseValues) == 0 { return nil } @@ -83,23 +83,17 @@ func (m *m2mModel) Scan(src interface{}) error { column := m.columns[m.scanIndex] m.scanIndex++ - field, ok := m.table.FieldMap[column] - if !ok { + // Base pks must come first. + if m.scanIndex <= len(m.rel.M2MBaseFields) { return m.scanM2MColumn(column, src) } - if err := field.ScanValue(m.strct, src); err != nil { - return err - } - - for _, fk := range m.rel.M2MBaseFields { - if fk.Name == field.Name { - m.structKey = append(m.structKey, field.Value(m.strct).Interface()) - break - } + if field, ok := m.table.FieldMap[column]; ok { + return field.ScanValue(m.strct, src) } - return nil + _, err := m.scanColumn(column, src) + return err } func (m *m2mModel) scanM2MColumn(column string, src interface{}) error { diff --git a/relation_join.go b/relation_join.go index 0ec2aa82d..90f3c999c 100644 --- a/relation_join.go +++ b/relation_join.go @@ -175,10 +175,10 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery { q = q.Model(m2mModel) index := j.JoinModel.parentIndex() - baseTable := j.BaseModel.Table() if j.Relation.M2MTable != nil { - fields := append(j.Relation.M2MBaseFields, j.Relation.M2MJoinFields...) + // We only need base pks to park joined models to the base model. + fields := j.Relation.M2MBaseFields b := make([]byte, 0, len(fields)) b = appendColumns(b, j.Relation.M2MTable.SQLAlias, fields) @@ -202,7 +202,7 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery { join = append(join, col.SQLName...) } join = append(join, ") IN ("...) - join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, baseTable.PKs) + join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, j.Relation.BaseFields) join = append(join, ")"...) q = q.Join(internal.String(join))