Skip to content

Commit

Permalink
fix concurrency on stream execute engine primitives (#14586)
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
Signed-off-by: Dirkjan Bussink <[email protected]>
Co-authored-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
harshit-gangal and dbussink authored Nov 23, 2023
1 parent 4733045 commit b68b1ed
Show file tree
Hide file tree
Showing 11 changed files with 356 additions and 10 deletions.
6 changes: 5 additions & 1 deletion go/vt/vtgate/engine/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package engine
import (
"context"
"fmt"
"sync"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -197,13 +198,16 @@ func (d *Distinct) TryExecute(ctx context.Context, vcursor VCursor, bindVars map

// TryStreamExecute implements the Primitive interface
func (d *Distinct) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
pt := newProbeTable(d.CheckCols)
var mu sync.Mutex

pt := newProbeTable(d.CheckCols)
err := vcursor.StreamExecutePrimitive(ctx, d.Source, bindVars, wantfields, func(input *sqltypes.Result) error {
result := &sqltypes.Result{
Fields: input.Fields,
InsertID: input.InsertID,
}
mu.Lock()
defer mu.Unlock()
for _, row := range input.Rows {
exists, err := pt.exists(row)
if err != nil {
Expand Down
53 changes: 53 additions & 0 deletions go/vt/vtgate/engine/distinct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,59 @@ func TestDistinct(t *testing.T) {
}
}

func TestDistinctStreamAsync(t *testing.T) {
distinct := &Distinct{
Source: &fakePrimitive{
results: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("myid|id|num|name", "varchar|int64|int64|varchar"),
"a|1|1|a",
"a|1|1|a",
"a|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
"z|1|1|a",
"a|1|1|t",
"a|1|1|a",
"a|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
"a|1|1|a",
"c|1|1|a",
"a|1|1|a",
"a|1|1|a",
"---",
"c|1|1|a",
"a|1|1|a",
),
async: true,
},
CheckCols: []CheckCol{
{Col: 0, Type: evalengine.NewType(sqltypes.VarChar, collations.CollationUtf8mb4ID)},
{Col: 1, Type: evalengine.NewType(sqltypes.Int64, collations.CollationBinaryID)},
{Col: 2, Type: evalengine.NewType(sqltypes.Int64, collations.CollationBinaryID)},
{Col: 3, Type: evalengine.NewType(sqltypes.VarChar, collations.CollationUtf8mb4ID)},
},
}

qr := &sqltypes.Result{}
err := distinct.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(result *sqltypes.Result) error {
qr.Rows = append(qr.Rows, result.Rows...)
return nil
})
require.NoError(t, err)
require.NoError(t, sqltypes.RowsEqualsStr(`
[[VARCHAR("c") INT64(1) INT64(1) VARCHAR("a")]
[VARCHAR("a") INT64(1) INT64(1) VARCHAR("a")]
[VARCHAR("z") INT64(1) INT64(1) VARCHAR("a")]
[VARCHAR("a") INT64(1) INT64(1) VARCHAR("t")]]`, qr.Rows))
}

func TestWeightStringFallBack(t *testing.T) {
offsetOne := 1
checkCols := []CheckCol{{
Expand Down
51 changes: 49 additions & 2 deletions go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import (
"strings"
"testing"

"vitess.io/vitess/go/sqltypes"
"golang.org/x/sync/errgroup"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
)

Expand All @@ -41,6 +42,8 @@ type fakePrimitive struct {
log []string

allResultsInOneCall bool

async bool
}

func (f *fakePrimitive) Inputs() ([]Primitive, []map[string]any) {
Expand Down Expand Up @@ -86,6 +89,13 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b
return f.sendErr
}

if f.async {
return f.asyncCall(callback)
}
return f.syncCall(wantfields, callback)
}

func (f *fakePrimitive) syncCall(wantfields bool, callback func(*sqltypes.Result) error) error {
readMoreResults := true
for readMoreResults && f.curResult < len(f.results) {
readMoreResults = f.allResultsInOneCall
Expand Down Expand Up @@ -116,9 +126,46 @@ func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, b
}
}
}

return nil
}

func (f *fakePrimitive) asyncCall(callback func(*sqltypes.Result) error) error {
var g errgroup.Group
var fields []*querypb.Field
if len(f.results) > 0 {
fields = f.results[0].Fields
}
for _, res := range f.results {
qr := res
g.Go(func() error {
if qr == nil {
return f.sendErr
}
if err := callback(&sqltypes.Result{Fields: fields}); err != nil {
return err
}
result := &sqltypes.Result{}
for i := 0; i < len(qr.Rows); i++ {
result.Rows = append(result.Rows, qr.Rows[i])
// Send only two rows at a time.
if i%2 == 1 {
if err := callback(result); err != nil {
return err
}
result = &sqltypes.Result{}
}
}
if len(result.Rows) != 0 {
if err := callback(result); err != nil {
return err
}
}
return nil
})
}
return g.Wait()
}

func (f *fakePrimitive) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
f.log = append(f.log, fmt.Sprintf("GetFields %v", printBindVars(bindVars)))
return f.TryExecute(ctx, vcursor, bindVars, true /* wantfields */)
Expand Down
6 changes: 6 additions & 0 deletions go/vt/vtgate/engine/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package engine

import (
"context"
"sync"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -78,9 +79,14 @@ func (f *Filter) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[s

// TryStreamExecute satisfies the Primitive interface.
func (f *Filter) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
var mu sync.Mutex

env := evalengine.NewExpressionEnv(ctx, bindVars, vcursor)
filter := func(results *sqltypes.Result) error {
var rows [][]sqltypes.Value

mu.Lock()
defer mu.Unlock()
for _, row := range results.Rows {
env.Row = row
evalResult, err := env.Evaluate(f.Predicate)
Expand Down
60 changes: 60 additions & 0 deletions go/vt/vtgate/engine/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,63 @@ func TestFilterPass(t *testing.T) {
})
}
}

func TestFilterStreaming(t *testing.T) {
utf8mb4Bin := collationEnv.LookupByName("utf8mb4_bin")
predicate := &sqlparser.ComparisonExpr{
Operator: sqlparser.GreaterThanOp,
Left: sqlparser.NewColName("left"),
Right: sqlparser.NewColName("right"),
}

tcases := []struct {
name string
res []*sqltypes.Result
expRes string
}{{
name: "int32",
res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "int32|int32"), "0|1", "---", "1|0", "2|3"),
expRes: `[[INT32(1) INT32(0)]]`,
}, {
name: "uint16",
res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "uint16|uint16"), "0|1", "1|0", "---", "2|3"),
expRes: `[[UINT16(1) UINT16(0)]]`,
}, {
name: "uint64_int64",
res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "uint64|int64"), "0|1", "---", "1|0", "2|3"),
expRes: `[[UINT64(1) INT64(0)]]`,
}, {
name: "int32_uint32",
res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "int32|uint32"), "0|1", "---", "1|0", "---", "2|3"),
expRes: `[[INT32(1) UINT32(0)]]`,
}, {
name: "uint16_int8",
res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "uint16|int8"), "0|1", "1|0", "2|3", "---"),
expRes: `[[UINT16(1) INT8(0)]]`,
}, {
name: "uint64_int32",
res: sqltypes.MakeTestStreamingResults(sqltypes.MakeTestFields("left|right", "uint64|int32"), "0|1", "1|0", "2|3", "---", "0|1", "1|3", "5|3"),
expRes: `[[UINT64(1) INT32(0)] [UINT64(5) INT32(3)]]`,
}}
for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {
pred, err := evalengine.Translate(predicate, &evalengine.Config{
Collation: utf8mb4Bin,
ResolveColumn: evalengine.FieldResolver(tc.res[0].Fields).Column,
})
require.NoError(t, err)

filter := &Filter{
Predicate: pred,
Input: &fakePrimitive{results: tc.res, async: true},
}
qr := &sqltypes.Result{}
err = filter.TryStreamExecute(context.Background(), &noopVCursor{}, nil, false, func(result *sqltypes.Result) error {
qr.Rows = append(qr.Rows, result.Rows...)
return nil
})
require.NoError(t, err)
require.NoError(t, sqltypes.RowsEqualsStr(tc.expRes, qr.Rows))
})
}
}
6 changes: 5 additions & 1 deletion go/vt/vtgate/engine/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"io"
"strconv"
"sync"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/evalengine"
Expand Down Expand Up @@ -97,8 +98,11 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
// the offset in memory from the result of the scatter query with count + offset.
bindVars["__upper_limit"] = sqltypes.Int64BindVariable(int64(count + offset))

var mu sync.Mutex
err = vcursor.StreamExecutePrimitive(ctx, l.Input, bindVars, wantfields, func(qr *sqltypes.Result) error {
if len(qr.Fields) != 0 {
mu.Lock()
defer mu.Unlock()
if wantfields && len(qr.Fields) != 0 {
if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil {
return err
}
Expand Down
67 changes: 67 additions & 0 deletions go/vt/vtgate/engine/limit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,73 @@ func TestLimitStreamExecute(t *testing.T) {
}
}

func TestLimitStreamExecuteAsync(t *testing.T) {
bindVars := make(map[string]*querypb.BindVariable)
fields := sqltypes.MakeTestFields(
"col1|col2",
"int64|varchar",
)
inputResults := sqltypes.MakeTestStreamingResults(
fields,
"a|1",
"b|2",
"d|3",
"e|4",
"a|1",
"b|2",
"d|3",
"e|4",
"---",
"c|7",
"x|8",
"y|9",
"c|7",
"x|8",
"y|9",
"c|7",
"x|8",
"y|9",
"---",
"l|4",
"m|5",
"n|6",
"l|4",
"m|5",
"n|6",
"l|4",
"m|5",
"n|6",
)
fp := &fakePrimitive{
results: inputResults,
async: true,
}

const maxCount = 26
for i := 0; i <= maxCount*20; i++ {
expRows := i
l := &Limit{
Count: evalengine.NewLiteralInt(int64(expRows)),
Input: fp,
}
// Test with limit smaller than input.
results := &sqltypes.Result{}

err := l.TryStreamExecute(context.Background(), &noopVCursor{}, bindVars, true, func(qr *sqltypes.Result) error {
if qr != nil {
results.Rows = append(results.Rows, qr.Rows...)
}
return nil
})
require.NoError(t, err)
if expRows > maxCount {
expRows = maxCount
}
require.Len(t, results.Rows, expRows)
}

}

func TestOffsetStreamExecute(t *testing.T) {
bindVars := make(map[string]*querypb.BindVariable)
fields := sqltypes.MakeTestFields(
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/engine/memory_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"reflect"
"strconv"
"strings"
"sync"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -101,7 +102,11 @@ func (ms *MemorySort) TryStreamExecute(ctx context.Context, vcursor VCursor, bin
Compare: ms.OrderBy,
Limit: count,
}

var mu sync.Mutex
err = vcursor.StreamExecutePrimitive(ctx, ms.Input, bindVars, wantfields, func(qr *sqltypes.Result) error {
mu.Lock()
defer mu.Unlock()
if len(qr.Fields) != 0 {
if err := cb(&sqltypes.Result{Fields: qr.Fields}); err != nil {
return err
Expand Down
Loading

0 comments on commit b68b1ed

Please sign in to comment.