Skip to content

Commit

Permalink
Prevent Early Ordering Pushdown to Enable Aggregation Pushdown to MyS…
Browse files Browse the repository at this point in the history
…QL (#16278)

Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay authored Jul 4, 2024
1 parent 05728fe commit 694a0cf
Show file tree
Hide file tree
Showing 20 changed files with 1,008 additions and 662 deletions.
51 changes: 51 additions & 0 deletions go/test/endtoend/vtgate/queries/benchmark/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,35 @@ type testQuery struct {
intTyp []bool
}

var deleteUser, deleteUserExtra = "delete from user", "delete from user_extra"

func generateInserts(userSize int, userExtraSize int) (string, string) {
var userInserts []string
var userExtraInserts []string

// Generate user table inserts
for i := 1; i <= userSize; i++ {
id := i
notShardingKey := i
typeValue := i % 5 // Just an example for type values
teamID := i%userExtraSize + 1 // To ensure team_id references user_extra id
userInserts = append(userInserts, fmt.Sprintf("(%d, %d, %d, %d)", id, notShardingKey, typeValue, teamID))
}

// Generate user_extra table inserts
for i := 1; i <= userExtraSize; i++ {
id := i
notShardingKey := i
colValue := fmt.Sprintf("col_value_%d", i)
userExtraInserts = append(userExtraInserts, fmt.Sprintf("(%d, %d, '%s')", id, notShardingKey, colValue))
}

userInsertStatement := fmt.Sprintf("INSERT INTO user (id, not_sharding_key, type, team_id) VALUES %s;", strings.Join(userInserts, ", "))
userExtraInsertStatement := fmt.Sprintf("INSERT INTO user_extra (id, not_sharding_key, col) VALUES %s;", strings.Join(userExtraInserts, ", "))

return userInsertStatement, userExtraInsertStatement
}

func (tq *testQuery) getInsertQuery(rows int) string {
var allRows []string
for i := 0; i < rows; i++ {
Expand Down Expand Up @@ -146,3 +175,25 @@ func BenchmarkShardedTblDeleteIn(b *testing.B) {
})
}
}

func BenchmarkShardedAggrPushDown(b *testing.B) {
conn, closer := start(b)
defer closer()

sizes := []int{100, 500, 1000}

for _, user := range sizes {
for _, userExtra := range sizes {
insert1, insert2 := generateInserts(user, userExtra)
_ = utils.Exec(b, conn, insert1)
_ = utils.Exec(b, conn, insert2)
b.Run(fmt.Sprintf("user-%d-user_extra-%d", user, userExtra), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = utils.Exec(b, conn, "select sum(user.type) from user join user_extra on user.team_id = user_extra.id group by user_extra.id order by user_extra.id")
}
})
_ = utils.Exec(b, conn, deleteUser)
_ = utils.Exec(b, conn, deleteUserExtra)
}
}
}
37 changes: 27 additions & 10 deletions go/test/endtoend/vtgate/queries/benchmark/sharded_schema.sql
Original file line number Diff line number Diff line change
@@ -1,16 +1,33 @@
create table tbl_no_lkp_vdx
(
id bigint,
c1 varchar(50),
c2 varchar(50),
c3 varchar(50),
c4 varchar(50),
c5 varchar(50),
c6 varchar(50),
c7 varchar(50),
c8 varchar(50),
c9 varchar(50),
c1 varchar(50),
c2 varchar(50),
c3 varchar(50),
c4 varchar(50),
c5 varchar(50),
c6 varchar(50),
c7 varchar(50),
c8 varchar(50),
c9 varchar(50),
c10 varchar(50),
c11 varchar(50),
c12 varchar(50)
) Engine = InnoDB;
) Engine = InnoDB;

create table user
(
id bigint,
not_sharding_key bigint,
type int,
team_id int,
primary key (id)
);

create table user_extra
(
id bigint,
not_sharding_key bigint,
col varchar(50),
primary key (id)
);
16 changes: 16 additions & 0 deletions go/test/endtoend/vtgate/queries/benchmark/vschema.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,22 @@
"name": "xxhash"
}
]
},
"user": {
"column_vindexes": [
{
"column": "id",
"name": "xxhash"
}
]
},
"user_extra": {
"column_vindexes": [
{
"column": "id",
"name": "xxhash"
}
]
}
}
}
7 changes: 6 additions & 1 deletion go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,14 @@ func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggrega
var groupByKeys []*engine.GroupByParams

for _, aggr := range op.Aggregations {
if aggr.OpCode == opcode.AggregateUnassigned {
switch aggr.OpCode {
case opcode.AggregateUnassigned:
return nil, vterrors.VT12001(fmt.Sprintf("in scatter query: aggregation function '%s'", sqlparser.String(aggr.Original)))
case opcode.AggregateUDF:
message := fmt.Sprintf("Aggregate UDF '%s' must be pushed down to MySQL", sqlparser.String(aggr.Original.Expr))
return nil, vterrors.VT12001(message)
}

aggrParam := engine.NewAggregateParam(aggr.OpCode, aggr.ColOffset, aggr.Alias, ctx.VSchema.Environment().CollationEnv())
aggrParam.Func = aggr.Func
if gcFunc, isGc := aggrParam.Func.(*sqlparser.GroupConcatExpr); isGc && gcFunc.Separator == "" {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) {

switch stmt := qb.stmt.(type) {
case *sqlparser.Select:
if ContainsAggr(qb.ctx, expr) {
if qb.ctx.ContainsAggr(expr) {
addPred = stmt.AddHaving
} else {
addPred = stmt.AddWhere
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,14 +370,14 @@ func pushAggregationThroughApplyJoin(ctx *plancontext.PlanningContext, rootAggr

columns := &applyJoinColumns{}
output, err := splitAggrColumnsToLeftAndRight(ctx, rootAggr, join, !join.JoinType.IsInner(), columns, lhs, rhs)
join.JoinColumns = columns
if err != nil {
// if we get this error, we just abort the splitting and fall back on simpler ways of solving the same query
if errors.Is(err, errAbortAggrPushing) {
return nil, nil
}
panic(err)
}
join.JoinColumns = columns

splitGroupingToLeftAndRight(ctx, rootAggr, lhs, rhs, join.JoinColumns)

Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (a *Aggregator) addColumnWithoutPushing(ctx *plancontext.PlanningContext, e
case sqlparser.AggrFunc:
aggr = createAggrFromAggrFunc(e, expr)
case *sqlparser.FuncExpr:
if IsAggr(ctx, e) {
if ctx.IsAggr(e) {
aggr = NewAggr(opcode.AggregateUDF, nil, expr, expr.As.String())
} else {
aggr = NewAggr(opcode.AggregateAnyValue, nil, expr, expr.As.String())
Expand Down
75 changes: 35 additions & 40 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,19 +202,11 @@ func (aj *ApplyJoin) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy {
return aj.LHS.GetOrdering(ctx)
}

func joinColumnToExpr(column applyJoinColumn) sqlparser.Expr {
return column.Original
}

func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sqlparser.AliasedExpr, e sqlparser.Expr, addToGroupBy bool) (col applyJoinColumn) {
defer func() {
col.Original = orig.Expr
}()
lhs := TableID(aj.LHS)
rhs := TableID(aj.RHS)
both := lhs.Merge(rhs)
deps := ctx.SemTable.RecursiveDeps(e)
col.GroupBy = addToGroupBy

switch {
case deps.IsSolvedBy(lhs):
Expand All @@ -224,9 +216,12 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq
case deps.IsSolvedBy(both):
col = breakExpressionInLHSandRHS(ctx, e, TableID(aj.LHS))
default:
panic(vterrors.VT13002(sqlparser.String(e)))
panic(vterrors.VT13001(fmt.Sprintf("expression depends on tables outside this join: %s", sqlparser.String(e))))
}

col.GroupBy = addToGroupBy
col.Original = orig.Expr

return
}

Expand Down Expand Up @@ -289,7 +284,7 @@ func (aj *ApplyJoin) AddWSColumn(ctx *plancontext.PlanningContext, offset int, u
if out >= 0 {
aj.addOffset(out)
} else {
col := aj.getJoinColumnFor(ctx, aeWrap(wsExpr), wsExpr, !ContainsAggr(ctx, wsExpr))
col := aj.getJoinColumnFor(ctx, aeWrap(wsExpr), wsExpr, !ctx.ContainsAggr(wsExpr))
aj.JoinColumns.add(col)
aj.planOffsetFor(ctx, col)
}
Expand Down Expand Up @@ -323,33 +318,15 @@ func (aj *ApplyJoin) planOffsets(ctx *plancontext.PlanningContext) Operator {
}

func (aj *ApplyJoin) planOffsetFor(ctx *plancontext.PlanningContext, col applyJoinColumn) {
if col.DTColName != nil {
// If DTColName is set, then we already pushed the parts of the expression down while planning.
// We need to use this name and ask the correct side of the join for it. Nothing else is required.
if col.IsPureLeft() {
offset := aj.LHS.AddColumn(ctx, true, col.GroupBy, aeWrap(col.DTColName))
aj.addOffset(ToLeftOffset(offset))
} else {
for _, lhsExpr := range col.LHSExprs {
offset := aj.LHS.AddColumn(ctx, true, col.GroupBy, aeWrap(lhsExpr.Expr))
aj.Vars[lhsExpr.Name] = offset
}
offset := aj.RHS.AddColumn(ctx, true, col.GroupBy, aeWrap(col.DTColName))
aj.addOffset(ToRightOffset(offset))
}
return
}
for _, lhsExpr := range col.LHSExprs {
offset := aj.LHS.AddColumn(ctx, true, col.GroupBy, aeWrap(lhsExpr.Expr))
if col.RHSExpr == nil {
// if we don't have an RHS expr, it means that this is a pure LHS expression
aj.addOffset(ToLeftOffset(offset))
} else {
if col.IsPureLeft() {
offset := aj.LHS.AddColumn(ctx, true, col.GroupBy, aeWrap(col.GetPureLeftExpr()))
aj.addOffset(ToLeftOffset(offset))
} else {
for _, lhsExpr := range col.LHSExprs {
offset := aj.LHS.AddColumn(ctx, true, col.GroupBy, aeWrap(lhsExpr.Expr))
aj.Vars[lhsExpr.Name] = offset
}
}
if col.RHSExpr != nil {
offset := aj.RHS.AddColumn(ctx, true, col.GroupBy, aeWrap(col.RHSExpr))
offset := aj.RHS.AddColumn(ctx, true, col.GroupBy, aeWrap(col.GetRHSExpr()))
aj.addOffset(ToRightOffset(offset))
}
}
Expand Down Expand Up @@ -443,17 +420,17 @@ func (aj *ApplyJoin) findOrAddColNameBindVarName(ctx *plancontext.PlanningContex
return bvName
}

func (a *ApplyJoin) LHSColumnsNeeded(ctx *plancontext.PlanningContext) (needed sqlparser.Exprs) {
func (aj *ApplyJoin) LHSColumnsNeeded(ctx *plancontext.PlanningContext) (needed sqlparser.Exprs) {
f := func(from BindVarExpr) sqlparser.Expr {
return from.Expr
}
for _, jc := range a.JoinColumns.columns {
for _, jc := range aj.JoinColumns.columns {
needed = append(needed, slice.Map(jc.LHSExprs, f)...)
}
for _, jc := range a.JoinPredicates.columns {
for _, jc := range aj.JoinPredicates.columns {
needed = append(needed, slice.Map(jc.LHSExprs, f)...)
}
needed = append(needed, slice.Map(a.ExtraLHSVars, f)...)
needed = append(needed, slice.Map(aj.ExtraLHSVars, f)...)
return ctx.SemTable.Uniquify(needed)
}

Expand All @@ -462,7 +439,11 @@ func (jc applyJoinColumn) String() string {
lhs := slice.Map(jc.LHSExprs, func(e BindVarExpr) string {
return sqlparser.String(e.Expr)
})
return fmt.Sprintf("[%s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original))
if jc.DTColName == nil {
return fmt.Sprintf("[%s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original))
}

return fmt.Sprintf("[%s | %s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original), sqlparser.String(jc.DTColName))
}

func (jc applyJoinColumn) IsPureLeft() bool {
Expand All @@ -477,6 +458,20 @@ func (jc applyJoinColumn) IsMixedLeftAndRight() bool {
return len(jc.LHSExprs) > 0 && jc.RHSExpr != nil
}

func (jc applyJoinColumn) GetPureLeftExpr() sqlparser.Expr {
if jc.DTColName != nil {
return jc.DTColName
}
return jc.LHSExprs[0].Expr
}

func (jc applyJoinColumn) GetRHSExpr() sqlparser.Expr {
if jc.DTColName != nil {
return jc.DTColName
}
return jc.RHSExpr
}

func (bve BindVarExpr) String() string {
if bve.Name == "" {
return sqlparser.String(bve.Expr)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/horizon.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (h *Horizon) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.
}

newExpr := ctx.RewriteDerivedTableExpression(expr, tableInfo)
if ContainsAggr(ctx, newExpr) {
if ctx.ContainsAggr(newExpr) {
return newFilter(h, expr)
}
h.Source = h.Source.AddPredicate(ctx, newExpr)
Expand Down
Loading

0 comments on commit 694a0cf

Please sign in to comment.