Skip to content

Commit

Permalink
Fix M2M relations for models with composite keys (#996)
Browse files Browse the repository at this point in the history
* fix(m2m join): composite keys
  • Loading branch information
ygabuev authored Oct 26, 2024
1 parent e2cd9f0 commit d99b767
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 16 deletions.
78 changes: 78 additions & 0 deletions internal/dbtest/orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/uptrace/bun"
"github.com/uptrace/bun/dbfixture"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
)

Expand All @@ -32,6 +33,7 @@ func TestORM(t *testing.T) {
{testM2MRelationExcludeColumn},
{testRelationBelongsToSelf},
{testCompositeHasMany},
{testCompositeM2M},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -441,6 +443,82 @@ func testCompositeHasMany(t *testing.T, db *bun.DB) {
require.Equal(t, 2, len(department.Employees))
}

func testCompositeM2M(t *testing.T, db *bun.DB) {
if db.Dialect().Name() == dialect.MSSQL {
t.Skip()
}

type Item struct {
ID int64 `bun:",pk"`
ShopID int64 `bun:",pk"`
}

type Order struct {
ID int64 `bun:",pk"`
ShopID int64 `bun:",pk"`
Items []Item `bun:"m2m:orders_to_items,join:Order=Item"`
}

type OrderToItem struct {
bun.BaseModel `bun:"table:orders_to_items"`

ShopID int64 `bun:""`

OrderID int64 `bun:""`
Order *Order `bun:"rel:belongs-to,join:shop_id=shop_id,join:order_id=id"`
ItemID int64 `bun:""`
Item *Item `bun:"rel:belongs-to,join:shop_id=shop_id,join:item_id=id"`
}

db.RegisterModel((*OrderToItem)(nil))
mustResetModel(t, ctx, db, (*Order)(nil), (*Item)(nil), (*OrderToItem)(nil))

items := []Item{
{ID: 1, ShopID: 22},
{ID: 2, ShopID: 22},
{ID: 3, ShopID: 22},
}
_, err := db.NewInsert().Model(&items).Exec(ctx)
require.NoError(t, err)

orders := []Order{
{ID: 12, ShopID: 22},
{ID: 13, ShopID: 22},
}
_, err = db.NewInsert().Model(&orders).Exec(ctx)
require.NoError(t, err)

orderItems := []OrderToItem{
{OrderID: 12, ItemID: 1, ShopID: 22},
{OrderID: 12, ItemID: 2, ShopID: 22},
{OrderID: 13, ItemID: 3, ShopID: 22},
}
_, err = db.NewInsert().Model(&orderItems).Exec(ctx)
require.NoError(t, err)

var ordersOut []Order

err = db.NewSelect().
Model(&ordersOut).
Where("id = ?", 12).
Relation("Items").
Scan(ctx)
require.NoError(t, err)
require.Equal(t, 1, len(ordersOut))
require.Equal(t, 2, len(ordersOut[0].Items))

var ordersOut2 []Order

err = db.NewSelect().
Model(&ordersOut2).
Where("id = ?", 13).
Relation("Items").
Scan(ctx)
require.NoError(t, err)
require.Equal(t, 1, len(ordersOut2))
require.Equal(t, 1, len(ordersOut2[0].Items))
}

type Genre struct {
ID int `bun:",pk"`
Name string
Expand Down
20 changes: 7 additions & 13 deletions model_table_m2m.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ var _ TableModel = (*m2mModel)(nil)
func newM2MModel(j *relationJoin) *m2mModel {
baseTable := j.BaseModel.Table()
joinModel := j.JoinModel.(*sliceTableModel)
baseValues := baseValues(joinModel, baseTable.PKs)
baseValues := baseValues(joinModel, j.Relation.BaseFields)
if len(baseValues) == 0 {
return nil
}
Expand Down Expand Up @@ -83,23 +83,17 @@ func (m *m2mModel) Scan(src interface{}) error {
column := m.columns[m.scanIndex]
m.scanIndex++

field, ok := m.table.FieldMap[column]
if !ok {
// Base pks must come first.
if m.scanIndex <= len(m.rel.M2MBaseFields) {
return m.scanM2MColumn(column, src)
}

if err := field.ScanValue(m.strct, src); err != nil {
return err
}

for _, fk := range m.rel.M2MBaseFields {
if fk.Name == field.Name {
m.structKey = append(m.structKey, field.Value(m.strct).Interface())
break
}
if field, ok := m.table.FieldMap[column]; ok {
return field.ScanValue(m.strct, src)
}

return nil
_, err := m.scanColumn(column, src)
return err
}

func (m *m2mModel) scanM2MColumn(column string, src interface{}) error {
Expand Down
6 changes: 3 additions & 3 deletions relation_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery {
q = q.Model(m2mModel)

index := j.JoinModel.parentIndex()
baseTable := j.BaseModel.Table()

if j.Relation.M2MTable != nil {
fields := append(j.Relation.M2MBaseFields, j.Relation.M2MJoinFields...)
// We only need base pks to park joined models to the base model.
fields := j.Relation.M2MBaseFields

b := make([]byte, 0, len(fields))
b = appendColumns(b, j.Relation.M2MTable.SQLAlias, fields)
Expand All @@ -202,7 +202,7 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery {
join = append(join, col.SQLName...)
}
join = append(join, ") IN ("...)
join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, baseTable.PKs)
join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, j.Relation.BaseFields)
join = append(join, ")"...)
q = q.Join(internal.String(join))

Expand Down

0 comments on commit d99b767

Please sign in to comment.