Skip to content

Commit

Permalink
[release-18.0] Planner Bug: Joins inside derived table (#14974) (#16962)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay authored Oct 17, 2024
1 parent ce4e22d commit afb135d
Show file tree
Hide file tree
Showing 22 changed files with 407 additions and 326 deletions.
13 changes: 13 additions & 0 deletions go/test/endtoend/vtgate/queries/subquery/subquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,16 @@ func TestSubqueryInAggregation(t *testing.T) {
// This fails as the planner adds `weight_string` method which make the query fail on MySQL.
// mcmp.Exec(`SELECT max((select min(id2) from t1 where t1.id1 = t.id1)) FROM t1 t`)
}

// TestSubqueryInDerivedTable tests that subqueries and derived tables
// are handled correctly when there are joins inside the derived table
func TestSubqueryInDerivedTable(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 18, "vtgate")
mcmp, closer := start(t)
defer closer()

mcmp.Exec("INSERT INTO t1 (id1, id2) VALUES (1, 100), (2, 200), (3, 300), (4, 400), (5, 500);")
mcmp.Exec("INSERT INTO t2 (id3, id4) VALUES (10, 1), (20, 2), (30, 3), (40, 4), (50, 99)")
mcmp.Exec(`select t.a from (select t1.id2, t2.id3, (select id2 from t1 order by id2 limit 1) as a from t1 join t2 on t1.id1 = t2.id4) t`)
mcmp.Exec(`SELECT COUNT(*) FROM (SELECT DISTINCT t1.id1 FROM t1 JOIN t2 ON t1.id1 = t2.id4) dt`)
}
30 changes: 15 additions & 15 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2821,7 +2821,7 @@ func TestCrossShardSubquery(t *testing.T) {
utils.MustMatch(t, wantQueries, sbc1.Queries)

wantQueries = []*querypb.BoundQuery{{
Sql: "select 1 from (select u2.id from `user` as u2 where u2.id = :u1_col) as t",
Sql: "select 1 from (select u2.id as id from `user` as u2 where u2.id = :u1_col) as t",
BindVariables: map[string]*querypb.BindVariable{"u1_col": sqltypes.Int32BindVariable(3)},
}}
utils.MustMatch(t, wantQueries, sbc2.Queries)
Expand Down Expand Up @@ -2896,7 +2896,7 @@ func TestCrossShardSubqueryStream(t *testing.T) {
}}
utils.MustMatch(t, wantQueries, sbc1.Queries)
wantQueries = []*querypb.BoundQuery{{
Sql: "select 1 from (select u2.id from `user` as u2 where u2.id = :u1_col) as t",
Sql: "select 1 from (select u2.id as id from `user` as u2 where u2.id = :u1_col) as t",
BindVariables: map[string]*querypb.BindVariable{"u1_col": sqltypes.Int32BindVariable(3)},
}}
utils.MustMatch(t, wantQueries, sbc2.Queries)
Expand Down Expand Up @@ -2937,7 +2937,7 @@ func TestCrossShardSubqueryGetFields(t *testing.T) {
Sql: "select t.id1, t.`u1.col` from (select u1.id as id1, u1.col as `u1.col` from `user` as u1 where 1 != 1) as t where 1 != 1",
BindVariables: map[string]*querypb.BindVariable{},
}, {
Sql: "select 1 from (select u2.id from `user` as u2 where 1 != 1) as t where 1 != 1",
Sql: "select 1 from (select u2.id as id from `user` as u2 where 1 != 1) as t where 1 != 1",
BindVariables: map[string]*querypb.BindVariable{
"u1_col": sqltypes.NullBindVariable,
},
Expand Down Expand Up @@ -3847,14 +3847,14 @@ func TestSelectAggregationNoData(t *testing.T) {
{
sql: `select count(*) from (select col1, col2 from user limit 2) x`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1", "int64|int64|int64")),
expSandboxQ: "select col1, col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"count(*)" type:INT64]`,
expRow: `[[INT64(0)]]`,
},
{
sql: `select col2, count(*) from (select col1, col2 from user limit 2) x group by col2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1|weight_string(col2)", "int64|int64|int64|varbinary")),
expSandboxQ: "select col1, col2, 1, weight_string(col2) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, 1, weight_string(x.col2) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`,
expRow: `[]`,
},
Expand Down Expand Up @@ -3939,70 +3939,70 @@ func TestSelectAggregationData(t *testing.T) {
{
sql: `select count(*) from (select col1, col2 from user limit 2) x`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1", "int64|int64|int64"), "100|200|1", "200|300|1"),
expSandboxQ: "select col1, col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"count(*)" type:INT64]`,
expRow: `[[INT64(2)]]`,
},
{
sql: `select col2, count(*) from (select col1, col2 from user limit 9) x group by col2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1|weight_string(col2)", "int64|int64|int64|varbinary"), "100|3|1|NULL", "200|2|1|NULL"),
expSandboxQ: "select col1, col2, 1, weight_string(col2) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, 1, weight_string(x.col2) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`,
expRow: `[[INT64(2) INT64(4)] [INT64(3) INT64(5)]]`,
},
{
sql: `select count(col1) from (select id, col1 from user limit 2) x`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1", "int64|varchar"), "1|a", "2|b"),
expSandboxQ: "select id, col1 from (select id, col1 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.id, x.col1 from (select id, col1 from `user`) as x limit :__upper_limit",
expField: `[name:"count(col1)" type:INT64]`,
expRow: `[[INT64(2)]]`,
},
{
sql: `select count(col1), col2 from (select col2, col1 from user limit 9) x group by col2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col2|col1|weight_string(col2)", "int64|varchar|varbinary"), "3|a|NULL", "2|b|NULL"),
expSandboxQ: "select col2, col1, weight_string(col2) from (select col2, col1 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col2, x.col1, weight_string(x.col2) from (select col2, col1 from `user`) as x limit :__upper_limit",
expField: `[name:"count(col1)" type:INT64 name:"col2" type:INT64]`,
expRow: `[[INT64(4) INT64(2)] [INT64(5) INT64(3)]]`,
},
{
sql: `select col1, count(col2) from (select col1, col2 from user limit 9) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "a|1|a", "b|null|b"),
expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col1" type:VARCHAR name:"count(col2)" type:INT64]`,
expRow: `[[VARCHAR("a") INT64(5)] [VARCHAR("b") INT64(0)]]`,
},
{
sql: `select col1, count(col2) from (select col1, col2 from user limit 32) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "null|1|null", "null|null|null", "a|1|a", "b|null|b"),
expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col1" type:VARCHAR name:"count(col2)" type:INT64]`,
expRow: `[[NULL INT64(8)] [VARCHAR("a") INT64(8)] [VARCHAR("b") INT64(0)]]`,
},
{
sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "a|3|a"),
expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:DECIMAL]`,
expRow: `[[VARCHAR("a") DECIMAL(12)]]`,
},
{
sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|2|a"),
expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`,
expRow: `[[VARCHAR("a") FLOAT64(8)]]`,
},
{
sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|x|a"),
expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`,
expRow: `[[VARCHAR("a") FLOAT64(0)]]`,
},
{
sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|null|a"),
expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`,
expRow: `[[VARCHAR("a") NULL]]`,
},
Expand Down
17 changes: 13 additions & 4 deletions go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"slices"
"sort"

"vitess.io/vitess/go/slice"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops"
Expand Down Expand Up @@ -562,25 +564,32 @@ func buildProjection(op *Projection, qb *queryBuilder) error {
}

func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) error {
predicates := slice.Map(op.JoinPredicates, func(jc JoinColumn) sqlparser.Expr {
// since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done
qb.ctx.SkipPredicates[jc.RHSExpr] = nil

return jc.Original.Expr
})
pred := sqlparser.AndExpressions(predicates...)
err := buildQuery(op.LHS, qb)
if err != nil {
return err
}
// If we are going to add the predicate used in join here
// We should not add the predicate's copy of when it was split into
// two parts. To avoid this, we use the SkipPredicates map.
for _, expr := range qb.ctx.JoinPredicates[op.Predicate] {
qb.ctx.SkipPredicates[expr] = nil
for _, pred := range op.JoinPredicates {
qb.ctx.SkipPredicates[pred.RHSExpr] = nil
}
qbR := &queryBuilder{ctx: qb.ctx}
err = buildQuery(op.RHS, qbR)
if err != nil {
return err
}
if op.LeftJoin {
qb.joinOuterWith(qbR, op.Predicate)
qb.joinOuterWith(qbR, pred)
} else {
qb.joinInnerWith(qbR, op.Predicate)
qb.joinInnerWith(qbR, pred)
}
return nil
}
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ func (a *Aggregator) isDerived() bool {
return a.DT != nil
}

func (a *Aggregator) derivedName() string {
if a.DT == nil {
return ""
}

return a.DT.Alias
}

func (a *Aggregator) FindCol(ctx *plancontext.PlanningContext, in sqlparser.Expr, underRoute bool) (int, error) {
if underRoute && a.isDerived() {
// We don't want to use columns on this operator if it's a derived table under a route.
Expand Down
50 changes: 32 additions & 18 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ type (
// LeftJoin will be true in the case of an outer join
LeftJoin bool

// Before offset planning
Predicate sqlparser.Expr

// JoinColumns keeps track of what AST expression is represented in the Columns array
JoinColumns []JoinColumn

Expand Down Expand Up @@ -86,14 +83,19 @@ type (
}
)

func NewApplyJoin(lhs, rhs ops.Operator, predicate sqlparser.Expr, leftOuterJoin bool) *ApplyJoin {
return &ApplyJoin{
LHS: lhs,
RHS: rhs,
Vars: map[string]int{},
Predicate: predicate,
LeftJoin: leftOuterJoin,
func NewApplyJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, predicate sqlparser.Expr, leftOuterJoin bool) (*ApplyJoin, error) {
aj := &ApplyJoin{
LHS: lhs,
RHS: rhs,
Vars: map[string]int{},
LeftJoin: leftOuterJoin,
}
err := aj.AddJoinPredicate(ctx, predicate)
if err != nil {
return nil, err
}

return aj, nil
}

// Clone implements the Operator interface
Expand All @@ -105,7 +107,6 @@ func (aj *ApplyJoin) Clone(inputs []ops.Operator) ops.Operator {
kopy.JoinColumns = slices.Clone(aj.JoinColumns)
kopy.JoinPredicates = slices.Clone(aj.JoinPredicates)
kopy.Vars = maps.Clone(aj.Vars)
kopy.Predicate = sqlparser.CloneExpr(aj.Predicate)
kopy.ExtraLHSVars = slices.Clone(aj.ExtraLHSVars)
return &kopy
}
Expand Down Expand Up @@ -149,8 +150,9 @@ func (aj *ApplyJoin) IsInner() bool {
}

func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) error {
aj.Predicate = ctx.SemTable.AndExpressions(expr, aj.Predicate)

if expr == nil {
return nil
}
col, err := BreakExpressionInLHSandRHS(ctx, expr, TableID(aj.LHS))
if err != nil {
return err
Expand Down Expand Up @@ -312,11 +314,15 @@ func (aj *ApplyJoin) addOffset(offset int) {
}

func (aj *ApplyJoin) ShortDescription() string {
pred := sqlparser.String(aj.Predicate)
columns := slice.Map(aj.JoinColumns, func(from JoinColumn) string {
return sqlparser.String(from.Original)
})
firstPart := fmt.Sprintf("on %s columns: %s", pred, strings.Join(columns, ", "))
fn := func(cols []JoinColumn) string {
out := slice.Map(cols, func(jc JoinColumn) string {
return jc.String()
})
return strings.Join(out, ", ")
}

firstPart := fmt.Sprintf("on %s columns: %s", fn(aj.JoinPredicates), fn(aj.JoinColumns))

if len(aj.ExtraLHSVars) == 0 {
return firstPart
}
Expand Down Expand Up @@ -419,6 +425,14 @@ func (jc JoinColumn) IsMixedLeftAndRight() bool {
return len(jc.LHSExprs) > 0 && jc.RHSExpr != nil
}

func (jc JoinColumn) String() string {
rhs := sqlparser.String(jc.RHSExpr)
lhs := slice.Map(jc.LHSExprs, func(e BindVarExpr) string {
return sqlparser.String(e.Expr)
})
return fmt.Sprintf("[%s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original))
}

func (bve BindVarExpr) String() string {
if bve.Name == "" {
return sqlparser.String(bve.Expr)
Expand Down
1 change: 1 addition & 0 deletions go/vt/vtgate/planbuilder/operators/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func BreakExpressionInLHSandRHS(
expr sqlparser.Expr,
lhs semantics.TableSet,
) (col JoinColumn, err error) {
col.Original = aeWrap(expr)
rewrittenExpr := sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) {
nodeExpr, ok := cursor.Node().(sqlparser.Expr)
if !ok || !fetchByOffset(nodeExpr) {
Expand Down
10 changes: 7 additions & 3 deletions go/vt/vtgate/planbuilder/operators/join_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,17 @@ func mergeShardedRouting(r1 *ShardedRouting, r2 *ShardedRouting) *ShardedRouting
return tr
}

func (jm *joinMerger) getApplyJoin(ctx *plancontext.PlanningContext, op1, op2 *Route) *ApplyJoin {
return NewApplyJoin(op1.Source, op2.Source, ctx.SemTable.AndExpressions(jm.predicates...), !jm.innerJoin)
func (jm *joinMerger) getApplyJoin(ctx *plancontext.PlanningContext, op1, op2 *Route) (*ApplyJoin, error) {
return NewApplyJoin(ctx, op1.Source, op2.Source, ctx.SemTable.AndExpressions(jm.predicates...), !jm.innerJoin)
}

func (jm *joinMerger) merge(ctx *plancontext.PlanningContext, op1, op2 *Route, r Routing) (*Route, error) {
join, err := jm.getApplyJoin(ctx, op1, op2)
if err != nil {
return nil, err
}
return &Route{
Source: jm.getApplyJoin(ctx, op1, op2),
Source: join,
MergedWith: []*Route{op2},
Routing: r,
}, nil
Expand Down
6 changes: 6 additions & 0 deletions go/vt/vtgate/planbuilder/operators/offset_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
// planOffsets will walk the tree top down, adding offset information to columns in the tree for use in further optimization,
func planOffsets(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Operator, error) {
type offsettable interface {
ops.Operator
planOffsets(ctx *plancontext.PlanningContext) error
}

Expand All @@ -40,6 +41,11 @@ func planOffsets(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Opera
return nil, nil, vterrors.VT13001(fmt.Sprintf("should not see %T here", in))
case offsettable:
err = op.planOffsets(ctx)
if rewrite.DebugOperatorTree {
fmt.Println("Planned offsets for:")
fmt.Println(ops.ToTree(op))
}

}
if err != nil {
return nil, nil, err
Expand Down
Loading

0 comments on commit afb135d

Please sign in to comment.