Skip to content

Commit

Permalink
feat: make sure to check last insert id in transactions, and don't fi…
Browse files Browse the repository at this point in the history
…nish early when limiting row count

Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Dec 18, 2024
1 parent 6c1a3d2 commit 880284d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 36 deletions.
28 changes: 11 additions & 17 deletions go/vt/vtgate/engine/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ package engine
import (
"context"
"fmt"
"io"
"strconv"
"sync"

"vitess.io/vitess/go/vt/log"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/evalengine"

Expand Down Expand Up @@ -105,14 +106,15 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
err = vcursor.StreamExecutePrimitive(ctx, l.Input, bindVars, wantfields, func(qr *sqltypes.Result) error {
mu.Lock()
defer mu.Unlock()
log.Errorf("LastInsertID: %d InsertIDChanged %t\n", qr.InsertID, qr.InsertIDChanged)
if wantfields && len(qr.Fields) != 0 {
if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil {
return err
}
}
inputSize := len(qr.Rows)
if inputSize == 0 {
return nil
return callback(qr)
}

// we've still not seen all rows we need to see before we can return anything to the client
Expand All @@ -126,30 +128,22 @@ func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars
offset = 0
}

if count == 0 {
return io.EOF
}

// reduce count till 0.
result := &sqltypes.Result{Rows: qr.Rows}
resultSize := len(result.Rows)
resultSize := len(qr.Rows)
if count > resultSize {
count -= resultSize
return callback(result)
return callback(qr)
}
result.Rows = result.Rows[:count]

qr.Rows = qr.Rows[:count]
count = 0
if err := callback(result); err != nil {
if err := callback(qr); err != nil {
return err
}
return io.EOF
})

if err == io.EOF {
// We may get back the EOF we returned in the callback.
// If so, suppress it.
return nil
}
})

if err != nil {
return err
}
Expand Down
49 changes: 30 additions & 19 deletions go/vt/vttablet/tabletserver/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1202,16 +1202,8 @@ func (qre *QueryExecutor) fetchLastInsertID(ctx context.Context, conn *connpool.

func (qre *QueryExecutor) execStreamSQL(conn *connpool.PooledConn, isTransaction bool, sql string, callback func(*sqltypes.Result) error) error {
span, ctx := trace.NewSpan(qre.ctx, "QueryExecutor.execStreamSQL")
defer span.Finish()
trace.AnnotateSQL(span, sqlparser.Preview(sql))
callBackClosingSpan := func(result *sqltypes.Result) error {
defer span.Finish()

// if err := qre.fetchLastInsertID(ctx, conn.Conn, result); err != nil {
// return err
// }

return callback(result)
}

start := time.Now()
defer qre.logStats.AddRewrittenSQL(sql, start)
Expand All @@ -1222,28 +1214,47 @@ func (qre *QueryExecutor) execStreamSQL(conn *connpool.PooledConn, isTransaction
// This change will ensure that long-running streaming stateful queries get gracefully shutdown during ServingTypeChange
// once their grace period is over.
qd := NewQueryDetail(qre.logStats.Ctx, conn.Conn)
// if err := qre.resetLastInsertIDIfNeeded(ctx, conn.Conn); err != nil {
// return err
// }

if err := qre.resetLastInsertIDIfNeeded(ctx, conn.Conn); err != nil {
return err
}

lastInsertIDSet := false
cb := func(result *sqltypes.Result) error {
if result != nil && result.InsertID != 0 {
lastInsertIDSet = true
}
return callback(result)
}

var err error
if isTransaction {
err := qre.tsv.statefulql.Add(qd)
err = qre.tsv.statefulql.Add(qd)
if err != nil {
return err
}
defer qre.tsv.statefulql.Remove(qd)
err = conn.Conn.StreamOnce(ctx, sql, callBackClosingSpan, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options))
err = conn.Conn.StreamOnce(ctx, sql, cb, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options))
} else {
err = qre.tsv.olapql.Add(qd)
if err != nil {
return err
}
return nil
defer qre.tsv.olapql.Remove(qd)
err = conn.Conn.Stream(ctx, sql, cb, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options))
}
err := qre.tsv.olapql.Add(qd)
if err != nil {

if err != nil || lastInsertIDSet || !qre.options.GetFetchLastInsertId() {
return err
}
res := &sqltypes.Result{}
if err = qre.fetchLastInsertID(ctx, conn.Conn, res); err != nil {
return err
}
defer qre.tsv.olapql.Remove(qd)
return conn.Conn.Stream(ctx, sql, callBackClosingSpan, allocStreamResult, int(qre.tsv.qe.streamBufferSize.Load()), sqltypes.IncludeFieldsOrDefault(qre.options))
if res.InsertIDChanged {
return callback(res)
}
return nil
}

func (qre *QueryExecutor) recordUserQuery(queryType string, duration int64) {
Expand Down

0 comments on commit 880284d

Please sign in to comment.