Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(relations): support macro ?TableAlias in relation callback #585

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions internal/dbtest/relation_join_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package dbtest_test

import (
"context"
"database/sql"
"github.com/stretchr/testify/require"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/sqlitedialect"
"github.com/uptrace/bun/driver/sqliteshim"
"github.com/uptrace/bun/extra/bundebug"
"testing"
)

type TestRelProfile struct {
ID int64 `bun:",pk,autoincrement"`
Lang string
UserID int64
}

type TestRelUser struct {
ID int64 `bun:",pk,autoincrement"`
Name string
Profile *TestRelProfile `bun:"rel:has-one,join:id=user_id"`
Disks []TestRelDisk `bun:"rel:has-many,join:id=user_id"`
}

type TestRelDisk struct {
ID int64 `bun:",pk,autoincrement"`
Title string
UserID int64
User *TestRelUser `bun:"rel:belongs-to,join:user_id=id"`
}

func TestRelationJoin(t *testing.T) {

ctx := context.Background()

sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
if err != nil {
panic(err)
}

db := bun.NewDB(sqldb, sqlitedialect.New())
defer db.Close()

db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))

// Create schema

models := []interface{}{
(*TestRelUser)(nil),
(*TestRelProfile)(nil),
(*TestRelDisk)(nil),
}
for _, model := range models {
_, err = db.NewCreateTable().Model(model).Exec(ctx)
require.NoError(t, err)
}

expectedUsers := []*TestRelUser{
{ID: 1, Name: "user 1"},
{ID: 2, Name: "user 2"},
}

_, err = db.NewInsert().Model(&expectedUsers).Exec(ctx)
require.NoError(t, err)

expectedProfiles := []*TestRelProfile{
{ID: 1, Lang: "en", UserID: 1},
{ID: 2, Lang: "ru", UserID: 2},
}

_, err = db.NewInsert().Model(&expectedProfiles).Exec(ctx)
require.NoError(t, err)

expectedDisks := []*TestRelDisk{
{ID: 1, Title: "Nirvana", UserID: 1},
{ID: 2, Title: "Linkin Park", UserID: 2},
}

_, err = db.NewInsert().Model(&expectedDisks).Exec(ctx)
require.NoError(t, err)

// test Has One relation

var users []TestRelUser
err = db.NewSelect().
Model(&users).
Relation("Profile").
Scan(ctx)
require.NoError(t, err)
require.Equal(t, len(expectedUsers), len(users))

// test Has One relation with filter

users = []TestRelUser{}
err = db.NewSelect().
Model(&users).
Relation("Profile", func(q *bun.SelectQuery) *bun.SelectQuery {
return q.Where("?TableAlias.lang = ?", "ru")
}).
Scan(ctx)
require.NoError(t, err)
require.Equal(t, 1, len(users))

// test Has One relation with join on

users = []TestRelUser{}
err = db.NewSelect().
Model(&users).
Relation("Profile", func(q *bun.SelectQuery) *bun.SelectQuery {
return q.JoinOn("?TableAlias.lang = ?", "ru")
}).
OrderExpr("?TableAlias.ID").
Scan(ctx)
require.NoError(t, err)
require.Equal(t, 2, len(users))
require.Nil(t, users[0].Profile)
require.NotNil(t, users[1].Profile)
require.Equal(t, int64(2), users[1].Profile.ID)

// test Has Many relation

users = []TestRelUser{}
err = db.NewSelect().
Model(&users).
Relation("Disks", func(q *bun.SelectQuery) *bun.SelectQuery {
return q.Where("?TableAlias.title = ?", "Linkin Park")
}).
Order("id").
Scan(ctx)
require.NoError(t, err)
require.Equal(t, 0, len(users[0].Disks))
require.Equal(t, 1, len(users[1].Disks))
require.Equal(t, "Linkin Park", users[1].Disks[0].Title)

// test Belongs To relation

var disks []TestRelDisk
err = db.NewSelect().
Model(&disks).
Relation("User", func(q *bun.SelectQuery) *bun.SelectQuery {
return q.Where("?TableAlias.name = ?", "user 2")
}).
Scan(ctx)
require.NoError(t, err)
require.Equal(t, 1, len(disks))
require.Equal(t, "Linkin Park", disks[0].Title)
}
83 changes: 81 additions & 2 deletions relation_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ package bun

import (
"context"
"reflect"

"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
"reflect"
)

type relationJoin struct {
Expand All @@ -16,6 +15,24 @@ type relationJoin struct {

apply func(*SelectQuery) *SelectQuery
columns []schema.QueryWithArgs
joinOn []schema.QueryWithSep
}

type tableAliasArg struct {
j *relationJoin
alias string
}

func (a *tableAliasArg) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) {
if name != "TableAlias" {
return nil, false
}

if a.alias == "" {
return a.j.appendAlias(fmter, b), true
}

return fmter.AppendIdent(b, a.alias), true
}

func (j *relationJoin) applyTo(q *SelectQuery) {
Expand All @@ -25,13 +42,50 @@ func (j *relationJoin) applyTo(q *SelectQuery) {

var table *schema.Table
var columns []schema.QueryWithArgs
var joins []joinQuery

// Save state.
table, q.table = q.table, j.JoinModel.Table()
columns, q.columns = q.columns, nil

oldWhere := q.where

if j.Relation.Type == schema.HasOneRelation || j.Relation.Type == schema.BelongsToRelation {
joins, q.joins = q.joins, []joinQuery{{}}
}

q = j.apply(q)

var newWhere []schema.QueryWithSep

var alias string

if j.Relation.Type == schema.HasManyRelation || j.Relation.Type == schema.ManyToManyRelation {
alias = j.Relation.JoinTable.Alias
}

fmter := q.db.fmter.WithArg(&tableAliasArg{j: j, alias: alias})

for i, w := range q.where {
if i >= len(oldWhere) {
w.Query = string(fmter.AppendQuery([]byte{}, w.Query))
}
newWhere = append(newWhere, w)
}

q.where = newWhere

if j.Relation.Type == schema.HasOneRelation || j.Relation.Type == schema.BelongsToRelation {
var joinOn []schema.QueryWithSep

for _, on := range q.joins[0].on {
on.Query = string(fmter.AppendQuery([]byte{}, on.Query))
joinOn = append(joinOn, on)
}

j.joinOn, q.joins = joinOn, joins
}

// Restore state.
q.table = table
j.columns, q.columns = q.columns, columns
Expand Down Expand Up @@ -271,6 +325,31 @@ func (j *relationJoin) appendHasOneJoin(
}
b = append(b, ')')

if len(j.joinOn) > 0 {
b = append(b, " AND "...)

if len(j.joinOn) > 1 {
b = append(b, '(')
}

for i, on := range j.joinOn {
if i > 0 {
b = append(b, on.Sep...)
}

b = append(b, '(')
b, err = on.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
b = append(b, ')')
}

if len(j.joinOn) > 1 {
b = append(b, ')')
}
}

if isSoftDelete {
b = append(b, " AND "...)
b = j.appendAlias(fmter, b)
Expand Down