diff --git a/go/vt/vtgate/scatter_conn.go b/go/vt/vtgate/scatter_conn.go index a7ceede4cd6..c37194a2c78 100644 --- a/go/vt/vtgate/scatter_conn.go +++ b/go/vt/vtgate/scatter_conn.go @@ -188,12 +188,10 @@ func (stc *ScatterConn) ExecuteMultiShard( opts = session.Session.Options } - if fetchLastInsertID { - if opts == nil { - opts = &querypb.ExecuteOptions{FetchLastInsertId: fetchLastInsertID} - } else { - opts.FetchLastInsertId = fetchLastInsertID - } + if opts != nil { + opts.FetchLastInsertId = fetchLastInsertID + } else if fetchLastInsertID { + opts = &querypb.ExecuteOptions{FetchLastInsertId: fetchLastInsertID} } if autocommit { @@ -414,12 +412,10 @@ func (stc *ScatterConn) StreamExecuteMulti( opts = session.Session.Options } - if fetchLastInsertID { - if opts == nil { - opts = &querypb.ExecuteOptions{FetchLastInsertId: fetchLastInsertID} - } else { - opts.FetchLastInsertId = fetchLastInsertID - } + if opts != nil { + opts.FetchLastInsertId = fetchLastInsertID + } else if fetchLastInsertID { + opts = &querypb.ExecuteOptions{FetchLastInsertId: fetchLastInsertID} } if autocommit { diff --git a/go/vt/vtgate/scatter_conn_test.go b/go/vt/vtgate/scatter_conn_test.go index e6c976c7b74..7fadda2a23b 100644 --- a/go/vt/vtgate/scatter_conn_test.go +++ b/go/vt/vtgate/scatter_conn_test.go @@ -109,6 +109,55 @@ func TestExecuteFailOnAutocommit(t *testing.T) { utils.MustMatch(t, []*querypb.BoundQuery{queries[1]}, sbc1.Queries, "") } +func TestFetchLastInsertIDResets(t *testing.T) { + ctx := utils.LeakCheckContext(t) + + ks := "TestFetchLastInsertIDResets" + createSandbox(ks) + hc := discovery.NewFakeHealthCheck(nil) + sc := newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa") + sbc0 := hc.AddTestTablet("aa", "0", 1, ks, "0", topodatapb.TabletType_PRIMARY, true, 1, nil) + sbc1 := hc.AddTestTablet("aa", "1", 1, ks, "1", topodatapb.TabletType_PRIMARY, true, 1, nil) + + rss := []*srvtopo.ResolvedShard{{ + Target: &querypb.Target{ + Keyspace: ks, + Shard: "0", + TabletType: topodatapb.TabletType_PRIMARY, + }, + Gateway: sbc0, + }, { + Target: &querypb.Target{ + Keyspace: ks, + Shard: "1", + TabletType: topodatapb.TabletType_PRIMARY, + }, + Gateway: sbc1, + }} + queries := []*querypb.BoundQuery{{ + // This will fail to go to shard. It will be rejected at vtgate. + Sql: "query1", + BindVariables: map[string]*querypb.BindVariable{ + "bv0": sqltypes.Int64BindVariable(0), + }, + }, { + // This will go to shard. + Sql: "query2", + BindVariables: map[string]*querypb.BindVariable{ + "bv1": sqltypes.Int64BindVariable(1), + }, + }} + + session := econtext.NewSafeSession(&vtgatepb.Session{Options: &querypb.ExecuteOptions{}}) + _, errs := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, true /*autocommit*/, false, nullResultsObserver{}, true) + require.NoError(t, vterrors.Aggregate(errs)) + assert.True(t, session.Options.FetchLastInsertId) + + _, errs = sc.ExecuteMultiShard(ctx, nil, rss, queries, session, true /*autocommit*/, false, nullResultsObserver{}, false) + require.NoError(t, vterrors.Aggregate(errs)) + assert.False(t, session.Options.FetchLastInsertId) +} + func TestExecutePanic(t *testing.T) { ctx := utils.LeakCheckContext(t) @@ -177,15 +226,10 @@ func TestExecutePanic(t *testing.T) { logMessage = fmt.Sprintf(format, args...) } - defer func() { - r := recover() - require.NotNil(t, r, "The code did not panic") - // assert we are seeing the stack trace - require.Contains(t, logMessage, "(*ScatterConn).multiGoTransaction") - }() - - _, _ = sc.ExecuteMultiShard(ctx, nil, rss, queries, econtext.NewSafeSession(session), true /*autocommit*/, false, nullResultsObserver{}, false) - + assert.Panics(t, func() { + _, _ = sc.ExecuteMultiShard(ctx, nil, rss, queries, econtext.NewSafeSession(session), true /*autocommit*/, false, nullResultsObserver{}, false) + }) + require.Contains(t, logMessage, "(*ScatterConn).multiGoTransaction") } func TestReservedOnMultiReplica(t *testing.T) {