From 85da372f5e2e265b8a072f9e542a4e035b666271 Mon Sep 17 00:00:00 2001 From: thecampagnards Date: Wed, 27 Nov 2024 10:10:35 +0000 Subject: [PATCH] fix: m2m relation with driver.Valuer --- internal/dbtest/db_test.go | 55 ++++++++++++++++++++++++++++++++++++++ model_table_has_many.go | 7 +++++ 2 files changed, 62 insertions(+) diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index ddc9d70a5..e609c7140 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -279,6 +279,7 @@ func TestDB(t *testing.T) { {testWithForeignKeys}, {testWithForeignKeysHasMany}, {testWithPointerForeignKeysHasMany}, + {testWithPointerForeignKeysHasManySQLNull}, {testInterfaceAny}, {testInterfaceJSON}, {testScanRawMessage}, @@ -1172,6 +1173,60 @@ func testWithPointerForeignKeysHasMany(t *testing.T, db *bun.DB) { require.Len(t, deck.Users, 2) } +func testWithPointerForeignKeysHasManySQLNull(t *testing.T, db *bun.DB) { + type User struct { + ID *int `bun:",pk"` + DeckID sql.NullInt64 + Name string + } + type Deck struct { + ID int64 `bun:",pk"` + Users []*User `bun:"rel:has-many,join:id=deck_id"` + } + + if db.Dialect().Name() == dialect.SQLite { + _, err := db.Exec("PRAGMA foreign_keys = ON;") + require.NoError(t, err) + } + + for _, model := range []interface{}{(*Deck)(nil), (*User)(nil)} { + _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx) + require.NoError(t, err) + } + + mustResetModel(t, ctx, db, (*User)(nil)) + _, err := db.NewCreateTable(). + Model((*Deck)(nil)). + IfNotExists(). + WithForeignKeys(). + Exec(ctx) + require.NoError(t, err) + mustDropTableOnCleanup(t, ctx, db, (*Deck)(nil)) + + deckID := int64(1) + deck := Deck{ID: deckID} + _, err = db.NewInsert().Model(&deck).Exec(ctx) + require.NoError(t, err) + + userID1 := 1 + userID2 := 2 + users := []*User{ + {ID: &userID1, DeckID: sql.NullInt64{Int64: deckID, Valid: true}, Name: "user 1"}, + {ID: &userID2, DeckID: sql.NullInt64{Int64: deckID, Valid: true}, Name: "user 2"}, + } + + res, err := db.NewInsert().Model(&users).Exec(ctx) + require.NoError(t, err) + + affected, err := res.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(2), affected) + + err = db.NewSelect().Model(&deck).Relation("Users").Scan(ctx) + require.NoError(t, err) + require.Len(t, deck.Users, 2) +} + func testInterfaceAny(t *testing.T, db *bun.DB) { switch db.Dialect().Name() { case dialect.MySQL: diff --git a/model_table_has_many.go b/model_table_has_many.go index cd721a1b2..f367d2c96 100644 --- a/model_table_has_many.go +++ b/model_table_has_many.go @@ -3,6 +3,7 @@ package bun import ( "context" "database/sql" + "database/sql/driver" "fmt" "reflect" @@ -152,6 +153,12 @@ func modelKey(key []interface{}, strct reflect.Value, fields []*schema.Field) [] // The value is then used as a map key. func indirectFieldValue(field reflect.Value) interface{} { if field.Kind() != reflect.Ptr { + i := field.Interface() + if valuer, ok := i.(driver.Valuer); ok { + if v, err := valuer.Value(); err == nil { + return v + } + } return field.Interface() } if field.IsNil() {