Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal committed Jan 2, 2025
1 parent e749db8 commit 7ff72d5
Show file tree
Hide file tree
Showing 28 changed files with 465 additions and 543 deletions.
4 changes: 2 additions & 2 deletions go/cmd/vtgateclienttest/services/callerid.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ func (c *callerIDClient) checkCallerID(ctx context.Context, received string) (bo
return true, fmt.Errorf("SUCCESS: callerid matches")
}

func (c *callerIDClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
func (c *callerIDClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, prepared bool) (*vtgatepb.Session, *sqltypes.Result, error) {
if ok, err := c.checkCallerID(ctx, sql); ok {
return session, nil, err
}
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables)
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables, prepared)
}

func (c *callerIDClient) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVariablesList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) {
Expand Down
4 changes: 2 additions & 2 deletions go/cmd/vtgateclienttest/services/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func echoQueryResult(vals map[string]any) *sqltypes.Result {
return qr
}

func (c *echoClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
func (c *echoClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, prepared bool) (*vtgatepb.Session, *sqltypes.Result, error) {
if strings.HasPrefix(sql, EchoPrefix) {
return session, echoQueryResult(map[string]any{
"callerId": callerid.EffectiveCallerIDFromContext(ctx),
Expand All @@ -107,7 +107,7 @@ func (c *echoClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLCo
"session": session,
}), nil
}
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables)
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables, prepared)
}

func (c *echoClient) StreamExecute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (*vtgatepb.Session, error) {
Expand Down
4 changes: 2 additions & 2 deletions go/cmd/vtgateclienttest/services/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,14 @@ func trimmedRequestToError(received string) error {
}
}

func (c *errorClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
func (c *errorClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, prepared bool) (*vtgatepb.Session, *sqltypes.Result, error) {
if err := requestToPartialError(sql, session); err != nil {
return session, nil, err
}
if err := requestToError(sql); err != nil {
return session, nil, err
}
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables)
return c.fallbackClient.Execute(ctx, mysqlCtx, session, sql, bindVariables, prepared)
}

func (c *errorClient) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVariablesList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) {
Expand Down
4 changes: 2 additions & 2 deletions go/cmd/vtgateclienttest/services/fallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ func newFallbackClient(fallback vtgateservice.VTGateService) fallbackClient {
return fallbackClient{fallback: fallback}
}

func (c fallbackClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
return c.fallback.Execute(ctx, mysqlCtx, session, sql, bindVariables)
func (c fallbackClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, prepared bool) (*vtgatepb.Session, *sqltypes.Result, error) {
return c.fallback.Execute(ctx, mysqlCtx, session, sql, bindVariables, false)
}

func (c fallbackClient) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVariablesList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) {
Expand Down
2 changes: 1 addition & 1 deletion go/cmd/vtgateclienttest/services/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func newTerminalClient() *terminalClient {
return &terminalClient{}
}

func (c *terminalClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
func (c *terminalClient) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, prepared bool) (*vtgatepb.Session, *sqltypes.Result, error) {
if sql == "quit://" {
log.Fatal("Received quit:// query. Going down.")
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vitessdriver/fakeserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (q *queryExecute) Equal(q2 *queryExecute) bool {
}

// Execute is part of the VTGateService interface
func (f *fakeVTGateService) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (*vtgatepb.Session, *sqltypes.Result, error) {
func (f *fakeVTGateService) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable, prepared bool) (*vtgatepb.Session, *sqltypes.Result, error) {
execCase, ok := execMap[sql]
if !ok {
return session, nil, fmt.Errorf("no match for: %s", sql)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtexplain/vtexplain_vtgate.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func (vte *VTExplain) vtgateExecute(sql string) ([]*engine.Plan, map[string]*Tab
// This will ensure that the commit/rollback order is predictable.
vte.sortShardSession()

_, err := vte.vtgateExecutor.Execute(context.Background(), nil, "VtexplainExecute", econtext.NewSafeSession(vte.vtgateSession), sql, nil)
_, err := vte.vtgateExecutor.Execute(context.Background(), nil, "VtexplainExecute", econtext.NewSafeSession(vte.vtgateSession), sql, nil, false)
if err != nil {
for _, tc := range vte.explainTopo.TabletConns {
tc.tabletQueries = nil
Expand Down
10 changes: 5 additions & 5 deletions go/vt/vtgate/autocommit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ func TestAutocommitTransactionStarted(t *testing.T) {

// single shard query - no savepoint needed
sql := "update `user` set a = 2 where id = 1"
_, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
_, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{}, false)
require.NoError(t, err)
require.Len(t, sbc1.Queries, 1)
require.Equal(t, sql, sbc1.Queries[0].Sql)
Expand All @@ -394,7 +394,7 @@ func TestAutocommitTransactionStarted(t *testing.T) {
// multi shard query - savepoint needed
sql = "update `user` set a = 2 where id in (1, 4)"
expectedSql := "update `user` set a = 2 where id in ::__vals"
_, err = executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
_, err = executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{}, false)
require.NoError(t, err)
require.Len(t, sbc1.Queries, 2)
require.Contains(t, sbc1.Queries[0].Sql, "savepoint")
Expand All @@ -413,7 +413,7 @@ func TestAutocommitDirectTarget(t *testing.T) {
}
sql := "insert into `simple`(val) values ('val')"

_, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
_, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{}, false)
require.NoError(t, err)

assertQueries(t, sbclookup, []*querypb.BoundQuery{{
Expand All @@ -434,7 +434,7 @@ func TestAutocommitDirectRangeTarget(t *testing.T) {
}
sql := "delete from sharded_user_msgs limit 1000"

_, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
_, err := executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{}, false)
require.NoError(t, err)

assertQueries(t, sbc1, []*querypb.BoundQuery{{
Expand All @@ -451,5 +451,5 @@ func autocommitExec(executor *Executor, sql string) (*sqltypes.Result, error) {
TransactionMode: vtgatepb.TransactionMode_MULTI,
}

return executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{})
return executor.Execute(context.Background(), nil, "TestExecute", econtext.NewSafeSession(session), sql, map[string]*querypb.BindVariable{}, false)
}
2 changes: 2 additions & 0 deletions go/vt/vtgate/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func BenchmarkWithNormalizer(b *testing.B) {
},
benchQuery,
nil,
false,
)
if err != nil {
panic(err)
Expand All @@ -92,6 +93,7 @@ func BenchmarkWithoutNormalizer(b *testing.B) {
},
benchQuery,
nil,
false,
)
if err != nil {
panic(err)
Expand Down
121 changes: 69 additions & 52 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,14 @@ func NewExecutor(
}

// Execute executes a non-streaming query.
func (e *Executor) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, method string, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (result *sqltypes.Result, err error) {
func (e *Executor) Execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, method string, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable, prepared bool) (result *sqltypes.Result, err error) {
span, ctx := trace.NewSpan(ctx, "executor.Execute")
span.Annotate("method", method)
trace.AnnotateSQL(span, sqlparser.Preview(sql))
defer span.Finish()

logStats := logstats.NewLogStats(ctx, method, sql, safeSession.GetSessionUUID(), bindVars)
stmtType, result, err := e.execute(ctx, mysqlCtx, safeSession, sql, bindVars, logStats)
stmtType, result, err := e.execute(ctx, mysqlCtx, safeSession, sql, bindVars, logStats, prepared)
logStats.Error = err
if result == nil {
saveSessionStats(safeSession, stmtType, 0, 0, err)
Expand Down Expand Up @@ -372,7 +372,7 @@ func (e *Executor) StreamExecute(
return err
}

err = e.newExecute(ctx, mysqlCtx, safeSession, sql, bindVars, logStats, resultHandler, srr.storeResultStats)
err = e.newExecute(ctx, mysqlCtx, safeSession, sql, bindVars, logStats, false, resultHandler, srr.storeResultStats)

logStats.Error = err
saveSessionStats(safeSession, srr.stmtType, srr.rowsAffected, srr.rowsReturned, err)
Expand Down Expand Up @@ -424,11 +424,11 @@ func saveSessionStats(safeSession *econtext.SafeSession, stmtType sqlparser.Stat
}
}

func (e *Executor) execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) (sqlparser.StatementType, *sqltypes.Result, error) {
func (e *Executor) execute(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats, prepared bool) (sqlparser.StatementType, *sqltypes.Result, error) {
var err error
var qr *sqltypes.Result
var stmtType sqlparser.StatementType
err = e.newExecute(ctx, mysqlCtx, safeSession, sql, bindVars, logStats, func(ctx context.Context, plan *engine.Plan, vc *econtext.VCursorImpl, bindVars map[string]*querypb.BindVariable, time time.Time) error {
err = e.newExecute(ctx, mysqlCtx, safeSession, sql, bindVars, logStats, prepared, func(ctx context.Context, plan *engine.Plan, vc *econtext.VCursorImpl, bindVars map[string]*querypb.BindVariable, time time.Time) error {
stmtType = plan.Type
qr, err = e.executePlan(ctx, safeSession, plan, vc, bindVars, logStats, time)
return err
Expand Down Expand Up @@ -1087,56 +1087,61 @@ func (e *Executor) getPlan(
reservedVars *sqlparser.ReservedVars,
allowParameterization bool,
logStats *logstats.LogStats,
prepared bool,
) (*engine.Plan, error) {
if e.VSchema() == nil {
return nil, vterrors.VT13001("vschema not initialized")
}

qh, err := sqlparser.BuildQueryHints(stmt)
if err != nil {
return nil, err
}
vcursor.SetIgnoreMaxMemoryRows(qh.IgnoreMaxMemoryRows)
vcursor.SetConsolidator(qh.Consolidator)
vcursor.SetWorkloadName(qh.Workload)
vcursor.UpdateForeignKeyChecksState(qh.ForeignKeyChecks)
vcursor.SetPriority(qh.Priority)
vcursor.SetExecQueryTimeout(qh.Timeout)

setVarComment, err := prepareSetVarComment(vcursor, stmt)
if err != nil {
return nil, err
}
var bindVarNeeds *sqlparser.BindVarNeeds
if !prepared {
qh, err := sqlparser.BuildQueryHints(stmt)
if err != nil {
return nil, err
}
vcursor.SetIgnoreMaxMemoryRows(qh.IgnoreMaxMemoryRows)
vcursor.SetConsolidator(qh.Consolidator)
vcursor.SetWorkloadName(qh.Workload)
vcursor.UpdateForeignKeyChecksState(qh.ForeignKeyChecks)
vcursor.SetPriority(qh.Priority)
vcursor.SetExecQueryTimeout(qh.Timeout)

setVarComment, err := prepareSetVarComment(vcursor, stmt)
if err != nil {
return nil, err
}

// Normalize if possible
shouldNormalize := e.canNormalizeStatement(stmt, setVarComment)
parameterize := allowParameterization && shouldNormalize

rewriteASTResult, err := sqlparser.PrepareAST(
stmt,
reservedVars,
bindVars,
parameterize,
vcursor.GetKeyspace(),
vcursor.SafeSession.GetSelectLimit(),
setVarComment,
vcursor.GetSystemVariablesCopy(),
vcursor.GetForeignKeyChecksState(),
vcursor,
)
if err != nil {
return nil, err
}
stmt = rewriteASTResult.AST
bindVarNeeds := rewriteASTResult.BindVarNeeds
if shouldNormalize {
query = sqlparser.String(stmt)
// Normalize if possible
shouldNormalize := e.canNormalizeStatement(stmt, setVarComment)
parameterize := allowParameterization && shouldNormalize

rewriteASTResult, err := sqlparser.PrepareAST(
stmt,
reservedVars,
bindVars,
parameterize,
vcursor.GetKeyspace(),
vcursor.SafeSession.GetSelectLimit(),
setVarComment,
vcursor.GetSystemVariablesCopy(),
vcursor.GetForeignKeyChecksState(),
vcursor,
)
if err != nil {
return nil, err
}
stmt = rewriteASTResult.AST
bindVarNeeds = rewriteASTResult.BindVarNeeds
// if shouldNormalize {
// // log.Infof("before query: %v", query)
// // query = sqlparser.String(stmt)
// // log.Infof("after query: %v", query)
// }
}

logStats.SQL = comments.Leading + query + comments.Trailing
logStats.BindVariables = sqltypes.CopyBindVariables(bindVars)

return e.cacheAndBuildStatement(ctx, vcursor, query, stmt, reservedVars, bindVarNeeds, logStats)
return e.cacheAndBuildStatement(ctx, vcursor, query, stmt, reservedVars, bindVarNeeds, logStats, prepared)
}

func (e *Executor) hashPlan(ctx context.Context, vcursor *econtext.VCursorImpl, query string) PlanCacheKey {
Expand Down Expand Up @@ -1179,10 +1184,15 @@ func (e *Executor) cacheAndBuildStatement(
reservedVars *sqlparser.ReservedVars,
bindVarNeeds *sqlparser.BindVarNeeds,
logStats *logstats.LogStats,
prepared bool,
) (*engine.Plan, error) {
planCachable := sqlparser.CachePlan(stmt) && vcursor.CachePlan()
if planCachable {
planCachable := true
if !prepared {
planCachable = sqlparser.CachePlan(stmt) && vcursor.CachePlan()
}
if prepared || planCachable {
planKey := e.hashPlan(ctx, vcursor, query)
// log.Infof("Plan cache key: %v for query: %v", planKey, query)

var plan *engine.Plan
var err error
Expand All @@ -1191,6 +1201,9 @@ func (e *Executor) cacheAndBuildStatement(
})
return plan, err
}
// if stmt == nil {
// panic("unexpected for query: " + query + " when prepared: " + strconv.FormatBool(prepared) + " planCachable: " + strconv.FormatBool(planCachable))
// }
return e.buildStatement(ctx, vcursor, query, stmt, reservedVars, bindVarNeeds)
}

Expand Down Expand Up @@ -1386,8 +1399,11 @@ func (e *Executor) prepare(ctx context.Context, safeSession *econtext.SafeSessio

switch stmtType {
case sqlparser.StmtSelect, sqlparser.StmtShow:
return e.handlePrepare(ctx, safeSession, sql, bindVars, logStats)
case sqlparser.StmtDDL, sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtRollback, sqlparser.StmtSet, sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete,
return e.handlePrepare(ctx, safeSession, sql, bindVars, logStats, false)
case sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete:
_, err := e.handlePrepare(ctx, safeSession, sql, bindVars, logStats, true)
return nil, err
case sqlparser.StmtDDL, sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtRollback, sqlparser.StmtSet,
sqlparser.StmtUse, sqlparser.StmtOther, sqlparser.StmtAnalyze, sqlparser.StmtComment, sqlparser.StmtExplain, sqlparser.StmtFlush, sqlparser.StmtKill:
return nil, nil
}
Expand Down Expand Up @@ -1425,7 +1441,7 @@ func (e *Executor) initVConfig(warnOnShardedOnly bool, pv plancontext.PlannerVer
}
}

func (e *Executor) handlePrepare(ctx context.Context, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats) ([]*querypb.Field, error) {
func (e *Executor) handlePrepare(ctx context.Context, safeSession *econtext.SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *logstats.LogStats, skipFields bool) ([]*querypb.Field, error) {
query, comments := sqlparser.SplitMarginComments(sql)

vcursor, _ := econtext.NewVCursorImpl(safeSession, comments, e, logStats, e.vm, e.VSchema(), e.resolver.resolver, e.serv, nullResultsObserver{}, e.vConfig)
Expand All @@ -1435,7 +1451,7 @@ func (e *Executor) handlePrepare(ctx context.Context, safeSession *econtext.Safe
return nil, err
}

plan, err := e.getPlan(ctx, vcursor, sql, stmt, comments, bindVars, reservedVars /* parameterize */, false, logStats)
plan, err := e.getPlan(ctx, vcursor, sql, stmt, comments, bindVars, reservedVars /* parameterize */, false, logStats, false)
execStart := time.Now()
logStats.PlanTime = execStart.Sub(logStats.StartTime)

Expand All @@ -1445,7 +1461,7 @@ func (e *Executor) handlePrepare(ctx context.Context, safeSession *econtext.Safe
}

err = e.addNeededBindVars(vcursor, plan.BindVarNeeds, bindVars, safeSession)
if err != nil {
if err != nil || skipFields {
logStats.Error = err
return nil, err
}
Expand Down Expand Up @@ -1621,6 +1637,7 @@ func (e *Executor) PlanPrepareStmt(ctx context.Context, vcursor *econtext.VCurso
reservedVars, /* normalize */
false,
lStats,
false,
)
if err != nil {
return nil, nil, err
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/executor_ddl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestDDLFlags(t *testing.T) {
session := econtext.NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded})
enableDirectDDL.Set(testcase.enableDirectDDL)
enableOnlineDDL.Set(testcase.enableOnlineDDL)
_, err := executor.Execute(ctx, nil, "TestDDLFlags", session, testcase.sql, nil)
_, err := executor.Execute(ctx, nil, "TestDDLFlags", session, testcase.sql, nil, false)
if testcase.wantErr {
require.EqualError(t, err, testcase.err)
} else {
Expand Down
Loading

0 comments on commit 7ff72d5

Please sign in to comment.