From 13c21fdf879e454136e73d245f8b975941dacc83 Mon Sep 17 00:00:00 2001 From: "vitess-bot[bot]" <108069721+vitess-bot[bot]@users.noreply.github.com> Date: Fri, 10 Nov 2023 23:04:21 +0530 Subject: [PATCH] [release-16.0] vtgate/engine: Fix race condition in join logic (#14435) (#14439) Signed-off-by: Dirkjan Bussink Co-authored-by: vitess-bot[bot] <108069721+vitess-bot[bot]@users.noreply.github.com> --- go/vt/vtgate/engine/join.go | 39 ++++++++++++++++------- go/vt/vtgate/engine/scalar_aggregation.go | 2 +- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index 6343789aef8..fc44366ff65 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "strings" + "sync" "sync/atomic" "vitess.io/vitess/go/sqltypes" @@ -96,22 +97,31 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st // TryStreamExecute performs a streaming exec. func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - var fieldNeeded atomic.Bool - fieldNeeded.Store(wantfields) - err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, fieldNeeded.Load(), func(lresult *sqltypes.Result) error { + var mu sync.Mutex + // We need to use this atomic since we're also reading this + // value outside of it being locked with the mu lock. + // This is still racy, but worst case it means that we may + // retrieve the right hand side fields twice instead of once. + var fieldsSent atomic.Bool + fieldsSent.Store(!wantfields) + err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error { joinVars := make(map[string]*querypb.BindVariable) for _, lrow := range lresult.Rows { for k, col := range jn.Vars { joinVars[k] = sqltypes.ValueBindVariable(lrow[col]) } var rowSent atomic.Bool - err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), fieldNeeded.Load(), func(rresult *sqltypes.Result) error { + err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), !fieldsSent.Load(), func(rresult *sqltypes.Result) error { + // This needs to be locking since it's not safe to just use + // fieldsSent. This is because we can't have a race between + // checking fieldsSent and then actually calling the callback + // and in parallel another goroutine doing the same. That + // can lead to out of order execution of the callback. So the callback + // itself and the check need to be covered by the same lock. + mu.Lock() + defer mu.Unlock() result := &sqltypes.Result{} - if fieldNeeded.Load() { - // This code is currently unreachable because the first result - // will always be just the field info, which will cause the outer - // wantfields code path to be executed. But this may change in the future. - fieldNeeded.Store(false) + if fieldsSent.CompareAndSwap(false, true) { result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols) } for _, rrow := range rresult.Rows { @@ -135,8 +145,15 @@ func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars return callback(result) } } - if fieldNeeded.Load() { - fieldNeeded.Store(false) + // This needs to be locking since it's not safe to just use + // fieldsSent. This is because we can't have a race between + // checking fieldsSent and then actually calling the callback + // and in parallel another goroutine doing the same. That + // can lead to out of order execution of the callback. So the callback + // itself and the check need to be covered by the same lock. + mu.Lock() + defer mu.Unlock() + if fieldsSent.CompareAndSwap(false, true) { for k := range jn.Vars { joinVars[k] = sqltypes.NullBindVariable } diff --git a/go/vt/vtgate/engine/scalar_aggregation.go b/go/vt/vtgate/engine/scalar_aggregation.go index 2b66073ac5e..5dfa8be7b51 100644 --- a/go/vt/vtgate/engine/scalar_aggregation.go +++ b/go/vt/vtgate/engine/scalar_aggregation.go @@ -133,8 +133,8 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor var current []sqltypes.Value var curDistincts []sqltypes.Value var fields []*querypb.Field - fieldsSent := false var mu sync.Mutex + fieldsSent := !wantfields err := vcursor.StreamExecutePrimitive(ctx, sa.Input, bindVars, wantfields, func(result *sqltypes.Result) error { // as the underlying primitive call is not sync