Skip to content

Commit

Permalink
feat: handle last_insert_id with arguments in the evalengine
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Dec 19, 2024
1 parent bea0515 commit b6295b7
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 1 deletion.
3 changes: 2 additions & 1 deletion go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ func TestSetAndGetLastInsertID(t *testing.T) {
"update t1 set id2 = last_insert_id(%d) where id1 = 2",
"update t1 set id2 = 88 where id1 = last_insert_id(%d)",
"delete from t1 where id1 = last_insert_id(%d)",
"select id2, last_insert_id(count(*)) from t1 where %d group by id2",
}

for _, workload := range []string{"olap", "oltp"} {
Expand All @@ -175,7 +176,7 @@ func TestSetAndGetLastInsertID(t *testing.T) {
require.NoError(t, err)
}

// Insert a row for UPDATE tests
// Insert a few rows for UPDATE tests
mcmp.Exec("insert into t1 (id1, id2) values (1, 10)")

for _, query := range queries {
Expand Down
3 changes: 3 additions & 0 deletions go/vt/vtgate/engine/fake_vcursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,9 @@ func (t *loggingVCursor) RecordMirrorStats(sourceExecTime, targetExecTime time.D
}
}

func (t *loggingVCursor) SetLastInsertID(id uint64) {}
func (t *noopVCursor) SetLastInsertID(id uint64) {}

func (t *noopVCursor) VExplainLogging() {}
func (t *noopVCursor) DisableLogging() {}
func (t *noopVCursor) GetVExplainLogs() []ExecuteEntry {
Expand Down
2 changes: 2 additions & 0 deletions go/vt/vtgate/engine/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ type (

// RecordMirrorStats is used to record stats about a mirror query.
RecordMirrorStats(time.Duration, time.Duration, error)

SetLastInsertID(uint64)
}

// SessionActions gives primitives ability to interact with the session state
Expand Down
21 changes: 21 additions & 0 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5138,3 +5138,24 @@ func (asm *assembler) Introduce(offset int, t sqltypes.Type, col collations.Type
return 1
}, "INTRODUCE (SP-1)")
}

func (asm *assembler) Fn_LAST_INSERT_ID() {
asm.emit(func(env *ExpressionEnv) int {
arg := env.vm.stack[env.vm.sp-1].(*evalUint64)
env.VCursor().SetLastInsertID(arg.u)
return 1
}, "FN LAST_INSERT_ID UINT64(SP-1)")
}

func (asm *assembler) Fn_LAST_INSERT_ID_NULL() {
asm.emit(func(env *ExpressionEnv) int {
env.VCursor().SetLastInsertID(0)
return 1
}, "FN LAST_INSERT_ID NULL")
}

func (asm *assembler) addJump(end *jump) {
asm.emit(func(env *ExpressionEnv) int {
return end.offset()
}, "JUMP")
}
93 changes: 93 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,99 @@ func TestBindVarLiteral(t *testing.T) {
}
}

type testVcursor struct {
lastInsertID *uint64
env *vtenv.Environment
}

func (t *testVcursor) TimeZone() *time.Location {
return time.UTC
}

func (t *testVcursor) GetKeyspace() string {
return "apa"
}

func (t *testVcursor) SQLMode() string {
return "oltp"
}

func (t *testVcursor) Environment() *vtenv.Environment {
return t.env
}

func (t *testVcursor) SetLastInsertID(id uint64) {
t.lastInsertID = &id
}

var _ evalengine.VCursor = (*testVcursor)(nil)

func TestLastInsertID(t *testing.T) {
var testCases = []struct {
expression string
result uint64
missing bool
}{
{
expression: `last_insert_id(1)`,
result: 1,
}, {
expression: `12`,
missing: true,
}, {
expression: `last_insert_id(666)`,
result: 666,
}, {
expression: `last_insert_id(null)`,
result: 0,
},
}

venv := vtenv.NewTestEnv()
for _, tc := range testCases {
t.Run(tc.expression, func(t *testing.T) {
expr, err := venv.Parser().ParseExpr(tc.expression)
require.NoError(t, err)

cfg := &evalengine.Config{
Collation: collations.CollationUtf8mb4ID,
NoConstantFolding: true,
NoCompilation: false,
Environment: venv,
}
t.Run("eval", func(t *testing.T) {
cfg.NoCompilation = true
runTest(t, expr, cfg, tc)
})
t.Run("compiled", func(t *testing.T) {
cfg.NoCompilation = false
runTest(t, expr, cfg, tc)
})
})
}
}

func runTest(t *testing.T, expr sqlparser.Expr, cfg *evalengine.Config, tc struct {
expression string
result uint64
missing bool
}) {
converted, err := evalengine.Translate(expr, cfg)
require.NoError(t, err)

vc := &testVcursor{env: vtenv.NewTestEnv()}
env := evalengine.NewExpressionEnv(context.Background(), nil, vc)

_, err = env.Evaluate(converted)
require.NoError(t, err)
if tc.missing {
require.Nil(t, vc.lastInsertID)
} else {
require.NotNil(t, vc.lastInsertID)
require.Equal(t, tc.result, *vc.lastInsertID)
}
}

func TestCompilerNonConstant(t *testing.T) {
var testCases = []struct {
expression string
Expand Down
2 changes: 2 additions & 0 deletions go/vt/vtgate/evalengine/expr_env.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type VCursor interface {
GetKeyspace() string
SQLMode() string
Environment() *vtenv.Environment
SetLastInsertID(id uint64)
}

type (
Expand Down Expand Up @@ -140,6 +141,7 @@ func (e *emptyVCursor) GetKeyspace() string {
func (e *emptyVCursor) SQLMode() string {
return config.DefaultSQLMode
}
func (e *emptyVCursor) SetLastInsertID(_ uint64) {}

func NewEmptyVCursor(env *vtenv.Environment, tz *time.Location) VCursor {
return &emptyVCursor{env: env, tz: tz}
Expand Down
40 changes: 40 additions & 0 deletions go/vt/vtgate/evalengine/fn_misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ type (
builtinUUIDToBin struct {
CallExpr
}

builtinLastInsertID struct {
CallExpr
}
)

var _ IR = (*builtinInetAton)(nil)
Expand All @@ -95,6 +99,7 @@ var _ IR = (*builtinBinToUUID)(nil)
var _ IR = (*builtinIsUUID)(nil)
var _ IR = (*builtinUUID)(nil)
var _ IR = (*builtinUUIDToBin)(nil)
var _ IR = (*builtinLastInsertID)(nil)

func (call *builtinInetAton) eval(env *ExpressionEnv) (eval, error) {
arg, err := call.arg1(env)
Expand Down Expand Up @@ -155,6 +160,7 @@ func (call *builtinInetNtoa) compile(c *compiler) (ctype, error) {
c.compileToUint64(arg, 1)
col := typedCoercionCollation(sqltypes.VarChar, call.collate)
c.asm.Fn_INET_NTOA(col)

c.asm.jumpDestination(skip)

return ctype{Type: sqltypes.VarChar, Flag: flagNullable, Col: col}, nil
Expand Down Expand Up @@ -194,6 +200,40 @@ func (call *builtinInet6Aton) compile(c *compiler) (ctype, error) {
return ctype{Type: sqltypes.VarBinary, Flag: flagNullable, Col: collationBinary}, nil
}

func (call *builtinLastInsertID) eval(env *ExpressionEnv) (eval, error) {
arg, err := call.arg1(env)
if err != nil {
return nil, err
}
if arg == nil {
env.VCursor().SetLastInsertID(0)
return nil, err
}
insertID := uint64(evalToInt64(arg).i)
env.VCursor().SetLastInsertID(insertID)
return newEvalUint64(insertID), nil
}

func (call *builtinLastInsertID) compile(c *compiler) (ctype, error) {
arg, err := call.Arguments[0].compile(c)
if err != nil {
return ctype{}, err
}

setZero := c.compileNullCheck1(arg)
c.compileToUint64(arg, 1)
c.asm.Fn_LAST_INSERT_ID()
end := c.asm.jumpFrom()
c.asm.addJump(end)

c.asm.jumpDestination(setZero)
c.asm.Fn_LAST_INSERT_ID_NULL()

c.asm.jumpDestination(end)

return ctype{Type: sqltypes.Uint64, Flag: flagNullable, Col: collationNumeric}, nil
}

func printIPv6AsIPv4(addr netip.Addr) (netip.Addr, bool) {
b := addr.AsSlice()
if len(b) != 16 {
Expand Down
4 changes: 4 additions & 0 deletions go/vt/vtgate/evalengine/integration/comparison_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ type vcursor struct {
env *vtenv.Environment
}

func (vc *vcursor) SetLastInsertID(id uint64) {}

var _ evalengine.VCursor = (*vcursor)(nil)

func (vc *vcursor) GetKeyspace() string {
return "vttest"
}
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/evalengine/translate_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,11 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (IR, error) {
return nil, argError(method)
}
return &builtinReplace{CallExpr: call, collate: ast.cfg.Collation}, nil
case "last_insert_id":
if len(args) != 1 {
return nil, argError(method)
}
return &builtinLastInsertID{CallExpr: call}, nil
default:
return nil, translateExprNotSupported(fn)
}
Expand Down
6 changes: 6 additions & 0 deletions go/vt/vtgate/executorcontext/vcursor_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -1569,3 +1569,9 @@ func (vc *VCursorImpl) GetContextWithTimeOut(ctx context.Context) (context.Conte
func (vc *VCursorImpl) IgnoreMaxMemoryRows() bool {
return vc.ignoreMaxMemoryRows
}

func (vc *VCursorImpl) SetLastInsertID(id uint64) {
vc.SafeSession.mu.Lock()
defer vc.SafeSession.mu.Unlock()
vc.SafeSession.LastInsertId = id
}

0 comments on commit b6295b7

Please sign in to comment.