diff --git a/go/vt/vtgate/engine/delete.go b/go/vt/vtgate/engine/delete.go index a039cee46fa..25bf09e2509 100644 --- a/go/vt/vtgate/engine/delete.go +++ b/go/vt/vtgate/engine/delete.go @@ -131,6 +131,9 @@ func (del *Delete) description() PrimitiveDescription { } addFieldsIfNotEmpty(del.DML, other) + if del.FetchLastInsertID { + other["FetchLastInsertID"] = del.FetchLastInsertID + } return PrimitiveDescription{ OperatorType: "Delete", diff --git a/go/vt/vtgate/engine/insert.go b/go/vt/vtgate/engine/insert.go index d35d9214186..cb822818d68 100644 --- a/go/vt/vtgate/engine/insert.go +++ b/go/vt/vtgate/engine/insert.go @@ -385,6 +385,10 @@ func (ins *Insert) description() PrimitiveDescription { } } + if ins.FetchLastInsertID { + other["FetchLastInsertID"] = true + } + return PrimitiveDescription{ OperatorType: "Insert", Keyspace: ins.Keyspace, diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index a1dc7b63362..e9044073be9 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -376,6 +376,9 @@ func (route *Route) description() PrimitiveDescription { "Table": route.GetTableName(), "FieldQuery": route.FieldQuery, } + if route.FetchLastInsertID { + other["FetchLastInsertID"] = true + } if route.Vindex != nil { other["Vindex"] = route.Vindex.String() } diff --git a/go/vt/vtgate/engine/update.go b/go/vt/vtgate/engine/update.go index e607b2dd74f..563a032a409 100644 --- a/go/vt/vtgate/engine/update.go +++ b/go/vt/vtgate/engine/update.go @@ -214,6 +214,9 @@ func (upd *Update) description() PrimitiveDescription { if len(changedVindexes) > 0 { other["ChangedVindexValues"] = changedVindexes } + if upd.FetchLastInsertID { + other["FetchLastInsertID"] = upd.FetchLastInsertID + } return PrimitiveDescription{ OperatorType: "Update", diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index df14745e6b2..169260e3d54 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -536,6 +536,7 @@ func routeToEngineRoute(ctx *plancontext.PlanningContext, op *operators.Route, h TableName: strings.Join(tableNames, ", "), RoutingParameters: rp, TruncateColumnCount: op.ResultColumns, + FetchLastInsertID: ctx.SemTable.ShouldFetchLastInsertID(), } if hints != nil { e.ScatterErrorsAsWarnings = hints.scatterErrorsAsWarnings @@ -601,7 +602,7 @@ func transformRoutePlan(ctx *plancontext.PlanningContext, op *operators.Route) ( case *sqlparser.Delete: return buildDeletePrimitive(ctx, op, dmlOp, stmt, hints) case *sqlparser.Insert: - return buildInsertPrimitive(op, dmlOp, stmt, hints) + return buildInsertPrimitive(ctx, op, dmlOp, stmt, hints) default: return nil, vterrors.VT13001(fmt.Sprintf("dont know how to %T", stmt)) } @@ -637,7 +638,10 @@ func buildRoutePrimitive(ctx *plancontext.PlanningContext, op *operators.Route, } func buildInsertPrimitive( - rb *operators.Route, op operators.Operator, stmt *sqlparser.Insert, + ctx *plancontext.PlanningContext, + rb *operators.Route, + op operators.Operator, + stmt *sqlparser.Insert, hints *queryHints, ) (engine.Primitive, error) { ins := op.(*operators.Insert) @@ -656,8 +660,9 @@ func buildInsertPrimitive( } eins := &engine.Insert{ - InsertCommon: ic, - VindexValues: ins.VindexValues, + InsertCommon: ic, + VindexValues: ins.VindexValues, + FetchLastInsertID: ctx.SemTable.ShouldFetchLastInsertID(), } // we would need to generate the query on the fly. The only exception here is @@ -788,6 +793,7 @@ func createDMLPrimitive(ctx *plancontext.PlanningContext, rb *operators.Route, h Vindexes: colVindexes, OwnedVindexQuery: vindexQuery, RoutingParameters: rp, + FetchLastInsertID: ctx.SemTable.ShouldFetchLastInsertID(), } if rb.Routing.OpCode() != engine.Unsharded && vindexQuery != "" { diff --git a/go/vt/vtgate/planbuilder/testdata/dml_cases.json b/go/vt/vtgate/planbuilder/testdata/dml_cases.json index f796b935605..3a519f9139d 100644 --- a/go/vt/vtgate/planbuilder/testdata/dml_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/dml_cases.json @@ -2235,6 +2235,52 @@ ] } }, + { + "comment": "update with last_insert_id in SET", + "query": "update user_extra set val = last_insert_id(123)", + "plan": { + "QueryType": "UPDATE", + "Original": "update user_extra set val = last_insert_id(123)", + "Instructions": { + "OperatorType": "Update", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "FetchLastInsertID": true, + "Query": "update user_extra set val = last_insert_id(123)", + "Table": "user_extra" + }, + "TablesUsed": [ + "user.user_extra" + ] + } + }, + { + "comment": "delete with last_insert_id in where", + "query": "delete from user_extra where val = last_insert_id(123)", + "plan": { + "QueryType": "DELETE", + "Original": "delete from user_extra where val = last_insert_id(123)", + "Instructions": { + "OperatorType": "Delete", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "FetchLastInsertID": true, + "Query": "delete from user_extra where val = last_insert_id(123)", + "Table": "user_extra" + }, + "TablesUsed": [ + "user.user_extra" + ] + } + }, { "comment": "delete from with no where clause", "query": "delete from user_extra", diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index eac13216380..ed95d06babf 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -5599,6 +5599,7 @@ "Name": "user", "Sharded": true }, + "FetchLastInsertID": true, "FieldQuery": "select last_insert_id(id) from `user` where 1 != 1", "Query": "select last_insert_id(id) from `user`", "Table": "`user`" diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index c28edbb0674..de8fbdee0d7 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -638,6 +638,9 @@ func TestQuerySignatureLastInsertID(t *testing.T) { }, { query: "select last_insert_id(123)", expected: true, + }, { + query: "update user_extra set val = last_insert_id(123)", + expected: true, }} for _, tc := range queries { @@ -645,8 +648,7 @@ func TestQuerySignatureLastInsertID(t *testing.T) { ast, err := sqlparser.NewTestParser().Parse(tc.query) require.NoError(t, err) - sel := ast.(*sqlparser.Select) - st, err := AnalyzeStrict(sel, "dbName", fakeSchemaInfo()) + st, err := AnalyzeStrict(ast, "dbName", fakeSchemaInfo()) require.NoError(t, err) require.Equal(t, tc.expected, st.QuerySignature.LastInsertIDArg) }) diff --git a/go/vt/vtgate/semantics/semantic_table.go b/go/vt/vtgate/semantics/semantic_table.go index 77c8f6db835..492259427c5 100644 --- a/go/vt/vtgate/semantics/semantic_table.go +++ b/go/vt/vtgate/semantics/semantic_table.go @@ -1014,6 +1014,13 @@ func (st *SemTable) GetMirrorInfo() MirrorInfo { return mirrorInfo(st.Tables) } +func (st *SemTable) ShouldFetchLastInsertID() bool { + if st == nil { + return false + } + return st.QuerySignature.LastInsertIDArg +} + // mirrorInfo looks through all tables with mirror rules defined, and returns a // MirrorInfo containing the lowest mirror percentage found across all rules. //