diff --git a/client/client.go b/client/client.go index f0d51ef2596..c1af1917b7d 100644 --- a/client/client.go +++ b/client/client.go @@ -386,12 +386,22 @@ func (c *client) setServiceMode(newMode pdpb.ServiceMode) { log.Info("[pd] changing service mode", zap.String("old-mode", c.serviceMode.String()), zap.String("new-mode", newMode.String())) + c.resetTSOClientLocked(newMode) + oldMode := c.serviceMode + c.serviceMode = newMode + log.Info("[pd] service mode changed", + zap.String("old-mode", oldMode.String()), + zap.String("new-mode", newMode.String())) +} + +// Reset a new TSO client. +func (c *client) resetTSOClientLocked(mode pdpb.ServiceMode) { // Re-create a new TSO client. var ( newTSOCli *tsoClient newTSOSvcDiscovery ServiceDiscovery ) - switch newMode { + switch mode { case pdpb.ServiceMode_PD_SVC_MODE: newTSOCli = newTSOClient(c.ctx, c.option, c.keyspaceID, c.pdSvcDiscovery, &pdTSOStreamBuilderFactory{}) @@ -424,6 +434,7 @@ func (c *client) setServiceMode(newMode pdpb.ServiceMode) { oldTSOSvcDiscovery.Close() } } +<<<<<<< HEAD c.serviceMode = newMode log.Info("[pd] service mode changed", zap.String("old-mode", c.serviceMode.String()), @@ -435,6 +446,27 @@ func (c *client) getTSOClient() *tsoClient { return tsoCli.(*tsoClient) } return nil +======= +} + +func (c *client) getTSOClient() *tsoClient { + c.RLock() + defer c.RUnlock() + return c.tsoClient +} + +// ResetTSOClient resets the TSO client, only for test. +func (c *client) ResetTSOClient() { + c.Lock() + defer c.Unlock() + c.resetTSOClientLocked(c.serviceMode) +} + +func (c *client) getServiceMode() pdpb.ServiceMode { + c.RLock() + defer c.RUnlock() + return c.serviceMode +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) } func (c *client) scheduleUpdateTokenConnection() { @@ -598,6 +630,7 @@ func (c *client) GetLocalTSAsync(ctx context.Context, dcLocation string) TSFutur ctx = opentracing.ContextWithSpan(ctx, span) } +<<<<<<< HEAD req := tsoReqPool.Get().(*tsoRequest) req.requestCtx = ctx req.clientCtx = c.ctx @@ -617,10 +650,59 @@ func (c *client) GetLocalTSAsync(ctx context.Context, dcLocation string) TSFutur if err = tsoClient.dispatchRequest(dcLocation, req); err != nil { req.done <- err } +======= + req := c.getTSORequest(ctx, dcLocation) + if err := c.dispatchTSORequestWithRetry(req); err != nil { + req.done <- err +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) } return req } +<<<<<<< HEAD +======= +func (c *client) getTSORequest(ctx context.Context, dcLocation string) *tsoRequest { + req := tsoReqPool.Get().(*tsoRequest) + // Set needed fields in the request before using it. + req.start = time.Now() + req.clientCtx = c.ctx + req.requestCtx = ctx + req.physical = 0 + req.logical = 0 + req.dcLocation = dcLocation + return req +} + +const ( + dispatchRetryDelay = 50 * time.Millisecond + dispatchRetryCount = 2 +) + +func (c *client) dispatchTSORequestWithRetry(req *tsoRequest) error { + var ( + retryable bool + err error + ) + for i := 0; i < dispatchRetryCount; i++ { + // Do not delay for the first time. + if i > 0 { + time.Sleep(dispatchRetryDelay) + } + // Get the tsoClient each time, as it may be initialized or switched during the process. + tsoClient := c.getTSOClient() + if tsoClient == nil { + err = errs.ErrClientGetTSO.FastGenByArgs("tso client is nil") + continue + } + retryable, err = tsoClient.dispatchRequest(req) + if !retryable { + break + } + } + return err +} + +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) func (c *client) GetTS(ctx context.Context) (physical int64, logical int64, err error) { resp := c.GetTSAsync(ctx) return resp.Wait() diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index 842c772abd9..bd7a440fb08 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -16,7 +16,10 @@ package pd import ( "context" + "runtime/trace" "time" + + "github.com/tikv/pd/client/tsoutil" ) type tsoBatchController struct { @@ -130,7 +133,18 @@ func (tbc *tsoBatchController) adjustBestBatchSize() { } } -func (tbc *tsoBatchController) revokePendingRequest(err error) { +func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical int64, suffixBits uint32, err error) { + for i := 0; i < tbc.collectedRequestCount; i++ { + tsoReq := tbc.collectedRequests[i] + tsoReq.physical, tsoReq.logical = physical, tsoutil.AddLogical(firstLogical, int64(i), suffixBits) + defer trace.StartRegion(tsoReq.requestCtx, "pdclient.tsoReqDequeue").End() + tsoReq.done <- err + } + // Prevent the finished requests from being processed again. + tbc.collectedRequestCount = 0 +} + +func (tbc *tsoBatchController) revokePendingRequests(err error) { for i := 0; i < len(tbc.tsoRequestCh); i++ { req := <-tbc.tsoRequestCh req.done <- err diff --git a/client/tso_client.go b/client/tso_client.go index 4feca3d2187..e23b34c9ad4 100644 --- a/client/tso_client.go +++ b/client/tso_client.go @@ -141,7 +141,7 @@ func (c *tsoClient) Close() { if dispatcherInterface != nil { dispatcher := dispatcherInterface.(*tsoDispatcher) tsoErr := errors.WithStack(errClosing) - dispatcher.tsoBatchController.revokePendingRequest(tsoErr) + dispatcher.tsoBatchController.revokePendingRequests(tsoErr) dispatcher.dispatcherCancel() } return true diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 1aa9543bfa2..ba32dc8a245 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -302,7 +302,8 @@ func (c *tsoClient) createTSODispatcher(dcLocation string) { func (c *tsoClient) handleDispatcher( dispatcherCtx context.Context, dc string, - tbc *tsoBatchController) { + tbc *tsoBatchController, +) { var ( err error streamAddr string @@ -377,7 +378,11 @@ tsoBatchLoop: } // Start to collect the TSO requests. maxBatchWaitInterval := c.option.getMaxTSOBatchWaitInterval() + // Once the TSO requests are collected, must make sure they could be finished or revoked eventually, + // otherwise the upper caller may get blocked on waiting for the results. if err = tbc.fetchPendingRequests(dispatcherCtx, maxBatchWaitInterval); err != nil { + // Finish the collected requests if the fetch failed. + tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) if err == context.Canceled { log.Info("[tso] stop fetching the pending tso requests due to context canceled", zap.String("dc-location", dc)) @@ -406,12 +411,24 @@ tsoBatchLoop: } select { case <-dispatcherCtx.Done(): +<<<<<<< HEAD +======= + // Finish the collected requests if the context is canceled. + tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(dispatcherCtx.Err())) + timer.Stop() +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) return case <-streamLoopTimer.C: err = errs.ErrClientCreateTSOStream.FastGenByArgs(errs.RetryTimeoutErr) log.Error("[tso] create tso stream error", zap.String("dc-location", dc), errs.ZapError(err)) c.svcDiscovery.ScheduleCheckMemberChanged() +<<<<<<< HEAD c.finishRequest(tbc.getCollectedRequests(), 0, 0, 0, errors.WithStack(err)) +======= + // Finish the collected requests if the stream is failed to be created. + tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) + timer.Stop() +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) continue tsoBatchLoop case <-time.After(retryInterval): continue streamChoosingLoop @@ -443,11 +460,18 @@ tsoBatchLoop: } select { case <-dispatcherCtx.Done(): + // Finish the collected requests if the context is canceled. + tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(dispatcherCtx.Err())) return case tsDeadlineCh.(chan deadline) <- dl: } +<<<<<<< HEAD opts = extractSpanReference(tbc, opts[:0]) err = c.processRequests(stream, dc, tbc, opts) +======= + // processRequests guarantees that the collected requests could be finished properly. + err = c.processRequests(stream, dc, tbc) +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) close(done) // If error happens during tso stream handling, reset stream and run the next trial. if err != nil { @@ -698,6 +722,7 @@ func extractSpanReference(tbc *tsoBatchController, opts []opentracing.StartSpanO opts = append(opts, opentracing.ChildOf(span.Context())) } } +<<<<<<< HEAD return opts } @@ -710,14 +735,36 @@ func (c *tsoClient) processRequests(stream tsoStream, dcLocation string, tbc *ts requests := tbc.getCollectedRequests() count := int64(len(requests)) physical, logical, suffixBits, err := stream.processRequests(c.svcDiscovery.GetClusterID(), dcLocation, requests, tbc.batchStartTime) +======= + + count := int64(len(requests)) + reqKeyspaceGroupID := c.svcDiscovery.GetKeyspaceGroupID() + respKeyspaceGroupID, physical, logical, suffixBits, err := stream.processRequests( + c.svcDiscovery.GetClusterID(), c.svcDiscovery.GetKeyspaceID(), reqKeyspaceGroupID, + dcLocation, count, tbc.batchStartTime) +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) if err != nil { - c.finishRequest(requests, 0, 0, 0, err) + tbc.finishCollectedRequests(0, 0, 0, err) return err } // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. +<<<<<<< HEAD firstLogical := addLogical(logical, -count+1, suffixBits) c.compareAndSwapTS(dcLocation, physical, firstLogical, suffixBits, count) c.finishRequest(requests, physical, firstLogical, suffixBits, nil) +======= + firstLogical := tsoutil.AddLogical(logical, -count+1, suffixBits) + curTSOInfo := &tsoInfo{ + tsoServer: stream.getServerURL(), + reqKeyspaceGroupID: reqKeyspaceGroupID, + respKeyspaceGroupID: respKeyspaceGroupID, + respReceivedAt: time.Now(), + physical: physical, + logical: tsoutil.AddLogical(firstLogical, count-1, suffixBits), + } + c.compareAndSwapTS(dcLocation, curTSOInfo, physical, firstLogical) + tbc.finishCollectedRequests(physical, firstLogical, suffixBits, nil) +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) return nil } @@ -757,6 +804,7 @@ func tsLessEqual(physical, logical, thatPhysical, thatLogical int64) bool { } return physical < thatPhysical } +<<<<<<< HEAD func (c *tsoClient) finishRequest(requests []*tsoRequest, physical, firstLogical int64, suffixBits uint32, err error) { for i := 0; i < len(requests); i++ { @@ -767,3 +815,5 @@ func (c *tsoClient) finishRequest(requests []*tsoRequest, physical, firstLogical requests[i].done <- err } } +======= +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) diff --git a/client/tso_stream.go b/client/tso_stream.go index baa764dffb2..281b5b7629d 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -97,18 +97,34 @@ func checkStreamTimeout(ctx context.Context, cancel context.CancelFunc, done cha type tsoStream interface { // processRequests processes TSO requests in streaming mode to get timestamps +<<<<<<< HEAD processRequests(clusterID uint64, dcLocation string, requests []*tsoRequest, batchStartTime time.Time) (physical, logical int64, suffixBits uint32, err error) +======= + processRequests( + clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, + count int64, batchStartTime time.Time, + ) (respKeyspaceGroupID uint32, physical, logical int64, suffixBits uint32, err error) +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) } type pdTSOStream struct { stream pdpb.PD_TsoClient } +<<<<<<< HEAD func (s *pdTSOStream) processRequests(clusterID uint64, dcLocation string, requests []*tsoRequest, batchStartTime time.Time) (physical, logical int64, suffixBits uint32, err error) { +======= +func (s *pdTSOStream) getServerURL() string { + return s.serverURL +} + +func (s *pdTSOStream) processRequests( + clusterID uint64, _, _ uint32, dcLocation string, count int64, batchStartTime time.Time, +) (respKeyspaceGroupID uint32, physical, logical int64, suffixBits uint32, err error) { +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) start := time.Now() - count := int64(len(requests)) req := &pdpb.TsoRequest{ Header: &pdpb.RequestHeader{ ClusterId: clusterID, @@ -152,10 +168,20 @@ type tsoTSOStream struct { stream tsopb.TSO_TsoClient } +<<<<<<< HEAD func (s *tsoTSOStream) processRequests(clusterID uint64, dcLocation string, requests []*tsoRequest, batchStartTime time.Time) (physical, logical int64, suffixBits uint32, err error) { +======= +func (s *tsoTSOStream) getServerURL() string { + return s.serverURL +} + +func (s *tsoTSOStream) processRequests( + clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, + count int64, batchStartTime time.Time, +) (respKeyspaceGroupID uint32, physical, logical int64, suffixBits uint32, err error) { +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) start := time.Now() - count := int64(len(requests)) req := &tsopb.TsoRequest{ Header: &tsopb.RequestHeader{ ClusterId: clusterID, diff --git a/tests/integrations/tso/client_test.go b/tests/integrations/tso/client_test.go index c98caa9db7f..22671cfe7b9 100644 --- a/tests/integrations/tso/client_test.go +++ b/tests/integrations/tso/client_test.go @@ -20,6 +20,7 @@ import ( "math/rand" "strings" "sync" + "sync/atomic" "testing" "time" @@ -52,6 +53,10 @@ type tsoClientTestSuite struct { client pd.TSOClient } +func (suite *tsoClientTestSuite) getBackendEndpoints() []string { + return strings.Split(suite.backendEndpoints, ",") +} + func TestLegacyTSOClient(t *testing.T) { suite.Run(t, &tsoClientTestSuite{ legacy: true, @@ -82,7 +87,11 @@ func (suite *tsoClientTestSuite) SetupSuite() { re.NoError(pdLeader.BootstrapCluster()) suite.backendEndpoints = pdLeader.GetAddr() if suite.legacy { +<<<<<<< HEAD suite.client, err = pd.NewClientWithContext(suite.ctx, strings.Split(suite.backendEndpoints, ","), pd.SecurityOption{}) +======= + client, err := pd.NewClientWithContext(suite.ctx, suite.getBackendEndpoints(), pd.SecurityOption{}, pd.WithForwardingOption(true)) +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) re.NoError(err) } else { suite.tsoServer, suite.tsoServerCleanup = mcs.StartSingleTSOTestServer(suite.ctx, re, suite.backendEndpoints, tempurl.Alloc()) @@ -90,6 +99,39 @@ func (suite *tsoClientTestSuite) SetupSuite() { } } +<<<<<<< HEAD +======= +func (suite *tsoClientTestSuite) waitForAllKeyspaceGroupsInServing(re *require.Assertions) { + // The tso servers are loading keyspace groups asynchronously. Make sure all keyspace groups + // are available for serving tso requests from corresponding keyspaces by querying + // IsKeyspaceServing(keyspaceID, the Desired KeyspaceGroupID). if use default keyspace group id + // in the query, it will always return true as the keyspace will be served by default keyspace + // group before the keyspace groups are loaded. + testutil.Eventually(re, func() bool { + for _, keyspaceGroup := range suite.keyspaceGroups { + for _, keyspaceID := range keyspaceGroup.keyspaceIDs { + served := false + for _, server := range suite.tsoCluster.GetServers() { + if server.IsKeyspaceServing(keyspaceID, keyspaceGroup.keyspaceGroupID) { + served = true + break + } + } + if !served { + return false + } + } + } + return true + }, testutil.WithWaitFor(5*time.Second), testutil.WithTickInterval(50*time.Millisecond)) + + // Create clients and make sure they all have discovered the tso service. + suite.clients = mcs.WaitForMultiKeyspacesTSOAvailable( + suite.ctx, re, suite.keyspaceIDs, suite.getBackendEndpoints()) + re.Equal(len(suite.keyspaceIDs), len(suite.clients)) +} + +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) func (suite *tsoClientTestSuite) TearDownSuite() { suite.cancel() if !suite.legacy { @@ -140,6 +182,81 @@ func (suite *tsoClientTestSuite) TestGetTSAsync() { wg.Wait() } +<<<<<<< HEAD +======= +func (suite *tsoClientTestSuite) TestDiscoverTSOServiceWithLegacyPath() { + re := suite.Require() + keyspaceID := uint32(1000000) + // Make sure this keyspace ID is not in use somewhere. + re.False(slice.Contains(suite.keyspaceIDs, keyspaceID)) + failpointValue := fmt.Sprintf(`return(%d)`, keyspaceID) + // Simulate the case that the server has lower version than the client and returns no tso addrs + // in the GetClusterInfo RPC. + re.NoError(failpoint.Enable("github.com/tikv/pd/client/serverReturnsNoTSOAddrs", `return(true)`)) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/unexpectedCallOfFindGroupByKeyspaceID", failpointValue)) + defer func() { + re.NoError(failpoint.Disable("github.com/tikv/pd/client/serverReturnsNoTSOAddrs")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/unexpectedCallOfFindGroupByKeyspaceID")) + }() + + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + client := mcs.SetupClientWithKeyspaceID( + ctx, re, keyspaceID, suite.getBackendEndpoints()) + defer client.Close() + var lastTS uint64 + for j := 0; j < tsoRequestRound; j++ { + physical, logical, err := client.GetTS(ctx) + re.NoError(err) + ts := tsoutil.ComposeTS(physical, logical) + re.Less(lastTS, ts) + lastTS = ts + } +} + +// TestGetMinTS tests the correctness of GetMinTS. +func (suite *tsoClientTestSuite) TestGetMinTS() { + re := suite.Require() + var wg sync.WaitGroup + wg.Add(tsoRequestConcurrencyNumber * len(suite.clients)) + for i := 0; i < tsoRequestConcurrencyNumber; i++ { + for _, client := range suite.clients { + go func(client pd.Client) { + defer wg.Done() + var lastMinTS uint64 + for j := 0; j < tsoRequestRound; j++ { + physical, logical, err := client.GetMinTS(suite.ctx) + re.NoError(err) + minTS := tsoutil.ComposeTS(physical, logical) + re.Less(lastMinTS, minTS) + lastMinTS = minTS + + // Now we check whether the returned ts is the minimum one + // among all keyspace groups, i.e., the returned ts is + // less than the new timestamps of all keyspace groups. + for _, client := range suite.clients { + physical, logical, err := client.GetTS(suite.ctx) + re.NoError(err) + ts := tsoutil.ComposeTS(physical, logical) + re.Less(minTS, ts) + } + } + }(client) + } + } + wg.Wait() + + re.NoError(failpoint.Enable("github.com/tikv/pd/client/unreachableNetwork1", "return(true)")) + time.Sleep(time.Second) + testutil.Eventually(re, func() bool { + var err error + _, _, err = suite.clients[0].GetMinTS(suite.ctx) + return err == nil + }) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/unreachableNetwork1")) +} + +>>>>>>> c00c42e77 (client/tso: fix the bug that collected TSO requests could never be finished (#7951)) // More details can be found in this issue: https://github.com/tikv/pd/issues/4884 func (suite *tsoClientTestSuite) TestUpdateAfterResetTSO() { re := suite.Require() @@ -235,6 +352,50 @@ func (suite *tsoClientTestSuite) TestRandomShutdown() { suite.SetupSuite() } +func (suite *tsoClientTestSuite) TestGetTSWhileRestingTSOClient() { + re := suite.Require() + var ( + clients []pd.Client + stopSignal atomic.Bool + wg sync.WaitGroup + ) + // Create independent clients to prevent interfering with other tests. + if suite.legacy { + client, err := pd.NewClientWithContext(suite.ctx, suite.getBackendEndpoints(), pd.SecurityOption{}, pd.WithForwardingOption(true)) + re.NoError(err) + clients = []pd.Client{client} + } else { + clients = mcs.WaitForMultiKeyspacesTSOAvailable(suite.ctx, re, suite.keyspaceIDs, suite.getBackendEndpoints()) + } + wg.Add(tsoRequestConcurrencyNumber * len(clients)) + for i := 0; i < tsoRequestConcurrencyNumber; i++ { + for _, client := range clients { + go func(client pd.Client) { + defer wg.Done() + var lastTS uint64 + for !stopSignal.Load() { + physical, logical, err := client.GetTS(suite.ctx) + if err != nil { + re.ErrorContains(err, context.Canceled.Error()) + } else { + ts := tsoutil.ComposeTS(physical, logical) + re.Less(lastTS, ts) + lastTS = ts + } + } + }(client) + } + } + // Reset the TSO clients while requesting TSO concurrently. + for i := 0; i < tsoRequestConcurrencyNumber; i++ { + for _, client := range clients { + client.(interface{ ResetTSOClient() }).ResetTSOClient() + } + } + stopSignal.Store(true) + wg.Wait() +} + // When we upgrade the PD cluster, there may be a period of time that the old and new PDs are running at the same time. func TestMixedTSODeployment(t *testing.T) { re := require.New(t)