diff --git a/go/test/endtoend/vtgate/gen4/gen4_test.go b/go/test/endtoend/vtgate/gen4/gen4_test.go index 8764328495c..f284f85e883 100644 --- a/go/test/endtoend/vtgate/gen4/gen4_test.go +++ b/go/test/endtoend/vtgate/gen4/gen4_test.go @@ -187,26 +187,6 @@ func TestSubQueriesOnOuterJoinOnCondition(t *testing.T) { } } -func TestPlannerWarning(t *testing.T) { - mcmp, closer := start(t) - defer closer() - - // straight_join query - _ = utils.Exec(t, mcmp.VtConn, `select 1 from t1 straight_join t2 on t1.id = t2.id`) - utils.AssertMatches(t, mcmp.VtConn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`) - - // execute same query again. - _ = utils.Exec(t, mcmp.VtConn, `select 1 from t1 straight_join t2 on t1.id = t2.id`) - utils.AssertMatches(t, mcmp.VtConn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`) - - // random query to reset the warning. - _ = utils.Exec(t, mcmp.VtConn, `select 1 from t1`) - - // execute same query again. - _ = utils.Exec(t, mcmp.VtConn, `select 1 from t1 straight_join t2 on t1.id = t2.id`) - utils.AssertMatches(t, mcmp.VtConn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`) -} - func TestHashJoin(t *testing.T) { mcmp, closer := start(t) defer closer() diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index f6d6a18aa3b..c10cb4c9b71 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -425,3 +425,30 @@ func TestAlterTableWithView(t *testing.T) { mcmp.AssertMatches("select * from v1", `[[INT64(1) INT64(1)]]`) } + +// TestStraightJoin tests that Vitess respects the ordering of join in a STRAIGHT JOIN query. +func TestStraightJoin(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into tbl(id, unq_col, nonunq_col) values (1,0,10), (2,10,10), (3,4,20), (4,30,20), (5,40,10)") + mcmp.Exec(`insert into t1(id1, id2) values (10, 11), (20, 13)`) + + mcmp.AssertMatchesNoOrder("select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 join tbl where t1.id1 = tbl.nonunq_col", + `[[INT64(0) INT64(10) INT64(11)] [INT64(10) INT64(10) INT64(11)] [INT64(4) INT64(20) INT64(13)] [INT64(40) INT64(10) INT64(11)] [INT64(30) INT64(20) INT64(13)]]`, + ) + // Verify that in a normal join query, vitess joins tbl with t1. + res, err := mcmp.VtConn.ExecuteFetch("vexplain plan select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 join tbl where t1.id1 = tbl.nonunq_col", 100, false) + require.NoError(t, err) + require.Contains(t, fmt.Sprintf("%v", res.Rows), "tbl_t1") + + // Test the same query with a straight join + mcmp.AssertMatchesNoOrder("select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 straight_join tbl where t1.id1 = tbl.nonunq_col", + `[[INT64(0) INT64(10) INT64(11)] [INT64(10) INT64(10) INT64(11)] [INT64(4) INT64(20) INT64(13)] [INT64(40) INT64(10) INT64(11)] [INT64(30) INT64(20) INT64(13)]]`, + ) + // Verify that in a straight join query, vitess joins t1 with tbl. + res, err = mcmp.VtConn.ExecuteFetch("vexplain plan select tbl.unq_col, tbl.nonunq_col, t1.id2 from t1 straight_join tbl where t1.id1 = tbl.nonunq_col", 100, false) + require.NoError(t, err) + require.Contains(t, fmt.Sprintf("%v", res.Rows), "t1_tbl") +} diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 4335e2432f9..77cbed714b0 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -1881,6 +1881,26 @@ func (node DatabaseOptionType) ToString() string { } } +// IsCommutative returns whether the join type supports rearranging or not. +func (joinType JoinType) IsCommutative() bool { + switch joinType { + case StraightJoinType, LeftJoinType, RightJoinType, NaturalLeftJoinType, NaturalRightJoinType: + return false + default: + return true + } +} + +// IsInner returns whether the join type is an inner join or not. +func (joinType JoinType) IsInner() bool { + switch joinType { + case StraightJoinType, NaturalJoinType, NormalJoinType: + return true + default: + return false + } +} + // ToString returns the type as a string func (ty LockType) ToString() string { switch ty { diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 8f3e436deb8..6f2ad00f514 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -3343,18 +3343,12 @@ func TestGen4SelectStraightJoin(t *testing.T) { require.NoError(t, err) wantQueries := []*querypb.BoundQuery{ { - Sql: "select u.id from `user` as u, user2 as u2 where u.id = u2.id", + Sql: "select u.id from `user` as u straight_join user2 as u2 on u.id = u2.id", BindVariables: map[string]*querypb.BindVariable{}, }, } - wantWarnings := []*querypb.QueryWarning{ - { - Code: 1235, - Message: "straight join is converted to normal join", - }, - } utils.MustMatch(t, wantQueries, sbc1.Queries) - utils.MustMatch(t, wantWarnings, session.Warnings) + require.Empty(t, session.Warnings) } func TestGen4MultiColumnVindexEqual(t *testing.T) { diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 6d2c2317517..f0783a5ecfb 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -464,7 +464,7 @@ func transformApplyJoinPlan(ctx *plancontext.PlanningContext, n *operators.Apply return nil, err } opCode := engine.InnerJoin - if n.LeftJoin { + if !n.JoinType.IsInner() { opCode = engine.LeftJoin } diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 305ab299abe..5fc8f36c646 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -213,7 +213,7 @@ var _ FromStatement = (*sqlparser.Select)(nil) var _ FromStatement = (*sqlparser.Update)(nil) var _ FromStatement = (*sqlparser.Delete)(nil) -func (qb *queryBuilder) joinInnerWith(other *queryBuilder, onCondition sqlparser.Expr) { +func (qb *queryBuilder) joinWith(other *queryBuilder, onCondition sqlparser.Expr, joinType sqlparser.JoinType) { stmt := qb.stmt.(FromStatement) otherStmt := other.stmt.(FromStatement) @@ -222,24 +222,18 @@ func (qb *queryBuilder) joinInnerWith(other *queryBuilder, onCondition sqlparser sel.SelectExprs = append(sel.SelectExprs, otherSel.SelectExprs...) } - newFromClause := append(stmt.GetFrom(), otherStmt.GetFrom()...) - stmt.SetFrom(newFromClause) qb.mergeWhereClauses(stmt, otherStmt) - qb.addPredicate(onCondition) -} - -func (qb *queryBuilder) joinOuterWith(other *queryBuilder, onCondition sqlparser.Expr) { - stmt := qb.stmt.(FromStatement) - otherStmt := other.stmt.(FromStatement) - if sel, isSel := stmt.(*sqlparser.Select); isSel { - otherSel := otherStmt.(*sqlparser.Select) - sel.SelectExprs = append(sel.SelectExprs, otherSel.SelectExprs...) + var newFromClause []sqlparser.TableExpr + switch joinType { + case sqlparser.NormalJoinType: + newFromClause = append(stmt.GetFrom(), otherStmt.GetFrom()...) + qb.addPredicate(onCondition) + default: + newFromClause = []sqlparser.TableExpr{buildJoin(stmt, otherStmt, onCondition, joinType)} } - newFromClause := []sqlparser.TableExpr{buildOuterJoin(stmt, otherStmt, onCondition)} stmt.SetFrom(newFromClause) - qb.mergeWhereClauses(stmt, otherStmt) } func (qb *queryBuilder) mergeWhereClauses(stmt, otherStmt FromStatement) { @@ -254,7 +248,7 @@ func (qb *queryBuilder) mergeWhereClauses(stmt, otherStmt FromStatement) { } } -func buildOuterJoin(stmt FromStatement, otherStmt FromStatement, onCondition sqlparser.Expr) *sqlparser.JoinTableExpr { +func buildJoin(stmt FromStatement, otherStmt FromStatement, onCondition sqlparser.Expr, joinType sqlparser.JoinType) *sqlparser.JoinTableExpr { var lhs sqlparser.TableExpr fromClause := stmt.GetFrom() if len(fromClause) == 1 { @@ -273,7 +267,7 @@ func buildOuterJoin(stmt FromStatement, otherStmt FromStatement, onCondition sql return &sqlparser.JoinTableExpr{ LeftExpr: lhs, RightExpr: rhs, - Join: sqlparser.LeftJoinType, + Join: joinType, Condition: &sqlparser.JoinCondition{ On: onCondition, }, @@ -539,11 +533,7 @@ func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) { qbR := &queryBuilder{ctx: qb.ctx} buildQuery(op.RHS, qbR) - if op.LeftJoin { - qb.joinOuterWith(qbR, pred) - } else { - qb.joinInnerWith(qbR, pred) - } + qb.joinWith(qbR, pred, op.JoinType) } func buildUnion(op *Union, qb *queryBuilder) { diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index f4aa851e176..9d5b76b09a0 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -373,7 +373,7 @@ func pushAggregationThroughApplyJoin(ctx *plancontext.PlanningContext, rootAggr rhs := createJoinPusher(rootAggr, join.RHS) columns := &applyJoinColumns{} - output, err := splitAggrColumnsToLeftAndRight(ctx, rootAggr, join, join.LeftJoin, columns, lhs, rhs) + output, err := splitAggrColumnsToLeftAndRight(ctx, rootAggr, join, !join.JoinType.IsInner(), columns, lhs, rhs) join.JoinColumns = columns if err != nil { // if we get this error, we just abort the splitting and fall back on simpler ways of solving the same query diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index c182bb2fb83..2e72f2eae57 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -34,6 +34,9 @@ type ( ApplyJoin struct { LHS, RHS Operator + // JoinType is permitted to store only 3 of the possible values + // NormalJoinType, StraightJoinType and LeftJoinType. + JoinType sqlparser.JoinType // LeftJoin will be true in the case of an outer join LeftJoin bool @@ -82,12 +85,12 @@ type ( } ) -func NewApplyJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, predicate sqlparser.Expr, leftOuterJoin bool) *ApplyJoin { +func NewApplyJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, predicate sqlparser.Expr, joinType sqlparser.JoinType) *ApplyJoin { aj := &ApplyJoin{ LHS: lhs, RHS: rhs, Vars: map[string]int{}, - LeftJoin: leftOuterJoin, + JoinType: joinType, JoinColumns: &applyJoinColumns{}, JoinPredicates: &applyJoinColumns{}, } @@ -139,11 +142,14 @@ func (aj *ApplyJoin) SetRHS(operator Operator) { } func (aj *ApplyJoin) MakeInner() { - aj.LeftJoin = false + if aj.IsInner() { + return + } + aj.JoinType = sqlparser.NormalJoinType } func (aj *ApplyJoin) IsInner() bool { - return !aj.LeftJoin + return aj.JoinType.IsInner() } func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) { diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 14d4de1bf51..5633239346d 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -229,7 +229,9 @@ func getOperatorFromJoinTableExpr(ctx *plancontext.PlanningContext, tableExpr *s case sqlparser.NormalJoinType: return createInnerJoin(ctx, tableExpr, lhs, rhs) case sqlparser.LeftJoinType, sqlparser.RightJoinType: - return createOuterJoin(tableExpr, lhs, rhs) + return createLeftOuterJoin(ctx, tableExpr, lhs, rhs) + case sqlparser.StraightJoinType: + return createStraightJoin(ctx, tableExpr, lhs, rhs) default: panic(vterrors.VT13001("unsupported: %s", tableExpr.Join.ToString())) } diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index 787d7fedfcc..d13f79e010f 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -26,7 +26,9 @@ import ( type Join struct { LHS, RHS Operator Predicate sqlparser.Expr - LeftJoin bool + // JoinType is permitted to store only 3 of the possible values + // NormalJoinType, StraightJoinType and LeftJoinType. + JoinType sqlparser.JoinType noColumns } @@ -42,7 +44,7 @@ func (j *Join) Clone(inputs []Operator) Operator { LHS: inputs[0], RHS: inputs[1], Predicate: j.Predicate, - LeftJoin: j.LeftJoin, + JoinType: j.JoinType, } } @@ -61,8 +63,8 @@ func (j *Join) SetInputs(ops []Operator) { } func (j *Join) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult) { - if j.LeftJoin { - // we can't merge outer joins into a single QG + if !j.JoinType.IsCommutative() { + // if we can't move tables around, we can't merge these inputs return j, NoRewrite } @@ -83,38 +85,52 @@ func (j *Join) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult return newOp, Rewrote("merge querygraphs into a single one") } -func createOuterJoin(tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator { - if tableExpr.Join == sqlparser.RightJoinType { +func createStraightJoin(ctx *plancontext.PlanningContext, join *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator { + // for inner joins we can treat the predicates as filters on top of the join + joinOp := &Join{LHS: lhs, RHS: rhs, JoinType: join.Join} + + return addJoinPredicates(ctx, join.Condition.On, joinOp) +} + +func createLeftOuterJoin(ctx *plancontext.PlanningContext, join *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator { + // first we switch sides, so we always deal with left outer joins + switch join.Join { + case sqlparser.RightJoinType: lhs, rhs = rhs, lhs + join.Join = sqlparser.LeftJoinType + case sqlparser.NaturalRightJoinType: + lhs, rhs = rhs, lhs + join.Join = sqlparser.NaturalLeftJoinType } - subq, _ := getSubQuery(tableExpr.Condition.On) + + joinOp := &Join{LHS: lhs, RHS: rhs, JoinType: join.Join} + + // for outer joins we have to be careful with the predicates we use + var op Operator + subq, _ := getSubQuery(join.Condition.On) if subq != nil { panic(vterrors.VT12001("subquery in outer join predicate")) } - predicate := tableExpr.Condition.On + predicate := join.Condition.On sqlparser.RemoveKeyspaceInCol(predicate) - return &Join{LHS: lhs, RHS: rhs, LeftJoin: true, Predicate: predicate} -} + joinOp.Predicate = predicate + op = joinOp -func createJoin(ctx *plancontext.PlanningContext, LHS, RHS Operator) Operator { - lqg, lok := LHS.(*QueryGraph) - rqg, rok := RHS.(*QueryGraph) - if lok && rok { - op := &QueryGraph{ - Tables: append(lqg.Tables, rqg.Tables...), - innerJoins: append(lqg.innerJoins, rqg.innerJoins...), - NoDeps: ctx.SemTable.AndExpressions(lqg.NoDeps, rqg.NoDeps), - } - return op - } - return &Join{LHS: LHS, RHS: RHS} + return op } func createInnerJoin(ctx *plancontext.PlanningContext, tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator { op := createJoin(ctx, lhs, rhs) + return addJoinPredicates(ctx, tableExpr.Condition.On, op) +} + +func addJoinPredicates( + ctx *plancontext.PlanningContext, + joinPredicate sqlparser.Expr, + op Operator, +) Operator { sqc := &SubQueryBuilder{} outerID := TableID(op) - joinPredicate := tableExpr.Condition.On sqlparser.RemoveKeyspaceInCol(joinPredicate) exprs := sqlparser.SplitAndExpression(nil, joinPredicate) for _, pred := range exprs { @@ -127,6 +143,20 @@ func createInnerJoin(ctx *plancontext.PlanningContext, tableExpr *sqlparser.Join return sqc.getRootOperator(op, nil) } +func createJoin(ctx *plancontext.PlanningContext, LHS, RHS Operator) Operator { + lqg, lok := LHS.(*QueryGraph) + rqg, rok := RHS.(*QueryGraph) + if lok && rok { + op := &QueryGraph{ + Tables: append(lqg.Tables, rqg.Tables...), + innerJoins: append(lqg.innerJoins, rqg.innerJoins...), + NoDeps: ctx.SemTable.AndExpressions(lqg.NoDeps, rqg.NoDeps), + } + return op + } + return &Join{LHS: LHS, RHS: RHS} +} + func (j *Join) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { return AddPredicate(ctx, j, expr, false, newFilterSinglePredicate) } @@ -150,11 +180,14 @@ func (j *Join) SetRHS(operator Operator) { } func (j *Join) MakeInner() { - j.LeftJoin = false + if j.IsInner() { + return + } + j.JoinType = sqlparser.NormalJoinType } func (j *Join) IsInner() bool { - return !j.LeftJoin + return j.JoinType.IsInner() } func (j *Join) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) { diff --git a/go/vt/vtgate/planbuilder/operators/join_merging.go b/go/vt/vtgate/planbuilder/operators/join_merging.go index 0cc5da9121f..5edc812b1b7 100644 --- a/go/vt/vtgate/planbuilder/operators/join_merging.go +++ b/go/vt/vtgate/planbuilder/operators/join_merging.go @@ -92,7 +92,9 @@ type ( joinMerger struct { predicates []sqlparser.Expr - innerJoin bool + // joinType is permitted to store only 3 of the possible values + // NormalJoinType, StraightJoinType and LeftJoinType. + joinType sqlparser.JoinType } routingType int @@ -176,10 +178,10 @@ func getRoutingType(r Routing) routingType { panic(fmt.Sprintf("switch should be exhaustive, got %T", r)) } -func newJoinMerge(predicates []sqlparser.Expr, innerJoin bool) merger { +func newJoinMerge(predicates []sqlparser.Expr, joinType sqlparser.JoinType) merger { return &joinMerger{ predicates: predicates, - innerJoin: innerJoin, + joinType: joinType, } } @@ -203,7 +205,7 @@ func mergeShardedRouting(r1 *ShardedRouting, r2 *ShardedRouting) *ShardedRouting } func (jm *joinMerger) getApplyJoin(ctx *plancontext.PlanningContext, op1, op2 *Route) *ApplyJoin { - return NewApplyJoin(ctx, op1.Source, op2.Source, ctx.SemTable.AndExpressions(jm.predicates...), !jm.innerJoin) + return NewApplyJoin(ctx, op1.Source, op2.Source, ctx.SemTable.AndExpressions(jm.predicates...), jm.joinType) } func (jm *joinMerger) merge(ctx *plancontext.PlanningContext, op1, op2 *Route, r Routing) *Route { diff --git a/go/vt/vtgate/planbuilder/operators/projection_pushing.go b/go/vt/vtgate/planbuilder/operators/projection_pushing.go index 59f6e6d484d..6df1caee5de 100644 --- a/go/vt/vtgate/planbuilder/operators/projection_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/projection_pushing.go @@ -214,7 +214,7 @@ func pushProjectionInApplyJoin( src *ApplyJoin, ) (Operator, *ApplyResult) { ap, err := p.GetAliasedProjections() - if src.LeftJoin || err != nil { + if !src.IsInner() || err != nil { // we can't push down expression evaluation to the rhs if we are not sure if it will even be executed return p, NoRewrite } diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index f214cb6512e..e6db9d407e3 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -408,7 +408,7 @@ func canPushLeft(ctx *plancontext.PlanningContext, aj *ApplyJoin, order []OrderB func isOuterTable(op Operator, ts semantics.TableSet) bool { aj, ok := op.(*ApplyJoin) - if ok && aj.LeftJoin && TableID(aj.RHS).IsOverlapping(ts) { + if ok && !aj.IsInner() && TableID(aj.RHS).IsOverlapping(ts) { return true } diff --git a/go/vt/vtgate/planbuilder/operators/route_planning.go b/go/vt/vtgate/planbuilder/operators/route_planning.go index f7276ea48c7..f2cf3116f72 100644 --- a/go/vt/vtgate/planbuilder/operators/route_planning.go +++ b/go/vt/vtgate/planbuilder/operators/route_planning.go @@ -53,7 +53,7 @@ func pushDerived(ctx *plancontext.PlanningContext, op *Horizon) (Operator, *Appl } func optimizeJoin(ctx *plancontext.PlanningContext, op *Join) (Operator, *ApplyResult) { - return mergeOrJoin(ctx, op.LHS, op.RHS, sqlparser.SplitAndExpression(nil, op.Predicate), !op.LeftJoin) + return mergeOrJoin(ctx, op.LHS, op.RHS, sqlparser.SplitAndExpression(nil, op.Predicate), op.JoinType) } func optimizeQueryGraph(ctx *plancontext.PlanningContext, op *QueryGraph) (result Operator, changed *ApplyResult) { @@ -147,7 +147,7 @@ func leftToRightSolve(ctx *plancontext.PlanningContext, qg *QueryGraph) Operator continue } joinPredicates := qg.GetPredicates(TableID(acc), TableID(plan)) - acc, _ = mergeOrJoin(ctx, acc, plan, joinPredicates, true) + acc, _ = mergeOrJoin(ctx, acc, plan, joinPredicates, sqlparser.NormalJoinType) } return acc @@ -262,7 +262,7 @@ func getJoinFor(ctx *plancontext.PlanningContext, cm opCacheMap, lhs, rhs Operat return cachedPlan } - join, _ := mergeOrJoin(ctx, lhs, rhs, joinPredicates, true) + join, _ := mergeOrJoin(ctx, lhs, rhs, joinPredicates, sqlparser.NormalJoinType) cm[solves] = join return join } @@ -283,16 +283,16 @@ func requiresSwitchingSides(ctx *plancontext.PlanningContext, op Operator) (requ return } -func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredicates []sqlparser.Expr, inner bool) (Operator, *ApplyResult) { - newPlan := mergeJoinInputs(ctx, lhs, rhs, joinPredicates, newJoinMerge(joinPredicates, inner)) +func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredicates []sqlparser.Expr, joinType sqlparser.JoinType) (Operator, *ApplyResult) { + newPlan := mergeJoinInputs(ctx, lhs, rhs, joinPredicates, newJoinMerge(joinPredicates, joinType)) if newPlan != nil { return newPlan, Rewrote("merge routes into single operator") } if len(joinPredicates) > 0 && requiresSwitchingSides(ctx, rhs) { - if !inner || requiresSwitchingSides(ctx, lhs) { + if !joinType.IsCommutative() || requiresSwitchingSides(ctx, lhs) { // we can't switch sides, so let's see if we can use a HashJoin to solve it - join := NewHashJoin(lhs, rhs, !inner) + join := NewHashJoin(lhs, rhs, !joinType.IsInner()) for _, pred := range joinPredicates { join.AddJoinPredicate(ctx, pred) } @@ -300,12 +300,12 @@ func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredic return join, Rewrote("use a hash join because we have LIMIT on the LHS") } - join := NewApplyJoin(ctx, Clone(rhs), Clone(lhs), nil, !inner) + join := NewApplyJoin(ctx, Clone(rhs), Clone(lhs), nil, joinType) newOp := pushJoinPredicates(ctx, joinPredicates, join) return newOp, Rewrote("logical join to applyJoin, switching side because LIMIT") } - join := NewApplyJoin(ctx, Clone(lhs), Clone(rhs), nil, !inner) + join := NewApplyJoin(ctx, Clone(lhs), Clone(rhs), nil, joinType) newOp := pushJoinPredicates(ctx, joinPredicates, join) return newOp, Rewrote("logical join to applyJoin ") } diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 960cde99acc..af25136b16a 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -231,7 +231,7 @@ func tryPushSubQueryInJoin( return outer, Rewrote("push subquery into LHS of join") } - if outer.LeftJoin || len(inner.Predicates) == 0 { + if !outer.IsInner() || len(inner.Predicates) == 0 { // we can't push any filters on the RHS of an outer join, and // we don't want to push uncorrelated subqueries to the RHS of a join return nil, NoRewrite @@ -278,7 +278,7 @@ func extractLHSExpr( // tryMergeWithRHS attempts to merge a subquery with the RHS of a join func tryMergeWithRHS(ctx *plancontext.PlanningContext, inner *SubQuery, outer *ApplyJoin) (Operator, *ApplyResult) { - if outer.LeftJoin { + if !outer.IsInner() { return nil, nil } // both sides need to be routes diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index 28b54bae9b2..63fba06202c 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -2328,6 +2328,74 @@ ] } }, + { + "comment": "Straight Join ensures specific ordering of joins", + "query": "select user.id, user_extra.user_id from user straight_join user_extra where user.id = user_extra.foo", + "plan": { + "QueryType": "SELECT", + "Original": "select user.id, user_extra.user_id from user straight_join user_extra where user.id = user_extra.foo", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0", + "JoinVars": { + "user_id": 0 + }, + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.id from `user` where 1 != 1", + "Query": "select `user`.id from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.user_id from user_extra where 1 != 1", + "Query": "select user_extra.user_id from user_extra where user_extra.foo = :user_id", + "Table": "user_extra" + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "Straight Join preserved in MySQL query", + "query": "select user.id, user_extra.user_id from user straight_join user_extra where user.id = user_extra.user_id", + "plan": { + "QueryType": "SELECT", + "Original": "select user.id, user_extra.user_id from user straight_join user_extra where user.id = user_extra.user_id", + "Instructions": { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.id, user_extra.user_id from `user` straight_join user_extra on `user`.id = user_extra.user_id where 1 != 1", + "Query": "select `user`.id, user_extra.user_id from `user` straight_join user_extra on `user`.id = user_extra.user_id", + "Table": "`user`, user_extra" + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, { "comment": "correlated subquery in exists clause", "query": "select col from user where exists(select user_id from user_extra where user_id = 3 and user_id < user.id)", diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index a8e1442edb8..2e67509c06f 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -47,8 +47,6 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { case sqlparser.SelectExprs: return r.handleSelectExprs(cursor, node) - case *sqlparser.JoinTableExpr: - r.handleJoinTableExprDown(node) case *sqlparser.OrExpr: rewriteOrExpr(r.env, cursor, node) case *sqlparser.AndExpr: @@ -223,15 +221,6 @@ func (r *earlyRewriter) handleSelectExprs(cursor *sqlparser.Cursor, node sqlpars return r.expandStar(cursor, node) } -// handleJoinTableExprDown processes JOIN table expressions and handles the Straight Join type. -func (r *earlyRewriter) handleJoinTableExprDown(node *sqlparser.JoinTableExpr) { - if node.Join != sqlparser.StraightJoinType { - return - } - node.Join = sqlparser.NormalJoinType - r.warning = "straight join is converted to normal join" -} - type orderByIterator struct { node sqlparser.OrderBy idx int