From 782cb9bbba48cd2d97d42a36ea4deec93dc0fe4e Mon Sep 17 00:00:00 2001 From: Denis Gukov Date: Mon, 20 Jun 2022 16:25:23 +0500 Subject: [PATCH 1/3] feat(relations): support macro ?TableAlias in relation callback --- internal/dbtest/relation_join_test.go | 131 ++++++++++++++++++++++++++ relation_join.go | 29 +++++- 2 files changed, 158 insertions(+), 2 deletions(-) create mode 100644 internal/dbtest/relation_join_test.go diff --git a/internal/dbtest/relation_join_test.go b/internal/dbtest/relation_join_test.go new file mode 100644 index 000000000..df67ba82f --- /dev/null +++ b/internal/dbtest/relation_join_test.go @@ -0,0 +1,131 @@ +package dbtest_test + +import ( + "context" + "database/sql" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/sqlitedialect" + "github.com/uptrace/bun/driver/sqliteshim" + "github.com/uptrace/bun/extra/bundebug" + "testing" +) + +type TestRelProfile struct { + ID int64 `bun:",pk,autoincrement"` + Lang string + UserID int64 +} + +type TestRelUser struct { + ID int64 `bun:",pk,autoincrement"` + Name string + Profile *TestRelProfile `bun:"rel:has-one,join:id=user_id"` + Disks []TestRelDisk `bun:"rel:has-many,join:id=user_id"` +} + +type TestRelDisk struct { + ID int64 `bun:",pk,autoincrement"` + Title string + UserID int64 + User *TestRelUser `bun:"rel:belongs-to,join:user_id=id"` +} + +func TestRelationJoin(t *testing.T) { + + ctx := context.Background() + + sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared") + if err != nil { + panic(err) + } + + db := bun.NewDB(sqldb, sqlitedialect.New()) + defer db.Close() + + db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true))) + + // Create schema + + models := []interface{}{ + (*TestRelUser)(nil), + (*TestRelProfile)(nil), + (*TestRelDisk)(nil), + } + for _, model := range models { + _, err = db.NewCreateTable().Model(model).Exec(ctx) + require.NoError(t, err) + } + + expectedUsers := []*TestRelUser{ + {ID: 1, Name: "user 1"}, + {ID: 2, Name: "user 2"}, + } + + _, err = db.NewInsert().Model(&expectedUsers).Exec(ctx) + require.NoError(t, err) + + expectedProfiles := []*TestRelProfile{ + {ID: 1, Lang: "en", UserID: 1}, + {ID: 2, Lang: "ru", UserID: 2}, + } + + _, err = db.NewInsert().Model(&expectedProfiles).Exec(ctx) + require.NoError(t, err) + + expectedDisks := []*TestRelDisk{ + {ID: 1, Title: "Nirvana", UserID: 1}, + {ID: 2, Title: "Linkin Park", UserID: 2}, + } + + _, err = db.NewInsert().Model(&expectedDisks).Exec(ctx) + require.NoError(t, err) + + // test Has One relation + + var users []TestRelUser + err = db.NewSelect(). + Model(&users). + Relation("Profile"). + Scan(ctx) + require.NoError(t, err) + require.Equal(t, len(expectedUsers), len(users)) + + users = []TestRelUser{} + err = db.NewSelect(). + Model(&users). + Relation("Profile", func(q *bun.SelectQuery) *bun.SelectQuery { + return q.Where("?TableAlias.lang = ?", "ru") + }). + Scan(ctx) + require.NoError(t, err) + require.Equal(t, 1, len(users)) + + // test Has Many relation + + users = []TestRelUser{} + err = db.NewSelect(). + Model(&users). + Relation("Disks", func(q *bun.SelectQuery) *bun.SelectQuery { + return q.Where("?TableAlias.title = ?", "Linkin Park") + }). + Order("id"). + Scan(ctx) + require.NoError(t, err) + require.Equal(t, 0, len(users[0].Disks)) + require.Equal(t, 1, len(users[1].Disks)) + require.Equal(t, "Linkin Park", users[1].Disks[0].Title) + + // test Belongs To relation + + var disks []TestRelDisk + err = db.NewSelect(). + Model(&disks). + Relation("User", func(q *bun.SelectQuery) *bun.SelectQuery { + return q.Where("?TableAlias.name = ?", "user 2") + }). + Scan(ctx) + require.NoError(t, err) + require.Equal(t, 1, len(disks)) + require.Equal(t, "Linkin Park", disks[0].Title) +} diff --git a/relation_join.go b/relation_join.go index e8074e0c6..cf58482a9 100644 --- a/relation_join.go +++ b/relation_join.go @@ -2,10 +2,10 @@ package bun import ( "context" - "reflect" - "github.com/uptrace/bun/internal" "github.com/uptrace/bun/schema" + "reflect" + "regexp" ) type relationJoin struct { @@ -30,8 +30,33 @@ func (j *relationJoin) applyTo(q *SelectQuery) { table, q.table = q.table, j.JoinModel.Table() columns, q.columns = q.columns, nil + oldWhere := q.where + q = j.apply(q) + var newWhere []schema.QueryWithSep + + var alias string + switch j.Relation.Type { + case schema.HasOneRelation, schema.BelongsToRelation: + alias = string(j.appendAlias(q.db.fmter, []byte{})) + case schema.HasManyRelation: + alias = j.Relation.JoinTable.Alias + case schema.ManyToManyRelation: + alias = j.Relation.JoinTable.Alias + } + + var re = regexp.MustCompile(`\?TableAlias\b`) + + for i, w := range q.where { + if i >= len(oldWhere) { + w.Query = re.ReplaceAllString(w.Query, " "+alias) + } + newWhere = append(newWhere, w) + } + + q.where = newWhere + // Restore state. q.table = table j.columns, q.columns = q.columns, columns From 4b04606185cf9d91c1bae6ec228b33445f28a7d6 Mon Sep 17 00:00:00 2001 From: Denis Gukov Date: Mon, 27 Jun 2022 14:26:44 +0500 Subject: [PATCH 2/3] fix(relations): use formatter instaed of regexp --- relation_join.go | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/relation_join.go b/relation_join.go index cf58482a9..f8a46eb9b 100644 --- a/relation_join.go +++ b/relation_join.go @@ -5,7 +5,6 @@ import ( "github.com/uptrace/bun/internal" "github.com/uptrace/bun/schema" "reflect" - "regexp" ) type relationJoin struct { @@ -18,6 +17,23 @@ type relationJoin struct { columns []schema.QueryWithArgs } +type tableAliasArg struct { + j *relationJoin + alias string +} + +func (a *tableAliasArg) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) { + if name != "TableAlias" { + return nil, false + } + + if a.alias == "" { + return a.j.appendAlias(fmter, b), true + } + + return fmter.AppendIdent(b, a.alias), true +} + func (j *relationJoin) applyTo(q *SelectQuery) { if j.apply == nil { return @@ -37,20 +53,16 @@ func (j *relationJoin) applyTo(q *SelectQuery) { var newWhere []schema.QueryWithSep var alias string - switch j.Relation.Type { - case schema.HasOneRelation, schema.BelongsToRelation: - alias = string(j.appendAlias(q.db.fmter, []byte{})) - case schema.HasManyRelation: - alias = j.Relation.JoinTable.Alias - case schema.ManyToManyRelation: + + if j.Relation.Type == schema.HasManyRelation || j.Relation.Type == schema.ManyToManyRelation { alias = j.Relation.JoinTable.Alias } - var re = regexp.MustCompile(`\?TableAlias\b`) + fmter := q.db.fmter.WithArg(&tableAliasArg{j: j, alias: alias}) for i, w := range q.where { if i >= len(oldWhere) { - w.Query = re.ReplaceAllString(w.Query, " "+alias) + w.Query = string(fmter.AppendQuery([]byte{}, w.Query)) } newWhere = append(newWhere, w) } From 056a7935340065fc70686fdf8582bc38ce15752b Mon Sep 17 00:00:00 2001 From: Denis Gukov Date: Mon, 27 Jun 2022 21:35:33 +0500 Subject: [PATCH 3/3] feat(repations): support JoinOn in relation callback --- internal/dbtest/relation_join_test.go | 18 ++++++++++++ relation_join.go | 42 +++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/internal/dbtest/relation_join_test.go b/internal/dbtest/relation_join_test.go index df67ba82f..51668e41c 100644 --- a/internal/dbtest/relation_join_test.go +++ b/internal/dbtest/relation_join_test.go @@ -91,6 +91,8 @@ func TestRelationJoin(t *testing.T) { require.NoError(t, err) require.Equal(t, len(expectedUsers), len(users)) + // test Has One relation with filter + users = []TestRelUser{} err = db.NewSelect(). Model(&users). @@ -101,6 +103,22 @@ func TestRelationJoin(t *testing.T) { require.NoError(t, err) require.Equal(t, 1, len(users)) + // test Has One relation with join on + + users = []TestRelUser{} + err = db.NewSelect(). + Model(&users). + Relation("Profile", func(q *bun.SelectQuery) *bun.SelectQuery { + return q.JoinOn("?TableAlias.lang = ?", "ru") + }). + OrderExpr("?TableAlias.ID"). + Scan(ctx) + require.NoError(t, err) + require.Equal(t, 2, len(users)) + require.Nil(t, users[0].Profile) + require.NotNil(t, users[1].Profile) + require.Equal(t, int64(2), users[1].Profile.ID) + // test Has Many relation users = []TestRelUser{} diff --git a/relation_join.go b/relation_join.go index f8a46eb9b..30b2c1c9f 100644 --- a/relation_join.go +++ b/relation_join.go @@ -15,6 +15,7 @@ type relationJoin struct { apply func(*SelectQuery) *SelectQuery columns []schema.QueryWithArgs + joinOn []schema.QueryWithSep } type tableAliasArg struct { @@ -41,6 +42,7 @@ func (j *relationJoin) applyTo(q *SelectQuery) { var table *schema.Table var columns []schema.QueryWithArgs + var joins []joinQuery // Save state. table, q.table = q.table, j.JoinModel.Table() @@ -48,6 +50,10 @@ func (j *relationJoin) applyTo(q *SelectQuery) { oldWhere := q.where + if j.Relation.Type == schema.HasOneRelation || j.Relation.Type == schema.BelongsToRelation { + joins, q.joins = q.joins, []joinQuery{{}} + } + q = j.apply(q) var newWhere []schema.QueryWithSep @@ -69,6 +75,17 @@ func (j *relationJoin) applyTo(q *SelectQuery) { q.where = newWhere + if j.Relation.Type == schema.HasOneRelation || j.Relation.Type == schema.BelongsToRelation { + var joinOn []schema.QueryWithSep + + for _, on := range q.joins[0].on { + on.Query = string(fmter.AppendQuery([]byte{}, on.Query)) + joinOn = append(joinOn, on) + } + + j.joinOn, q.joins = joinOn, joins + } + // Restore state. q.table = table j.columns, q.columns = q.columns, columns @@ -308,6 +325,31 @@ func (j *relationJoin) appendHasOneJoin( } b = append(b, ')') + if len(j.joinOn) > 0 { + b = append(b, " AND "...) + + if len(j.joinOn) > 1 { + b = append(b, '(') + } + + for i, on := range j.joinOn { + if i > 0 { + b = append(b, on.Sep...) + } + + b = append(b, '(') + b, err = on.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + b = append(b, ')') + } + + if len(j.joinOn) > 1 { + b = append(b, ')') + } + } + if isSoftDelete { b = append(b, " AND "...) b = j.appendAlias(fmter, b)