diff --git a/client/client.go b/client/client.go index 31481b918e6..8838c184d92 100644 --- a/client/client.go +++ b/client/client.go @@ -606,12 +606,22 @@ func (c *client) setServiceMode(newMode pdpb.ServiceMode) { log.Info("[pd] changing service mode", zap.String("old-mode", c.serviceMode.String()), zap.String("new-mode", newMode.String())) + c.resetTSOClientLocked(newMode) + oldMode := c.serviceMode + c.serviceMode = newMode + log.Info("[pd] service mode changed", + zap.String("old-mode", oldMode.String()), + zap.String("new-mode", newMode.String())) +} + +// Reset a new TSO client. +func (c *client) resetTSOClientLocked(mode pdpb.ServiceMode) { // Re-create a new TSO client. var ( newTSOCli *tsoClient newTSOSvcDiscovery ServiceDiscovery ) - switch newMode { + switch mode { case pdpb.ServiceMode_PD_SVC_MODE: newTSOCli = newTSOClient(c.ctx, c.option, c.pdSvcDiscovery, &pdTSOStreamBuilderFactory{}) @@ -649,11 +659,6 @@ func (c *client) setServiceMode(newMode pdpb.ServiceMode) { // We are switching from API service mode to PD service mode, so delete the old tso microservice discovery. oldTSOSvcDiscovery.Close() } - oldMode := c.serviceMode - c.serviceMode = newMode - log.Info("[pd] service mode changed", - zap.String("old-mode", oldMode.String()), - zap.String("new-mode", newMode.String())) } func (c *client) getTSOClient() *tsoClient { @@ -662,6 +667,13 @@ func (c *client) getTSOClient() *tsoClient { return c.tsoClient } +// ResetTSOClient resets the TSO client, only for test. +func (c *client) ResetTSOClient() { + c.Lock() + defer c.Unlock() + c.resetTSOClientLocked(c.serviceMode) +} + func (c *client) getServiceMode() pdpb.ServiceMode { c.RLock() defer c.RUnlock() @@ -779,18 +791,25 @@ func (c *client) GetLocalTSAsync(ctx context.Context, dcLocation string) TSFutur defer span.Finish() } - req := tsoReqPool.Get().(*tsoRequest) - req.requestCtx = ctx - req.clientCtx = c.ctx - req.start = time.Now() - req.dcLocation = dcLocation - + req := c.getTSORequest(ctx, dcLocation) if err := c.dispatchTSORequestWithRetry(req); err != nil { req.done <- err } return req } +func (c *client) getTSORequest(ctx context.Context, dcLocation string) *tsoRequest { + req := tsoReqPool.Get().(*tsoRequest) + // Set needed fields in the request before using it. + req.start = time.Now() + req.clientCtx = c.ctx + req.requestCtx = ctx + req.physical = 0 + req.logical = 0 + req.dcLocation = dcLocation + return req +} + const ( dispatchRetryDelay = 50 * time.Millisecond dispatchRetryCount = 2 diff --git a/client/tso_batch_controller.go b/client/tso_batch_controller.go index 842c772abd9..bd7a440fb08 100644 --- a/client/tso_batch_controller.go +++ b/client/tso_batch_controller.go @@ -16,7 +16,10 @@ package pd import ( "context" + "runtime/trace" "time" + + "github.com/tikv/pd/client/tsoutil" ) type tsoBatchController struct { @@ -130,7 +133,18 @@ func (tbc *tsoBatchController) adjustBestBatchSize() { } } -func (tbc *tsoBatchController) revokePendingRequest(err error) { +func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical int64, suffixBits uint32, err error) { + for i := 0; i < tbc.collectedRequestCount; i++ { + tsoReq := tbc.collectedRequests[i] + tsoReq.physical, tsoReq.logical = physical, tsoutil.AddLogical(firstLogical, int64(i), suffixBits) + defer trace.StartRegion(tsoReq.requestCtx, "pdclient.tsoReqDequeue").End() + tsoReq.done <- err + } + // Prevent the finished requests from being processed again. + tbc.collectedRequestCount = 0 +} + +func (tbc *tsoBatchController) revokePendingRequests(err error) { for i := 0; i < len(tbc.tsoRequestCh); i++ { req := <-tbc.tsoRequestCh req.done <- err diff --git a/client/tso_client.go b/client/tso_client.go index eeeaf202ce6..c563df0efdb 100644 --- a/client/tso_client.go +++ b/client/tso_client.go @@ -141,7 +141,7 @@ func (c *tsoClient) Close() { if dispatcherInterface != nil { dispatcher := dispatcherInterface.(*tsoDispatcher) tsoErr := errors.WithStack(errClosing) - dispatcher.tsoBatchController.revokePendingRequest(tsoErr) + dispatcher.tsoBatchController.revokePendingRequests(tsoErr) dispatcher.dispatcherCancel() } return true diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 3a6f109bfd4..a625f8dbbe1 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -350,7 +350,8 @@ func (c *tsoClient) createTSODispatcher(dcLocation string) { func (c *tsoClient) handleDispatcher( dispatcherCtx context.Context, dc string, - tbc *tsoBatchController) { + tbc *tsoBatchController, +) { var ( err error streamURL string @@ -428,7 +429,11 @@ tsoBatchLoop: } // Start to collect the TSO requests. maxBatchWaitInterval := c.option.getMaxTSOBatchWaitInterval() + // Once the TSO requests are collected, must make sure they could be finished or revoked eventually, + // otherwise the upper caller may get blocked on waiting for the results. if err = tbc.fetchPendingRequests(dispatcherCtx, maxBatchWaitInterval); err != nil { + // Finish the collected requests if the fetch failed. + tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) if err == context.Canceled { log.Info("[tso] stop fetching the pending tso requests due to context canceled", zap.String("dc-location", dc)) @@ -468,13 +473,16 @@ tsoBatchLoop: timer := time.NewTimer(retryInterval) select { case <-dispatcherCtx.Done(): + // Finish the collected requests if the context is canceled. + tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(dispatcherCtx.Err())) timer.Stop() return case <-streamLoopTimer.C: err = errs.ErrClientCreateTSOStream.FastGenByArgs(errs.RetryTimeoutErr) log.Error("[tso] create tso stream error", zap.String("dc-location", dc), errs.ZapError(err)) c.svcDiscovery.ScheduleCheckMemberChanged() - c.finishRequest(tbc.getCollectedRequests(), 0, 0, 0, errors.WithStack(err)) + // Finish the collected requests if the stream is failed to be created. + tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(err)) timer.Stop() continue tsoBatchLoop case <-timer.C: @@ -504,9 +512,12 @@ tsoBatchLoop: } select { case <-dispatcherCtx.Done(): + // Finish the collected requests if the context is canceled. + tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(dispatcherCtx.Err())) return case tsDeadlineCh.(chan *deadline) <- dl: } + // processRequests guarantees that the collected requests could be finished properly. err = c.processRequests(stream, dc, tbc) close(done) // If error happens during tso stream handling, reset stream and run the next trial. @@ -776,13 +787,14 @@ func (c *tsoClient) processRequests( defer span.Finish() } } + count := int64(len(requests)) reqKeyspaceGroupID := c.svcDiscovery.GetKeyspaceGroupID() respKeyspaceGroupID, physical, logical, suffixBits, err := stream.processRequests( c.svcDiscovery.GetClusterID(), c.svcDiscovery.GetKeyspaceID(), reqKeyspaceGroupID, - dcLocation, requests, tbc.batchStartTime) + dcLocation, count, tbc.batchStartTime) if err != nil { - c.finishRequest(requests, 0, 0, 0, err) + tbc.finishCollectedRequests(0, 0, 0, err) return err } // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. @@ -796,7 +808,7 @@ func (c *tsoClient) processRequests( logical: tsoutil.AddLogical(firstLogical, count-1, suffixBits), } c.compareAndSwapTS(dcLocation, curTSOInfo, physical, firstLogical) - c.finishRequest(requests, physical, firstLogical, suffixBits, nil) + tbc.finishCollectedRequests(physical, firstLogical, suffixBits, nil) return nil } @@ -843,11 +855,3 @@ func (c *tsoClient) compareAndSwapTS( lastTSOInfo.physical = curTSOInfo.physical lastTSOInfo.logical = curTSOInfo.logical } - -func (c *tsoClient) finishRequest(requests []*tsoRequest, physical, firstLogical int64, suffixBits uint32, err error) { - for i := 0; i < len(requests); i++ { - requests[i].physical, requests[i].logical = physical, tsoutil.AddLogical(firstLogical, int64(i), suffixBits) - defer trace.StartRegion(requests[i].requestCtx, "pdclient.tsoReqDequeue").End() - requests[i].done <- err - } -} diff --git a/client/tso_stream.go b/client/tso_stream.go index acefa19d21c..83c0f08d4e0 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -106,7 +106,7 @@ type tsoStream interface { // processRequests processes TSO requests in streaming mode to get timestamps processRequests( clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, - requests []*tsoRequest, batchStartTime time.Time, + count int64, batchStartTime time.Time, ) (respKeyspaceGroupID uint32, physical, logical int64, suffixBits uint32, err error) } @@ -120,10 +120,9 @@ func (s *pdTSOStream) getServerURL() string { } func (s *pdTSOStream) processRequests( - clusterID uint64, _, _ uint32, dcLocation string, requests []*tsoRequest, batchStartTime time.Time, + clusterID uint64, _, _ uint32, dcLocation string, count int64, batchStartTime time.Time, ) (respKeyspaceGroupID uint32, physical, logical int64, suffixBits uint32, err error) { start := time.Now() - count := int64(len(requests)) req := &pdpb.TsoRequest{ Header: &pdpb.RequestHeader{ ClusterId: clusterID, @@ -175,10 +174,9 @@ func (s *tsoTSOStream) getServerURL() string { func (s *tsoTSOStream) processRequests( clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, - requests []*tsoRequest, batchStartTime time.Time, + count int64, batchStartTime time.Time, ) (respKeyspaceGroupID uint32, physical, logical int64, suffixBits uint32, err error) { start := time.Now() - count := int64(len(requests)) req := &tsopb.TsoRequest{ Header: &tsopb.RequestHeader{ ClusterId: clusterID, diff --git a/tests/integrations/tso/client_test.go b/tests/integrations/tso/client_test.go index 3d7b099f342..b0bd6f1d4e5 100644 --- a/tests/integrations/tso/client_test.go +++ b/tests/integrations/tso/client_test.go @@ -21,6 +21,7 @@ import ( "math/rand" "strings" "sync" + "sync/atomic" "testing" "time" @@ -66,6 +67,10 @@ type tsoClientTestSuite struct { clients []pd.Client } +func (suite *tsoClientTestSuite) getBackendEndpoints() []string { + return strings.Split(suite.backendEndpoints, ",") +} + func TestLegacyTSOClient(t *testing.T) { suite.Run(t, &tsoClientTestSuite{ legacy: true, @@ -98,7 +103,7 @@ func (suite *tsoClientTestSuite) SetupSuite() { suite.keyspaceIDs = make([]uint32, 0) if suite.legacy { - client, err := pd.NewClientWithContext(suite.ctx, strings.Split(suite.backendEndpoints, ","), pd.SecurityOption{}, pd.WithForwardingOption(true)) + client, err := pd.NewClientWithContext(suite.ctx, suite.getBackendEndpoints(), pd.SecurityOption{}, pd.WithForwardingOption(true)) re.NoError(err) innerClient, ok := client.(interface{ GetServiceDiscovery() pd.ServiceDiscovery }) re.True(ok) @@ -173,7 +178,7 @@ func (suite *tsoClientTestSuite) waitForAllKeyspaceGroupsInServing(re *require.A // Create clients and make sure they all have discovered the tso service. suite.clients = mcs.WaitForMultiKeyspacesTSOAvailable( - suite.ctx, re, suite.keyspaceIDs, strings.Split(suite.backendEndpoints, ",")) + suite.ctx, re, suite.keyspaceIDs, suite.getBackendEndpoints()) re.Equal(len(suite.keyspaceIDs), len(suite.clients)) } @@ -254,7 +259,7 @@ func (suite *tsoClientTestSuite) TestDiscoverTSOServiceWithLegacyPath() { ctx, cancel := context.WithCancel(suite.ctx) defer cancel() client := mcs.SetupClientWithKeyspaceID( - ctx, re, keyspaceID, strings.Split(suite.backendEndpoints, ",")) + ctx, re, keyspaceID, suite.getBackendEndpoints()) defer client.Close() var lastTS uint64 for j := 0; j < tsoRequestRound; j++ { @@ -420,6 +425,50 @@ func (suite *tsoClientTestSuite) TestRandomShutdown() { re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/tso/fastUpdatePhysicalInterval")) } +func (suite *tsoClientTestSuite) TestGetTSWhileRestingTSOClient() { + re := suite.Require() + var ( + clients []pd.Client + stopSignal atomic.Bool + wg sync.WaitGroup + ) + // Create independent clients to prevent interfering with other tests. + if suite.legacy { + client, err := pd.NewClientWithContext(suite.ctx, suite.getBackendEndpoints(), pd.SecurityOption{}, pd.WithForwardingOption(true)) + re.NoError(err) + clients = []pd.Client{client} + } else { + clients = mcs.WaitForMultiKeyspacesTSOAvailable(suite.ctx, re, suite.keyspaceIDs, suite.getBackendEndpoints()) + } + wg.Add(tsoRequestConcurrencyNumber * len(clients)) + for i := 0; i < tsoRequestConcurrencyNumber; i++ { + for _, client := range clients { + go func(client pd.Client) { + defer wg.Done() + var lastTS uint64 + for !stopSignal.Load() { + physical, logical, err := client.GetTS(suite.ctx) + if err != nil { + re.ErrorContains(err, context.Canceled.Error()) + } else { + ts := tsoutil.ComposeTS(physical, logical) + re.Less(lastTS, ts) + lastTS = ts + } + } + }(client) + } + } + // Reset the TSO clients while requesting TSO concurrently. + for i := 0; i < tsoRequestConcurrencyNumber; i++ { + for _, client := range clients { + client.(interface{ ResetTSOClient() }).ResetTSOClient() + } + } + stopSignal.Store(true) + wg.Wait() +} + // When we upgrade the PD cluster, there may be a period of time that the old and new PDs are running at the same time. func TestMixedTSODeployment(t *testing.T) { re := require.New(t)