Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Data race in semi-join #17417

Merged
merged 4 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,24 @@ func TestTimeZones(t *testing.T) {
})
}
}

// TestSemiJoin tests that the semi join works as intended.
func TestSemiJoin(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 22, "vtgate")
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
mcmp, closer := start(t)
defer closer()

for i := 1; i <= 1000; i++ {
mcmp.Exec(fmt.Sprintf("insert into t1(id1, id2) values (%d, %d)", i, 2*i))
mcmp.Exec(fmt.Sprintf("insert into tbl(id, unq_col, nonunq_col) values (%d, %d, %d)", i, 2*i, 3*i))
}

// Test that the semi join works as intended
for _, mode := range []string{"oltp", "olap"} {
mcmp.Run(mode, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode))

mcmp.Exec("select id1, id2 from t1 where exists (select id from tbl where nonunq_col = t1.id2) order by id1")
})
}
}
7 changes: 5 additions & 2 deletions go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ type fakePrimitive struct {
// sendErr is sent at the end of the stream if it's set.
sendErr error

log []string
noLog bool
log []string

allResultsInOneCall bool

Expand Down Expand Up @@ -85,7 +86,9 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar
}

func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields))
if !f.noLog {
f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields))
}
if f.results == nil {
return f.sendErr
}
Expand Down
13 changes: 8 additions & 5 deletions go/vt/vtgate/engine/semi_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package engine

import (
"context"
"sync/atomic"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -62,24 +63,26 @@ func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma

// TryStreamExecute performs a streaming exec.
func (jn *SemiJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
joinVars := make(map[string]*querypb.BindVariable)
err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error {
joinVars := make(map[string]*querypb.BindVariable)
result := &sqltypes.Result{Fields: lresult.Fields}
for _, lrow := range lresult.Rows {
for k, col := range jn.Vars {
joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
}
rowAdded := false
var rowAdded atomic.Bool
err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), false, func(rresult *sqltypes.Result) error {
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
if len(rresult.Rows) > 0 && !rowAdded {
result.Rows = append(result.Rows, lrow)
rowAdded = true
if len(rresult.Rows) > 0 {
rowAdded.Store(true)
}
return nil
})
if err != nil {
return err
}
if rowAdded.Load() {
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
result.Rows = append(result.Rows, lrow)
}
}
return callback(result)
})
Expand Down
79 changes: 79 additions & 0 deletions go/vt/vtgate/engine/semi_join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package engine

import (
"context"
"sync"
"testing"

"vitess.io/vitess/go/test/utils"
Expand Down Expand Up @@ -159,3 +160,81 @@ func TestSemiJoinStreamExecute(t *testing.T) {
"4|d|dd",
))
}

// TestSemiJoinStreamExecuteParallelExecution tests SemiJoin stream execution with parallel execution
// to ensure we have no data races.
func TestSemiJoinStreamExecuteParallelExecution(t *testing.T) {
leftPrim := &fakePrimitive{
results: []*sqltypes.Result{
sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"1|a|aa",
"2|b|bb",
), sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"3|c|cc",
"4|d|dd",
),
},
async: true,
}
rightFields := sqltypes.MakeTestFields(
"col4|col5|col6",
"int64|varchar|varchar",
)
rightPrim := &fakePrimitive{
// we'll return non-empty results for rows 2 and 4
results: sqltypes.MakeTestStreamingResults(rightFields,
"4|d|dd",
"---",
"---",
"5|e|ee",
"6|f|ff",
"7|g|gg",
),
async: true,
noLog: true,
}

jn := &SemiJoin{
Left: leftPrim,
Right: rightPrim,
Vars: map[string]int{
"bv": 1,
},
}
var res *sqltypes.Result
var mu sync.Mutex
err := jn.TryStreamExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{}, true, func(result *sqltypes.Result) error {
mu.Lock()
defer mu.Unlock()
if res == nil {
res = result
} else {
res.Rows = append(res.Rows, result.Rows...)
}
return nil
})
require.NoError(t, err)
leftPrim.ExpectLog(t, []string{
`StreamExecute true`,
})
// We'll get all the rows back in left primitive, since we're returning the same set of rows
// from the right primitive that makes them all qualify.
expectResultAnyOrder(t, res, sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"1|a|aa",
"2|b|bb",
"3|c|cc",
"4|d|dd",
))
}
Loading