Skip to content

Commit

Permalink
Fix the bug that collected TSO requests could never be finished
Browse files Browse the repository at this point in the history
Signed-off-by: JmPotato <[email protected]>
  • Loading branch information
JmPotato committed Mar 20, 2024
1 parent b5c56b6 commit 51ad007
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 35 deletions.
43 changes: 31 additions & 12 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -606,12 +606,22 @@ func (c *client) setServiceMode(newMode pdpb.ServiceMode) {
log.Info("[pd] changing service mode",
zap.String("old-mode", c.serviceMode.String()),
zap.String("new-mode", newMode.String()))
c.resetTSOClientLocked(newMode)
oldMode := c.serviceMode
c.serviceMode = newMode
log.Info("[pd] service mode changed",
zap.String("old-mode", oldMode.String()),
zap.String("new-mode", newMode.String()))
}

// Reset a new TSO client.
func (c *client) resetTSOClientLocked(mode pdpb.ServiceMode) {
// Re-create a new TSO client.
var (
newTSOCli *tsoClient
newTSOSvcDiscovery ServiceDiscovery
)
switch newMode {
switch mode {
case pdpb.ServiceMode_PD_SVC_MODE:
newTSOCli = newTSOClient(c.ctx, c.option,
c.pdSvcDiscovery, &pdTSOStreamBuilderFactory{})
Expand Down Expand Up @@ -649,11 +659,6 @@ func (c *client) setServiceMode(newMode pdpb.ServiceMode) {
// We are switching from API service mode to PD service mode, so delete the old tso microservice discovery.
oldTSOSvcDiscovery.Close()
}
oldMode := c.serviceMode
c.serviceMode = newMode
log.Info("[pd] service mode changed",
zap.String("old-mode", oldMode.String()),
zap.String("new-mode", newMode.String()))
}

func (c *client) getTSOClient() *tsoClient {
Expand All @@ -662,6 +667,13 @@ func (c *client) getTSOClient() *tsoClient {
return c.tsoClient
}

// ResetTSOClient resets the TSO client, only for test.
func (c *client) ResetTSOClient() {
c.Lock()
defer c.Unlock()
c.resetTSOClientLocked(c.serviceMode)
}

func (c *client) getServiceMode() pdpb.ServiceMode {
c.RLock()
defer c.RUnlock()
Expand Down Expand Up @@ -779,18 +791,25 @@ func (c *client) GetLocalTSAsync(ctx context.Context, dcLocation string) TSFutur
defer span.Finish()
}

req := tsoReqPool.Get().(*tsoRequest)
req.requestCtx = ctx
req.clientCtx = c.ctx
req.start = time.Now()
req.dcLocation = dcLocation

req := c.getTSORequest(ctx, dcLocation)
if err := c.dispatchTSORequestWithRetry(req); err != nil {
req.done <- err
}
return req
}

func (c *client) getTSORequest(ctx context.Context, dcLocation string) *tsoRequest {
req := tsoReqPool.Get().(*tsoRequest)
// Set needed fields in the request before using it.
req.start = time.Now()
req.clientCtx = c.ctx
req.requestCtx = ctx
req.physical = 0
req.logical = 0
req.dcLocation = dcLocation
return req
}

const (
dispatchRetryDelay = 50 * time.Millisecond
dispatchRetryCount = 2
Expand Down
16 changes: 15 additions & 1 deletion client/tso_batch_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ package pd

import (
"context"
"runtime/trace"
"time"

"github.com/tikv/pd/client/tsoutil"
)

type tsoBatchController struct {
Expand Down Expand Up @@ -130,7 +133,18 @@ func (tbc *tsoBatchController) adjustBestBatchSize() {
}
}

func (tbc *tsoBatchController) revokePendingRequest(err error) {
func (tbc *tsoBatchController) finishCollectedRequests(physical, firstLogical int64, suffixBits uint32, err error) {
for i := 0; i < tbc.collectedRequestCount; i++ {
tsoReq := tbc.collectedRequests[i]
tsoReq.physical, tsoReq.logical = physical, tsoutil.AddLogical(firstLogical, int64(i), suffixBits)
defer trace.StartRegion(tsoReq.requestCtx, "pdclient.tsoReqDequeue").End()
tsoReq.done <- err
}
// Prevent the finished requests from being processed again.
tbc.collectedRequestCount = 0
}

func (tbc *tsoBatchController) revokePendingRequests(err error) {
for i := 0; i < len(tbc.tsoRequestCh); i++ {
req := <-tbc.tsoRequestCh
req.done <- err
Expand Down
2 changes: 1 addition & 1 deletion client/tso_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (c *tsoClient) Close() {
if dispatcherInterface != nil {
dispatcher := dispatcherInterface.(*tsoDispatcher)
tsoErr := errors.WithStack(errClosing)
dispatcher.tsoBatchController.revokePendingRequest(tsoErr)
dispatcher.tsoBatchController.revokePendingRequests(tsoErr)
dispatcher.dispatcherCancel()
}
return true
Expand Down
30 changes: 17 additions & 13 deletions client/tso_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ func (c *tsoClient) createTSODispatcher(dcLocation string) {
func (c *tsoClient) handleDispatcher(
dispatcherCtx context.Context,
dc string,
tbc *tsoBatchController) {
tbc *tsoBatchController,
) {
var (
err error
streamURL string
Expand Down Expand Up @@ -428,7 +429,11 @@ tsoBatchLoop:
}
// Start to collect the TSO requests.
maxBatchWaitInterval := c.option.getMaxTSOBatchWaitInterval()
// Once the TSO requests are collected, must make sure they could be finished or revoked eventually,
// otherwise the upper caller may get blocked on waiting for the results.
if err = tbc.fetchPendingRequests(dispatcherCtx, maxBatchWaitInterval); err != nil {
// Finish the collected requests if the fetch failed.
tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(err))
if err == context.Canceled {
log.Info("[tso] stop fetching the pending tso requests due to context canceled",
zap.String("dc-location", dc))
Expand Down Expand Up @@ -468,13 +473,16 @@ tsoBatchLoop:
timer := time.NewTimer(retryInterval)
select {
case <-dispatcherCtx.Done():
// Finish the collected requests if the context is canceled.
tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(dispatcherCtx.Err()))
timer.Stop()
return
case <-streamLoopTimer.C:
err = errs.ErrClientCreateTSOStream.FastGenByArgs(errs.RetryTimeoutErr)
log.Error("[tso] create tso stream error", zap.String("dc-location", dc), errs.ZapError(err))
c.svcDiscovery.ScheduleCheckMemberChanged()
c.finishRequest(tbc.getCollectedRequests(), 0, 0, 0, errors.WithStack(err))
// Finish the collected requests if the stream is failed to be created.
tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(err))
timer.Stop()
continue tsoBatchLoop
case <-timer.C:
Expand Down Expand Up @@ -504,9 +512,12 @@ tsoBatchLoop:
}
select {
case <-dispatcherCtx.Done():
// Finish the collected requests if the context is canceled.
tbc.finishCollectedRequests(0, 0, 0, errors.WithStack(dispatcherCtx.Err()))
return
case tsDeadlineCh.(chan *deadline) <- dl:
}
// processRequests guarantees that the collected requests could be finished properly.
err = c.processRequests(stream, dc, tbc)
close(done)
// If error happens during tso stream handling, reset stream and run the next trial.
Expand Down Expand Up @@ -776,13 +787,14 @@ func (c *tsoClient) processRequests(
defer span.Finish()
}
}

count := int64(len(requests))
reqKeyspaceGroupID := c.svcDiscovery.GetKeyspaceGroupID()
respKeyspaceGroupID, physical, logical, suffixBits, err := stream.processRequests(
c.svcDiscovery.GetClusterID(), c.svcDiscovery.GetKeyspaceID(), reqKeyspaceGroupID,
dcLocation, requests, tbc.batchStartTime)
dcLocation, count, tbc.batchStartTime)
if err != nil {
c.finishRequest(requests, 0, 0, 0, err)
tbc.finishCollectedRequests(0, 0, 0, err)
return err
}
// `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request.
Expand All @@ -796,7 +808,7 @@ func (c *tsoClient) processRequests(
logical: tsoutil.AddLogical(firstLogical, count-1, suffixBits),
}
c.compareAndSwapTS(dcLocation, curTSOInfo, physical, firstLogical)
c.finishRequest(requests, physical, firstLogical, suffixBits, nil)
tbc.finishCollectedRequests(physical, firstLogical, suffixBits, nil)
return nil
}

Expand Down Expand Up @@ -843,11 +855,3 @@ func (c *tsoClient) compareAndSwapTS(
lastTSOInfo.physical = curTSOInfo.physical
lastTSOInfo.logical = curTSOInfo.logical
}

func (c *tsoClient) finishRequest(requests []*tsoRequest, physical, firstLogical int64, suffixBits uint32, err error) {
for i := 0; i < len(requests); i++ {
requests[i].physical, requests[i].logical = physical, tsoutil.AddLogical(firstLogical, int64(i), suffixBits)
defer trace.StartRegion(requests[i].requestCtx, "pdclient.tsoReqDequeue").End()
requests[i].done <- err
}
}
8 changes: 3 additions & 5 deletions client/tso_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ type tsoStream interface {
// processRequests processes TSO requests in streaming mode to get timestamps
processRequests(
clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string,
requests []*tsoRequest, batchStartTime time.Time,
count int64, batchStartTime time.Time,
) (respKeyspaceGroupID uint32, physical, logical int64, suffixBits uint32, err error)
}

Expand All @@ -120,10 +120,9 @@ func (s *pdTSOStream) getServerURL() string {
}

func (s *pdTSOStream) processRequests(
clusterID uint64, _, _ uint32, dcLocation string, requests []*tsoRequest, batchStartTime time.Time,
clusterID uint64, _, _ uint32, dcLocation string, count int64, batchStartTime time.Time,
) (respKeyspaceGroupID uint32, physical, logical int64, suffixBits uint32, err error) {
start := time.Now()
count := int64(len(requests))
req := &pdpb.TsoRequest{
Header: &pdpb.RequestHeader{
ClusterId: clusterID,
Expand Down Expand Up @@ -175,10 +174,9 @@ func (s *tsoTSOStream) getServerURL() string {

func (s *tsoTSOStream) processRequests(
clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string,
requests []*tsoRequest, batchStartTime time.Time,
count int64, batchStartTime time.Time,
) (respKeyspaceGroupID uint32, physical, logical int64, suffixBits uint32, err error) {
start := time.Now()
count := int64(len(requests))
req := &tsopb.TsoRequest{
Header: &tsopb.RequestHeader{
ClusterId: clusterID,
Expand Down
55 changes: 52 additions & 3 deletions tests/integrations/tso/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"math/rand"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -66,6 +67,10 @@ type tsoClientTestSuite struct {
clients []pd.Client
}

func (suite *tsoClientTestSuite) getBackendEndpoints() []string {
return strings.Split(suite.backendEndpoints, ",")
}

func TestLegacyTSOClient(t *testing.T) {
suite.Run(t, &tsoClientTestSuite{
legacy: true,
Expand Down Expand Up @@ -98,7 +103,7 @@ func (suite *tsoClientTestSuite) SetupSuite() {
suite.keyspaceIDs = make([]uint32, 0)

if suite.legacy {
client, err := pd.NewClientWithContext(suite.ctx, strings.Split(suite.backendEndpoints, ","), pd.SecurityOption{}, pd.WithForwardingOption(true))
client, err := pd.NewClientWithContext(suite.ctx, suite.getBackendEndpoints(), pd.SecurityOption{}, pd.WithForwardingOption(true))
re.NoError(err)
innerClient, ok := client.(interface{ GetServiceDiscovery() pd.ServiceDiscovery })
re.True(ok)
Expand Down Expand Up @@ -173,7 +178,7 @@ func (suite *tsoClientTestSuite) waitForAllKeyspaceGroupsInServing(re *require.A

// Create clients and make sure they all have discovered the tso service.
suite.clients = mcs.WaitForMultiKeyspacesTSOAvailable(
suite.ctx, re, suite.keyspaceIDs, strings.Split(suite.backendEndpoints, ","))
suite.ctx, re, suite.keyspaceIDs, suite.getBackendEndpoints())
re.Equal(len(suite.keyspaceIDs), len(suite.clients))
}

Expand Down Expand Up @@ -254,7 +259,7 @@ func (suite *tsoClientTestSuite) TestDiscoverTSOServiceWithLegacyPath() {
ctx, cancel := context.WithCancel(suite.ctx)
defer cancel()
client := mcs.SetupClientWithKeyspaceID(
ctx, re, keyspaceID, strings.Split(suite.backendEndpoints, ","))
ctx, re, keyspaceID, suite.getBackendEndpoints())
defer client.Close()
var lastTS uint64
for j := 0; j < tsoRequestRound; j++ {
Expand Down Expand Up @@ -420,6 +425,50 @@ func (suite *tsoClientTestSuite) TestRandomShutdown() {
re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/tso/fastUpdatePhysicalInterval"))
}

func (suite *tsoClientTestSuite) TestGetTSWhileRestingTSOClient() {
re := suite.Require()
var (
clients []pd.Client
stopSignal atomic.Bool
wg sync.WaitGroup
)
// Create independent clients to prevent interfering with other tests.
if suite.legacy {
client, err := pd.NewClientWithContext(suite.ctx, suite.getBackendEndpoints(), pd.SecurityOption{}, pd.WithForwardingOption(true))
re.NoError(err)
clients = []pd.Client{client}
} else {
clients = mcs.WaitForMultiKeyspacesTSOAvailable(suite.ctx, re, suite.keyspaceIDs, suite.getBackendEndpoints())
}
wg.Add(tsoRequestConcurrencyNumber * len(clients))
for i := 0; i < tsoRequestConcurrencyNumber; i++ {
for _, client := range clients {
go func(client pd.Client) {
defer wg.Done()
var lastTS uint64
for !stopSignal.Load() {
physical, logical, err := client.GetTS(suite.ctx)
if err != nil {
re.ErrorContains(err, context.Canceled.Error())
} else {
ts := tsoutil.ComposeTS(physical, logical)
re.Less(lastTS, ts)
lastTS = ts
}
}
}(client)
}
}
// Reset the TSO clients while requesting TSO concurrently.
for i := 0; i < tsoRequestConcurrencyNumber; i++ {
for _, client := range clients {
client.(interface{ ResetTSOClient() }).ResetTSOClient()
}
}
stopSignal.Store(true)
wg.Wait()
}

// When we upgrade the PD cluster, there may be a period of time that the old and new PDs are running at the same time.
func TestMixedTSODeployment(t *testing.T) {
re := require.New(t)
Expand Down

0 comments on commit 51ad007

Please sign in to comment.