Skip to content

Commit

Permalink
Planner Bug: Joins inside derived table (#14974)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
Signed-off-by: Florent Poinsard <[email protected]>
Signed-off-by: Andrés Taylor <[email protected]>
Co-authored-by: Florent Poinsard <[email protected]>
Co-authored-by: Florent Poinsard <[email protected]>
  • Loading branch information
3 people authored Feb 8, 2024
1 parent fa8b5ea commit 3a98b4c
Show file tree
Hide file tree
Showing 28 changed files with 935 additions and 589 deletions.
15 changes: 13 additions & 2 deletions go/test/endtoend/vtgate/queries/subquery/subquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ func TestSubqueryInINClause(t *testing.T) {
}

func TestSubqueryInUpdate(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 14, "vtgate")
mcmp, closer := start(t)
defer closer()

Expand All @@ -131,7 +130,6 @@ func TestSubqueryInUpdate(t *testing.T) {
}

func TestSubqueryInReference(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 14, "vtgate")
mcmp, closer := start(t)
defer closer()

Expand Down Expand Up @@ -177,3 +175,16 @@ func TestSubqueryInAggregation(t *testing.T) {
// This fails as the planner adds `weight_string` method which make the query fail on MySQL.
// mcmp.Exec(`SELECT max((select min(id2) from t1 where t1.id1 = t.id1)) FROM t1 t`)
}

// TestSubqueryInDerivedTable tests that subqueries and derived tables
// are handled correctly when there are joins inside the derived table
func TestSubqueryInDerivedTable(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate")
mcmp, closer := start(t)
defer closer()

mcmp.Exec("INSERT INTO t1 (id1, id2) VALUES (1, 100), (2, 200), (3, 300), (4, 400), (5, 500);")
mcmp.Exec("INSERT INTO t2 (id3, id4) VALUES (10, 1), (20, 2), (30, 3), (40, 4), (50, 99)")
mcmp.Exec(`select t.a from (select t1.id2, t2.id3, (select id2 from t1 order by id2 limit 1) as a from t1 join t2 on t1.id1 = t2.id4) t`)
mcmp.Exec(`SELECT COUNT(*) FROM (SELECT DISTINCT t1.id1 FROM t1 JOIN t2 ON t1.id1 = t2.id4) dt`)
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ func TestMain(m *testing.M) {

func TestAddColumn(t *testing.T) {
defer cluster.PanicHandler(t)
utils.SkipIfBinaryIsBelowVersion(t, 14, "vtgate")
ctx := context.Background()
conn, err := mysql.Connect(ctx, &vtParams)
require.NoError(t, err)
Expand Down
1 change: 1 addition & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ func (node TableName) IsEmpty() bool {
// If Name is empty, Qualifier is also empty.
return node.Name.IsEmpty()
}
func (node TableName) NonEmpty() bool { return !node.Name.IsEmpty() }

// NewWhere creates a WHERE or HAVING clause out
// of a Expr. If the expression is nil, it returns nil.
Expand Down
24 changes: 12 additions & 12 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3914,14 +3914,14 @@ func TestSelectAggregationNoData(t *testing.T) {
{
sql: `select count(*) from (select col1, col2 from user limit 2) x`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1", "int64|int64|int64")),
expSandboxQ: "select col1, col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"count(*)" type:INT64]`,
expRow: `[[INT64(0)]]`,
},
{
sql: `select col2, count(*) from (select col1, col2 from user limit 2) x group by col2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1|weight_string(col2)", "int64|int64|int64|varbinary")),
expSandboxQ: "select col1, col2, 1, weight_string(col2) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, 1, weight_string(x.col2) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`,
expRow: `[]`,
},
Expand Down Expand Up @@ -4006,70 +4006,70 @@ func TestSelectAggregationData(t *testing.T) {
{
sql: `select count(*) from (select col1, col2 from user limit 2) x`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1", "int64|int64|int64"), "100|200|1", "200|300|1"),
expSandboxQ: "select col1, col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"count(*)" type:INT64]`,
expRow: `[[INT64(2)]]`,
},
{
sql: `select col2, count(*) from (select col1, col2 from user limit 9) x group by col2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1|weight_string(col2)", "int64|int64|int64|varbinary"), "100|3|1|NULL", "200|2|1|NULL"),
expSandboxQ: "select col1, col2, 1, weight_string(col2) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, 1, weight_string(x.col2) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`,
expRow: `[[INT64(2) INT64(4)] [INT64(3) INT64(5)]]`,
},
{
sql: `select count(col1) from (select id, col1 from user limit 2) x`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1", "int64|varchar"), "1|a", "2|b"),
expSandboxQ: "select id, col1 from (select id, col1 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.id, x.col1 from (select id, col1 from `user`) as x limit :__upper_limit",
expField: `[name:"count(col1)" type:INT64]`,
expRow: `[[INT64(2)]]`,
},
{
sql: `select count(col1), col2 from (select col2, col1 from user limit 9) x group by col2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col2|col1|weight_string(col2)", "int64|varchar|varbinary"), "3|a|NULL", "2|b|NULL"),
expSandboxQ: "select col2, col1, weight_string(col2) from (select col2, col1 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col2, x.col1, weight_string(x.col2) from (select col2, col1 from `user`) as x limit :__upper_limit",
expField: `[name:"count(col1)" type:INT64 name:"col2" type:INT64]`,
expRow: `[[INT64(4) INT64(2)] [INT64(5) INT64(3)]]`,
},
{
sql: `select col1, count(col2) from (select col1, col2 from user limit 9) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "a|1|a", "b|null|b"),
expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col1" type:VARCHAR name:"count(col2)" type:INT64]`,
expRow: `[[VARCHAR("a") INT64(5)] [VARCHAR("b") INT64(0)]]`,
},
{
sql: `select col1, count(col2) from (select col1, col2 from user limit 32) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "null|1|null", "null|null|null", "a|1|a", "b|null|b"),
expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col1" type:VARCHAR name:"count(col2)" type:INT64]`,
expRow: `[[NULL INT64(8)] [VARCHAR("a") INT64(8)] [VARCHAR("b") INT64(0)]]`,
},
{
sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "a|3|a"),
expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:DECIMAL]`,
expRow: `[[VARCHAR("a") DECIMAL(12)]]`,
},
{
sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|2|a"),
expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`,
expRow: `[[VARCHAR("a") FLOAT64(8)]]`,
},
{
sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|x|a"),
expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`,
expRow: `[[VARCHAR("a") FLOAT64(0)]]`,
},
{
sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|null|a"),
expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit",
expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`,
expRow: `[[VARCHAR("a") NULL]]`,
},
Expand Down
25 changes: 15 additions & 10 deletions go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"slices"
"sort"

"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
Expand Down Expand Up @@ -84,7 +85,7 @@ func (qb *queryBuilder) addTableExpr(
}

func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) {
if _, toBeSkipped := qb.ctx.SkipPredicates[expr]; toBeSkipped {
if qb.ctx.ShouldSkip(expr) {
// This is a predicate that was added to the RHS of an ApplyJoin.
// The original predicate will be added, so we don't have to add this here
return
Expand Down Expand Up @@ -523,20 +524,24 @@ func buildProjection(op *Projection, qb *queryBuilder) {
}

func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) {
predicates := slice.Map(op.JoinPredicates.columns, func(jc applyJoinColumn) sqlparser.Expr {
// since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done
err := qb.ctx.SkipJoinPredicates(jc.Original)
if err != nil {
panic(err)
}
return jc.Original
})
pred := sqlparser.AndExpressions(predicates...)

buildQuery(op.LHS, qb)
// If we are going to add the predicate used in join here
// We should not add the predicate's copy of when it was split into
// two parts. To avoid this, we use the SkipPredicates map.
for _, expr := range qb.ctx.JoinPredicates[op.Predicate] {
qb.ctx.SkipPredicates[expr] = nil
}

qbR := &queryBuilder{ctx: qb.ctx}
buildQuery(op.RHS, qbR)

if op.LeftJoin {
qb.joinOuterWith(qbR, op.Predicate)
qb.joinOuterWith(qbR, pred)
} else {
qb.joinInnerWith(qbR, op.Predicate)
qb.joinInnerWith(qbR, pred)
}
}

Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ func (a *Aggregator) isDerived() bool {
return a.DT != nil
}

func (a *Aggregator) derivedName() string {
if a.DT == nil {
return ""
}

return a.DT.Alias
}

func (a *Aggregator) FindCol(ctx *plancontext.PlanningContext, in sqlparser.Expr, underRoute bool) int {
if underRoute && a.isDerived() {
// We don't want to use columns on this operator if it's a derived table under a route.
Expand Down
37 changes: 23 additions & 14 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ type (
// LeftJoin will be true in the case of an outer join
LeftJoin bool

// Before offset planning
Predicate sqlparser.Expr

// JoinColumns keeps track of what AST expression is represented in the Columns array
JoinColumns *applyJoinColumns

Expand Down Expand Up @@ -85,16 +82,17 @@ type (
}
)

func NewApplyJoin(lhs, rhs Operator, predicate sqlparser.Expr, leftOuterJoin bool) *ApplyJoin {
return &ApplyJoin{
func NewApplyJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, predicate sqlparser.Expr, leftOuterJoin bool) *ApplyJoin {
aj := &ApplyJoin{
LHS: lhs,
RHS: rhs,
Vars: map[string]int{},
Predicate: predicate,
LeftJoin: leftOuterJoin,
JoinColumns: &applyJoinColumns{},
JoinPredicates: &applyJoinColumns{},
}
aj.AddJoinPredicate(ctx, predicate)
return aj
}

// Clone implements the Operator interface
Expand All @@ -106,7 +104,6 @@ func (aj *ApplyJoin) Clone(inputs []Operator) Operator {
kopy.JoinColumns = aj.JoinColumns.clone()
kopy.JoinPredicates = aj.JoinPredicates.clone()
kopy.Vars = maps.Clone(aj.Vars)
kopy.Predicate = sqlparser.CloneExpr(aj.Predicate)
kopy.ExtraLHSVars = slices.Clone(aj.ExtraLHSVars)
return &kopy
}
Expand Down Expand Up @@ -150,8 +147,9 @@ func (aj *ApplyJoin) IsInner() bool {
}

func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) {
aj.Predicate = ctx.SemTable.AndExpressions(expr, aj.Predicate)

if expr == nil {
return
}
col := breakExpressionInLHSandRHSForApplyJoin(ctx, expr, TableID(aj.LHS))
aj.JoinPredicates.add(col)
rhs := aj.RHS.AddPredicate(ctx, col.RHSExpr)
Expand Down Expand Up @@ -266,11 +264,14 @@ func (aj *ApplyJoin) addOffset(offset int) {
}

func (aj *ApplyJoin) ShortDescription() string {
pred := sqlparser.String(aj.Predicate)
columns := slice.Map(aj.JoinColumns.columns, func(from applyJoinColumn) string {
return sqlparser.String(from.Original)
})
firstPart := fmt.Sprintf("on %s columns: %s", pred, strings.Join(columns, ", "))
fn := func(cols *applyJoinColumns) string {
out := slice.Map(cols.columns, func(jc applyJoinColumn) string {
return jc.String()
})
return strings.Join(out, ", ")
}

firstPart := fmt.Sprintf("on %s columns: %s", fn(aj.JoinPredicates), fn(aj.JoinColumns))
if len(aj.ExtraLHSVars) == 0 {
return firstPart
}
Expand Down Expand Up @@ -361,6 +362,14 @@ func (a *ApplyJoin) LHSColumnsNeeded(ctx *plancontext.PlanningContext) (needed s
return ctx.SemTable.Uniquify(needed)
}

func (jc applyJoinColumn) String() string {
rhs := sqlparser.String(jc.RHSExpr)
lhs := slice.Map(jc.LHSExprs, func(e BindVarExpr) string {
return sqlparser.String(e.Expr)
})
return fmt.Sprintf("[%s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original))
}

func (jc applyJoinColumn) IsPureLeft() bool {
return jc.RHSExpr == nil
}
Expand Down
4 changes: 4 additions & 0 deletions go/vt/vtgate/planbuilder/operators/ast_to_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ func createOpFromStmt(inCtx *plancontext.PlanningContext, stmt sqlparser.Stateme
if err != nil {
panic(err)
}

// need to remember which predicates have been broken up during join planning
inCtx.KeepPredicateInfo(ctx)

return op
}

Expand Down
3 changes: 2 additions & 1 deletion go/vt/vtgate/planbuilder/operators/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ func breakExpressionInLHSandRHSForApplyJoin(
cursor.Replace(arg)
}, nil).(sqlparser.Expr)

ctx.JoinPredicates[expr] = append(ctx.JoinPredicates[expr], rewrittenExpr)
ctx.AddJoinPredicates(expr, rewrittenExpr)
col.RHSExpr = rewrittenExpr
col.Original = expr
return
}
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/join_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func mergeShardedRouting(r1 *ShardedRouting, r2 *ShardedRouting) *ShardedRouting
}

func (jm *joinMerger) getApplyJoin(ctx *plancontext.PlanningContext, op1, op2 *Route) *ApplyJoin {
return NewApplyJoin(op1.Source, op2.Source, ctx.SemTable.AndExpressions(jm.predicates...), !jm.innerJoin)
return NewApplyJoin(ctx, op1.Source, op2.Source, ctx.SemTable.AndExpressions(jm.predicates...), !jm.innerJoin)
}

func (jm *joinMerger) merge(ctx *plancontext.PlanningContext, op1, op2 *Route, r Routing) *Route {
Expand Down
12 changes: 10 additions & 2 deletions go/vt/vtgate/planbuilder/operators/offset_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
// planOffsets will walk the tree top down, adding offset information to columns in the tree for use in further optimization,
func planOffsets(ctx *plancontext.PlanningContext, root Operator) Operator {
type offsettable interface {
Operator
planOffsets(ctx *plancontext.PlanningContext) Operator
}

Expand All @@ -37,9 +38,16 @@ func planOffsets(ctx *plancontext.PlanningContext, root Operator) Operator {
panic(vterrors.VT13001(fmt.Sprintf("should not see %T here", in)))
case offsettable:
newOp := op.planOffsets(ctx)
if newOp != nil {
return newOp, Rewrote("new operator after offset planning")

if newOp == nil {
newOp = op
}

if DebugOperatorTree {
fmt.Println("Planned offsets for:")
fmt.Println(ToTree(newOp))
}
return newOp, nil
}
return in, NoRewrite
}
Expand Down
10 changes: 9 additions & 1 deletion go/vt/vtgate/planbuilder/operators/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (sp StarProjections) GetSelectExprs() sqlparser.SelectExprs {

func (ap AliasedProjections) GetColumns() []*sqlparser.AliasedExpr {
return slice.Map(ap, func(from *ProjExpr) *sqlparser.AliasedExpr {
return aeWrap(from.ColExpr)
return from.Original
})
}

Expand Down Expand Up @@ -229,6 +229,14 @@ func (p *Projection) isDerived() bool {
return p.DT != nil
}

func (p *Projection) derivedName() string {
if p.DT == nil {
return ""
}

return p.DT.Alias
}

func (p *Projection) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int {
ap, err := p.GetAliasedProjections()
if err != nil {
Expand Down
Loading

0 comments on commit 3a98b4c

Please sign in to comment.