Skip to content

Commit

Permalink
Cleanup relation code
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Aug 30, 2021
1 parent 832d5c1 commit 73ea38a
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 75 deletions.
48 changes: 22 additions & 26 deletions join.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ import (
"github.com/uptrace/bun/schema"
)

type join struct {
Parent *join
type relationJoin struct {
Parent *relationJoin
BaseModel tableModel
JoinModel tableModel
Relation *schema.Relation

ApplyQueryFunc func(*SelectQuery) *SelectQuery
columns []schema.QueryWithArgs
apply func(*SelectQuery) *SelectQuery
columns []schema.QueryWithArgs
}

func (j *join) applyQuery(q *SelectQuery) {
if j.ApplyQueryFunc == nil {
func (j *relationJoin) applyTo(q *SelectQuery) {
if j.apply == nil {
return
}

Expand All @@ -30,32 +30,28 @@ func (j *join) applyQuery(q *SelectQuery) {
table, q.table = q.table, j.JoinModel.Table()
columns, q.columns = q.columns, nil

q = j.ApplyQueryFunc(q)
q = j.apply(q)

// Restore state.
q.table = table
j.columns, q.columns = q.columns, columns
}

func (j *join) Select(ctx context.Context, q *SelectQuery) error {
func (j *relationJoin) Select(ctx context.Context, q *SelectQuery) error {
switch j.Relation.Type {
case schema.HasManyRelation:
return j.selectMany(ctx, q)
case schema.ManyToManyRelation:
return j.selectM2M(ctx, q)
}
panic("not reached")
}

func (j *join) selectMany(ctx context.Context, q *SelectQuery) error {
func (j *relationJoin) selectMany(ctx context.Context, q *SelectQuery) error {
q = j.manyQuery(q)
if q == nil {
return nil
}
return q.Scan(ctx)
}

func (j *join) manyQuery(q *SelectQuery) *SelectQuery {
func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery {
hasManyModel := newHasManyModel(j)
if hasManyModel == nil {
return nil
Expand Down Expand Up @@ -86,13 +82,13 @@ func (j *join) manyQuery(q *SelectQuery) *SelectQuery {
q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue)
}

j.applyQuery(q)
j.applyTo(q)
q = q.Apply(j.hasManyColumns)

return q
}

func (j *join) hasManyColumns(q *SelectQuery) *SelectQuery {
func (j *relationJoin) hasManyColumns(q *SelectQuery) *SelectQuery {
if j.Relation.M2MTable != nil {
q = q.ColumnExpr(string(j.Relation.M2MTable.SQLAlias) + ".*")
}
Expand Down Expand Up @@ -122,15 +118,15 @@ func (j *join) hasManyColumns(q *SelectQuery) *SelectQuery {
return q
}

func (j *join) selectM2M(ctx context.Context, q *SelectQuery) error {
func (j *relationJoin) selectM2M(ctx context.Context, q *SelectQuery) error {
q = j.m2mQuery(q)
if q == nil {
return nil
}
return q.Scan(ctx)
}

func (j *join) m2mQuery(q *SelectQuery) *SelectQuery {
func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery {
fmter := q.db.fmter

m2mModel := newM2MModel(j)
Expand Down Expand Up @@ -170,13 +166,13 @@ func (j *join) m2mQuery(q *SelectQuery) *SelectQuery {
j.Relation.M2MTable.SQLAlias, m2mJoinField.SQLName)
}

j.applyQuery(q)
j.applyTo(q)
q = q.Apply(j.hasManyColumns)

return q
}

func (j *join) hasParent() bool {
func (j *relationJoin) hasParent() bool {
if j.Parent != nil {
switch j.Parent.Relation.Type {
case schema.HasOneRelation, schema.BelongsToRelation:
Expand All @@ -186,7 +182,7 @@ func (j *join) hasParent() bool {
return false
}

func (j *join) appendAlias(fmter schema.Formatter, b []byte) []byte {
func (j *relationJoin) appendAlias(fmter schema.Formatter, b []byte) []byte {
quote := fmter.IdentQuote()

b = append(b, quote)
Expand All @@ -195,7 +191,7 @@ func (j *join) appendAlias(fmter schema.Formatter, b []byte) []byte {
return b
}

func (j *join) appendAliasColumn(fmter schema.Formatter, b []byte, column string) []byte {
func (j *relationJoin) appendAliasColumn(fmter schema.Formatter, b []byte, column string) []byte {
quote := fmter.IdentQuote()

b = append(b, quote)
Expand All @@ -206,7 +202,7 @@ func (j *join) appendAliasColumn(fmter schema.Formatter, b []byte, column string
return b
}

func (j *join) appendBaseAlias(fmter schema.Formatter, b []byte) []byte {
func (j *relationJoin) appendBaseAlias(fmter schema.Formatter, b []byte) []byte {
quote := fmter.IdentQuote()

if j.hasParent() {
Expand All @@ -218,7 +214,7 @@ func (j *join) appendBaseAlias(fmter schema.Formatter, b []byte) []byte {
return append(b, j.BaseModel.Table().SQLAlias...)
}

func (j *join) appendSoftDelete(b []byte, flags internal.Flag) []byte {
func (j *relationJoin) appendSoftDelete(b []byte, flags internal.Flag) []byte {
b = append(b, '.')
b = append(b, j.JoinModel.Table().SoftDeleteField.SQLName...)
if flags.Has(deletedFlag) {
Expand All @@ -229,7 +225,7 @@ func (j *join) appendSoftDelete(b []byte, flags internal.Flag) []byte {
return b
}

func appendAlias(b []byte, j *join) []byte {
func appendAlias(b []byte, j *relationJoin) []byte {
if j.hasParent() {
b = appendAlias(b, j.Parent)
b = append(b, "__"...)
Expand All @@ -238,7 +234,7 @@ func appendAlias(b []byte, j *join) []byte {
return b
}

func (j *join) appendHasOneJoin(
func (j *relationJoin) appendHasOneJoin(
fmter schema.Formatter, b []byte, q *SelectQuery,
) (_ []byte, err error) {
isSoftDelete := j.JoinModel.Table().SoftDeleteField != nil && !q.flags.Has(allWithDeletedFlag)
Expand Down
8 changes: 4 additions & 4 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ type tableModel interface {
Table() *schema.Table
Relation() *schema.Relation

Join(string, func(*SelectQuery) *SelectQuery) *join
GetJoin(string) *join
GetJoins() []join
AddJoin(join) *join
Join(string) *relationJoin
GetJoin(string) *relationJoin
GetJoins() []relationJoin
AddJoin(relationJoin) *relationJoin

Root() reflect.Value
ParentIndex() []int
Expand Down
2 changes: 1 addition & 1 deletion model_table_has_many.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type hasManyModel struct {

var _ tableModel = (*hasManyModel)(nil)

func newHasManyModel(j *join) *hasManyModel {
func newHasManyModel(j *relationJoin) *hasManyModel {
baseTable := j.BaseModel.Table()
joinModel := j.JoinModel.(*sliceTableModel)
baseValues := baseValues(joinModel, j.Relation.BaseFields)
Expand Down
2 changes: 1 addition & 1 deletion model_table_m2m.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type m2mModel struct {

var _ tableModel = (*m2mModel)(nil)

func newM2MModel(j *join) *m2mModel {
func newM2MModel(j *relationJoin) *m2mModel {
baseTable := j.BaseModel.Table()
joinModel := j.JoinModel.(*sliceTableModel)
baseValues := baseValues(joinModel, baseTable.PKs)
Expand Down
4 changes: 2 additions & 2 deletions model_table_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ func (m *sliceTableModel) init(sliceType reflect.Type) {
}
}

func (m *sliceTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join {
return m.join(m.slice, name, apply)
func (m *sliceTableModel) Join(name string) *relationJoin {
return m.join(m.slice, name)
}

func (m *sliceTableModel) Bind(bind reflect.Value) {
Expand Down
28 changes: 9 additions & 19 deletions model_table_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type structTableModel struct {
table *schema.Table

rel *schema.Relation
joins []join
joins []relationJoin

dest interface{}
root reflect.Value
Expand Down Expand Up @@ -151,7 +151,7 @@ func (m *structTableModel) AfterScan(ctx context.Context) error {
return firstErr
}

func (m *structTableModel) GetJoin(name string) *join {
func (m *structTableModel) GetJoin(name string) *relationJoin {
for i := range m.joins {
j := &m.joins[i]
if j.Relation.Field.Name == name || j.Relation.Field.GoName == name {
Expand All @@ -161,30 +161,28 @@ func (m *structTableModel) GetJoin(name string) *join {
return nil
}

func (m *structTableModel) GetJoins() []join {
func (m *structTableModel) GetJoins() []relationJoin {
return m.joins
}

func (m *structTableModel) AddJoin(j join) *join {
func (m *structTableModel) AddJoin(j relationJoin) *relationJoin {
m.joins = append(m.joins, j)
return &m.joins[len(m.joins)-1]
}

func (m *structTableModel) Join(name string, apply func(*SelectQuery) *SelectQuery) *join {
return m.join(m.strct, name, apply)
func (m *structTableModel) Join(name string) *relationJoin {
return m.join(m.strct, name)
}

func (m *structTableModel) join(
bind reflect.Value, name string, apply func(*SelectQuery) *SelectQuery,
) *join {
func (m *structTableModel) join(bind reflect.Value, name string) *relationJoin {
path := strings.Split(name, ".")
index := make([]int, 0, len(path))

currJoin := join{
currJoin := relationJoin{
BaseModel: m,
JoinModel: m,
}
var lastJoin *join
var lastJoin *relationJoin

for _, name := range path {
relation, ok := currJoin.JoinModel.Table().Relations[name]
Expand Down Expand Up @@ -214,14 +212,6 @@ func (m *structTableModel) join(
}
}

// No joins with such name.
if lastJoin == nil {
return nil
}
if apply != nil {
lastJoin.ApplyQueryFunc = apply
}

return lastJoin
}

Expand Down
48 changes: 26 additions & 22 deletions query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,41 +286,38 @@ func (q *SelectQuery) joinOn(cond string, args []interface{}, sep string) *Selec

//------------------------------------------------------------------------------

// Relation adds a relation to the query. Relation name can be:
// - RelationName to select all columns,
// - RelationName.column_name,
// - RelationName._ to join relation without selecting relation columns.
// Relation adds a relation to the query.
func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQuery) *SelectQuery {
if len(apply) > 1 {
panic("only one apply function is supported")
}

if q.tableModel == nil {
q.setErr(errNilModel)
return q
}

var fn func(*SelectQuery) *SelectQuery

if len(apply) == 1 {
fn = apply[0]
} else if len(apply) > 1 {
panic("only one apply function is supported")
}

join := q.tableModel.Join(name, fn)
join := q.tableModel.Join(name)
if join == nil {
q.setErr(fmt.Errorf("%s does not have relation=%q", q.table, name))
return q
}

if len(apply) == 1 {
join.apply = apply[0]
}

return q
}

func (q *SelectQuery) forEachHasOneJoin(fn func(*join) error) error {
func (q *SelectQuery) forEachHasOneJoin(fn func(*relationJoin) error) error {
if q.tableModel == nil {
return nil
}
return q._forEachHasOneJoin(fn, q.tableModel.GetJoins())
}

func (q *SelectQuery) _forEachHasOneJoin(fn func(*join) error, joins []join) error {
func (q *SelectQuery) _forEachHasOneJoin(fn func(*relationJoin) error, joins []relationJoin) error {
for i := range joins {
j := &joins[i]
switch j.Relation.Type {
Expand All @@ -336,16 +333,23 @@ func (q *SelectQuery) _forEachHasOneJoin(fn func(*join) error, joins []join) err
return nil
}

func (q *SelectQuery) selectJoins(ctx context.Context, joins []join) error {
var err error
func (q *SelectQuery) selectJoins(ctx context.Context, joins []relationJoin) error {
for i := range joins {
j := &joins[i]

var err error

switch j.Relation.Type {
case schema.HasOneRelation, schema.BelongsToRelation:
err = q.selectJoins(ctx, j.JoinModel.GetJoins())
case schema.HasManyRelation:
err = j.selectMany(ctx, q.db.NewSelect())
case schema.ManyToManyRelation:
err = j.selectM2M(ctx, q.db.NewSelect())
default:
err = j.Select(ctx, q.db.NewSelect())
panic("not reached")
}

if err != nil {
return err
}
Expand Down Expand Up @@ -415,7 +419,7 @@ func (q *SelectQuery) appendQuery(
}
}

if err := q.forEachHasOneJoin(func(j *join) error {
if err := q.forEachHasOneJoin(func(j *relationJoin) error {
b = append(b, ' ')
b, err = j.appendHasOneJoin(fmter, b, q)
return err
Expand Down Expand Up @@ -545,7 +549,7 @@ func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte,
b = append(b, '*')
}

if err := q.forEachHasOneJoin(func(j *join) error {
if err := q.forEachHasOneJoin(func(j *relationJoin) error {
if len(b) != start {
b = append(b, ", "...)
start = len(b)
Expand All @@ -567,9 +571,9 @@ func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte,
}

func (q *SelectQuery) appendHasOneColumns(
fmter schema.Formatter, b []byte, join *join,
fmter schema.Formatter, b []byte, join *relationJoin,
) (_ []byte, err error) {
join.applyQuery(q)
join.applyTo(q)

if join.columns != nil {
for i, col := range join.columns {
Expand Down

0 comments on commit 73ea38a

Please sign in to comment.