diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index 3f0e896bdf5..eeba0097798 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -520,3 +520,23 @@ func TestTimeZones(t *testing.T) { }) } } + +// TestSemiJoin tests that the semi join works as intended. +func TestSemiJoin(t *testing.T) { + mcmp, closer := start(t) + defer closer() + + for i := 1; i <= 1000; i++ { + mcmp.Exec(fmt.Sprintf("insert into t1(id1, id2) values (%d, %d)", i, 2*i)) + mcmp.Exec(fmt.Sprintf("insert into tbl(id, unq_col, nonunq_col) values (%d, %d, %d)", i, 2*i, 3*i)) + } + + // Test that the semi join works as intended + for _, mode := range []string{"oltp", "olap"} { + mcmp.Run(mode, func(mcmp *utils.MySQLCompare) { + utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode)) + + mcmp.Exec("select id1, id2 from t1 where exists (select id from tbl where nonunq_col = t1.id2) order by id1") + }) + } +} diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index b878c1931c0..f3ab5ad5336 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -40,7 +40,8 @@ type fakePrimitive struct { // sendErr is sent at the end of the stream if it's set. sendErr error - log []string + noLog bool + log []string allResultsInOneCall bool @@ -85,7 +86,9 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar } func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields)) + if !f.noLog { + f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields)) + } if f.results == nil { return f.sendErr } diff --git a/go/vt/vtgate/engine/semi_join.go b/go/vt/vtgate/engine/semi_join.go index f0dd0d09033..b5bc74a5941 100644 --- a/go/vt/vtgate/engine/semi_join.go +++ b/go/vt/vtgate/engine/semi_join.go @@ -18,6 +18,7 @@ package engine import ( "context" + "sync/atomic" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -62,24 +63,26 @@ func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma // TryStreamExecute performs a streaming exec. func (jn *SemiJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - joinVars := make(map[string]*querypb.BindVariable) err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error { + joinVars := make(map[string]*querypb.BindVariable) result := &sqltypes.Result{Fields: lresult.Fields} for _, lrow := range lresult.Rows { for k, col := range jn.Vars { joinVars[k] = sqltypes.ValueBindVariable(lrow[col]) } - rowAdded := false + var rowAdded atomic.Bool err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), false, func(rresult *sqltypes.Result) error { - if len(rresult.Rows) > 0 && !rowAdded { - result.Rows = append(result.Rows, lrow) - rowAdded = true + if len(rresult.Rows) > 0 { + rowAdded.Store(true) } return nil }) if err != nil { return err } + if rowAdded.Load() { + result.Rows = append(result.Rows, lrow) + } } return callback(result) }) diff --git a/go/vt/vtgate/engine/semi_join_test.go b/go/vt/vtgate/engine/semi_join_test.go index 8fee0490415..a103b0686b2 100644 --- a/go/vt/vtgate/engine/semi_join_test.go +++ b/go/vt/vtgate/engine/semi_join_test.go @@ -18,6 +18,7 @@ package engine import ( "context" + "sync" "testing" "vitess.io/vitess/go/test/utils" @@ -159,3 +160,81 @@ func TestSemiJoinStreamExecute(t *testing.T) { "4|d|dd", )) } + +// TestSemiJoinStreamExecuteParallelExecution tests SemiJoin stream execution with parallel execution +// to ensure we have no data races. +func TestSemiJoinStreamExecuteParallelExecution(t *testing.T) { + leftPrim := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "1|a|aa", + "2|b|bb", + ), sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "3|c|cc", + "4|d|dd", + ), + }, + async: true, + } + rightFields := sqltypes.MakeTestFields( + "col4|col5|col6", + "int64|varchar|varchar", + ) + rightPrim := &fakePrimitive{ + // we'll return non-empty results for rows 2 and 4 + results: sqltypes.MakeTestStreamingResults(rightFields, + "4|d|dd", + "---", + "---", + "5|e|ee", + "6|f|ff", + "7|g|gg", + ), + async: true, + noLog: true, + } + + jn := &SemiJoin{ + Left: leftPrim, + Right: rightPrim, + Vars: map[string]int{ + "bv": 1, + }, + } + var res *sqltypes.Result + var mu sync.Mutex + err := jn.TryStreamExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{}, true, func(result *sqltypes.Result) error { + mu.Lock() + defer mu.Unlock() + if res == nil { + res = result + } else { + res.Rows = append(res.Rows, result.Rows...) + } + return nil + }) + require.NoError(t, err) + leftPrim.ExpectLog(t, []string{ + `StreamExecute true`, + }) + // We'll get all the rows back in left primitive, since we're returning the same set of rows + // from the right primitive that makes them all qualify. + expectResultAnyOrder(t, res, sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "1|a|aa", + "2|b|bb", + "3|c|cc", + "4|d|dd", + )) +}