From 8cd72333f15ffbf391f836f618baa4685ef64a65 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Fri, 27 Dec 2024 16:06:15 +0800 Subject: [PATCH] client: introduce the connection ctx manager (#8940) ref tikv/pd#8690 Previously, we used a `sync.Map` as a medium to propagate connection ctx updates between the dispatcher and TSO client, which introduced a lot of redundant parameter passing and made the logic less intuitive. This PR implements the same functionality using a common connection ctx manager to simplify and reuse related code. Signed-off-by: JmPotato --- client/clients/tso/client.go | 143 ++++++++------ client/clients/tso/dispatcher.go | 178 +++++------------- client/clients/tso/dispatcher_test.go | 19 +- client/pkg/connectionctx/manager.go | 143 ++++++++++++++ client/pkg/connectionctx/manager_test.go | 83 ++++++++ .../servicediscovery/pd_service_discovery.go | 9 +- 6 files changed, 371 insertions(+), 204 deletions(-) create mode 100644 client/pkg/connectionctx/manager.go create mode 100644 client/pkg/connectionctx/manager_test.go diff --git a/client/clients/tso/client.go b/client/clients/tso/client.go index c26dd25f2ad..c6caa8b985f 100644 --- a/client/clients/tso/client.go +++ b/client/clients/tso/client.go @@ -36,6 +36,7 @@ import ( "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" + cctx "github.com/tikv/pd/client/pkg/connectionctx" "github.com/tikv/pd/client/pkg/utils/grpcutil" "github.com/tikv/pd/client/pkg/utils/tlsutil" sd "github.com/tikv/pd/client/servicediscovery" @@ -80,7 +81,9 @@ type Cli struct { svcDiscovery sd.ServiceDiscovery tsoStreamBuilderFactory // leaderURL is the URL of the TSO leader. - leaderURL atomic.Value + leaderURL atomic.Value + conCtxMgr *cctx.Manager[*tsoStream] + updateConCtxsCh chan struct{} // tsoReqPool is the pool to recycle `*tsoRequest`. tsoReqPool *sync.Pool @@ -100,6 +103,8 @@ func NewClient( option: option, svcDiscovery: svcDiscovery, tsoStreamBuilderFactory: factory, + conCtxMgr: cctx.NewManager[*tsoStream](), + updateConCtxsCh: make(chan struct{}, 1), tsoReqPool: &sync.Pool{ New: func() any { return &Request{ @@ -122,6 +127,8 @@ func (c *Cli) getOption() *opt.Option { return c.option } func (c *Cli) getServiceDiscovery() sd.ServiceDiscovery { return c.svcDiscovery } +func (c *Cli) getConnectionCtxMgr() *cctx.Manager[*tsoStream] { return c.conCtxMgr } + func (c *Cli) getDispatcher() *tsoDispatcher { return c.dispatcher.Load() } @@ -133,6 +140,8 @@ func (c *Cli) GetRequestPool() *sync.Pool { // Setup initializes the TSO client. func (c *Cli) Setup() { + // Daemon goroutine to update the connectionCtxs periodically and handle the `connectionCtxs` update event. + go c.connectionCtxsUpdater() if err := c.svcDiscovery.CheckMemberChanged(); err != nil { log.Warn("[tso] failed to check member changed", errs.ZapError(err)) } @@ -154,9 +163,12 @@ func (c *Cli) Close() { log.Info("[tso] tso client is closed") } -// scheduleUpdateTSOConnectionCtxs update the TSO connection contexts. +// scheduleUpdateTSOConnectionCtxs schedules the update of the TSO connection contexts. func (c *Cli) scheduleUpdateTSOConnectionCtxs() { - c.getDispatcher().scheduleUpdateConnectionCtxs() + select { + case c.updateConCtxsCh <- struct{}{}: + default: + } } // GetTSORequest gets a TSO request from the pool. @@ -231,25 +243,66 @@ func (c *Cli) backupClientConn() (*grpc.ClientConn, string) { return nil, "" } -// tsoConnectionContext is used to store the context of a TSO stream connection. -type tsoConnectionContext struct { - ctx context.Context - cancel context.CancelFunc - // Current URL of the stream connection. - streamURL string - // Current stream to send gRPC requests. - stream *tsoStream +// connectionCtxsUpdater updates the `connectionCtxs` regularly. +func (c *Cli) connectionCtxsUpdater() { + log.Info("[tso] start tso connection contexts updater") + + var updateTicker = &time.Ticker{} + setNewUpdateTicker := func(interval time.Duration) { + if updateTicker.C != nil { + updateTicker.Stop() + } + if interval == 0 { + updateTicker = &time.Ticker{} + } else { + updateTicker = time.NewTicker(interval) + } + } + // If the TSO Follower Proxy is enabled, set the update interval to the member update interval. + if c.option.GetEnableTSOFollowerProxy() { + setNewUpdateTicker(sd.MemberUpdateInterval) + } + // Set to nil before returning to ensure that the existing ticker can be GC. + defer setNewUpdateTicker(0) + + ctx, cancel := context.WithCancel(c.ctx) + defer cancel() + for { + c.updateConnectionCtxs(ctx) + select { + case <-ctx.Done(): + log.Info("[tso] exit tso connection contexts updater") + return + case <-c.option.EnableTSOFollowerProxyCh: + enableTSOFollowerProxy := c.option.GetEnableTSOFollowerProxy() + log.Info("[tso] tso follower proxy status changed", + zap.Bool("enable", enableTSOFollowerProxy)) + if enableTSOFollowerProxy && updateTicker.C == nil { + // Because the TSO Follower Proxy is enabled, + // the periodic check needs to be performed. + setNewUpdateTicker(sd.MemberUpdateInterval) + } else if !enableTSOFollowerProxy && updateTicker.C != nil { + // Because the TSO Follower Proxy is disabled, + // the periodic check needs to be turned off. + setNewUpdateTicker(0) + } + case <-updateTicker.C: + // Triggered periodically when the TSO Follower Proxy is enabled. + case <-c.updateConCtxsCh: + // Triggered by the leader/follower change. + } + } } // updateConnectionCtxs will choose the proper way to update the connections. // It will return a bool to indicate whether the update is successful. -func (c *Cli) updateConnectionCtxs(ctx context.Context, connectionCtxs *sync.Map) bool { +func (c *Cli) updateConnectionCtxs(ctx context.Context) bool { // Normal connection creating, it will be affected by the `enableForwarding`. createTSOConnection := c.tryConnectToTSO if c.option.GetEnableTSOFollowerProxy() { createTSOConnection = c.tryConnectToTSOWithProxy } - if err := createTSOConnection(ctx, connectionCtxs); err != nil { + if err := createTSOConnection(ctx); err != nil { log.Error("[tso] update connection contexts failed", errs.ZapError(err)) return false } @@ -260,30 +313,13 @@ func (c *Cli) updateConnectionCtxs(ctx context.Context, connectionCtxs *sync.Map // and enableForwarding is true, it will create a new connection to a follower to do the forwarding, // while a new daemon will be created also to switch back to a normal leader connection ASAP the // connection comes back to normal. -func (c *Cli) tryConnectToTSO( - ctx context.Context, - connectionCtxs *sync.Map, -) error { +func (c *Cli) tryConnectToTSO(ctx context.Context) error { var ( - networkErrNum uint64 - err error - stream *tsoStream - url string - cc *grpc.ClientConn - updateAndClear = func(newURL string, connectionCtx *tsoConnectionContext) { - // Only store the `connectionCtx` if it does not exist before. - if connectionCtx != nil { - connectionCtxs.LoadOrStore(newURL, connectionCtx) - } - // Remove all other `connectionCtx`s. - connectionCtxs.Range(func(url, cc any) bool { - if url.(string) != newURL { - cc.(*tsoConnectionContext).cancel() - connectionCtxs.Delete(url) - } - return true - }) - } + networkErrNum uint64 + err error + stream *tsoStream + url string + cc *grpc.ClientConn ) ticker := time.NewTicker(constants.RetryInterval) @@ -292,9 +328,9 @@ func (c *Cli) tryConnectToTSO( for range constants.MaxRetryTimes { c.svcDiscovery.ScheduleCheckMemberChanged() cc, url = c.getTSOLeaderClientConn() - if _, ok := connectionCtxs.Load(url); ok { + if c.conCtxMgr.Exist(url) { // Just trigger the clean up of the stale connection contexts. - updateAndClear(url, nil) + c.conCtxMgr.CleanAllAndStore(ctx, url) return nil } if cc != nil { @@ -305,7 +341,7 @@ func (c *Cli) tryConnectToTSO( err = status.New(codes.Unavailable, "unavailable").Err() }) if stream != nil && err == nil { - updateAndClear(url, &tsoConnectionContext{cctx, cancel, url, stream}) + c.conCtxMgr.CleanAllAndStore(ctx, url, stream) return nil } @@ -348,9 +384,9 @@ func (c *Cli) tryConnectToTSO( forwardedHostTrim := tlsutil.TrimHTTPPrefix(forwardedHost) addr := tlsutil.TrimHTTPPrefix(backupURL) // the goroutine is used to check the network and change back to the original stream - go c.checkLeader(ctx, cancel, forwardedHostTrim, addr, url, updateAndClear) + go c.checkLeader(ctx, cancel, forwardedHostTrim, addr, url) metrics.RequestForwarded.WithLabelValues(forwardedHostTrim, addr).Set(1) - updateAndClear(backupURL, &tsoConnectionContext{cctx, cancel, backupURL, stream}) + c.conCtxMgr.CleanAllAndStore(ctx, backupURL, stream) return nil } cancel() @@ -363,7 +399,6 @@ func (c *Cli) checkLeader( ctx context.Context, forwardCancel context.CancelFunc, forwardedHostTrim, addr, url string, - updateAndClear func(newAddr string, connectionCtx *tsoConnectionContext), ) { defer func() { // cancel the forward stream @@ -396,7 +431,7 @@ func (c *Cli) checkLeader( stream, err := c.tsoStreamBuilderFactory.makeBuilder(cc).build(cctx, cancel, c.option.Timeout) if err == nil && stream != nil { log.Info("[tso] recover the original tso stream since the network has become normal", zap.String("url", url)) - updateAndClear(url, &tsoConnectionContext{cctx, cancel, url, stream}) + c.conCtxMgr.CleanAllAndStore(ctx, url, stream) return } } @@ -413,10 +448,7 @@ func (c *Cli) checkLeader( // tryConnectToTSOWithProxy will create multiple streams to all the service endpoints to work as // a TSO proxy to reduce the pressure of the main serving service endpoint. -func (c *Cli) tryConnectToTSOWithProxy( - ctx context.Context, - connectionCtxs *sync.Map, -) error { +func (c *Cli) tryConnectToTSOWithProxy(ctx context.Context) error { tsoStreamBuilders := c.getAllTSOStreamBuilders() leaderAddr := c.svcDiscovery.GetServingURL() forwardedHost := c.getLeaderURL() @@ -424,20 +456,17 @@ func (c *Cli) tryConnectToTSOWithProxy( return errors.Errorf("cannot find the tso leader") } // GC the stale one. - connectionCtxs.Range(func(addr, cc any) bool { - addrStr := addr.(string) - if _, ok := tsoStreamBuilders[addrStr]; !ok { + c.conCtxMgr.GC(func(addr string) bool { + _, ok := tsoStreamBuilders[addr] + if !ok { log.Info("[tso] remove the stale tso stream", - zap.String("addr", addrStr)) - cc.(*tsoConnectionContext).cancel() - connectionCtxs.Delete(addr) + zap.String("addr", addr)) } - return true + return !ok }) // Update the missing one. for addr, tsoStreamBuilder := range tsoStreamBuilders { - _, ok := connectionCtxs.Load(addr) - if ok { + if c.conCtxMgr.Exist(addr) { continue } log.Info("[tso] try to create tso stream", zap.String("addr", addr)) @@ -456,7 +485,7 @@ func (c *Cli) tryConnectToTSOWithProxy( addrTrim := tlsutil.TrimHTTPPrefix(addr) metrics.RequestForwarded.WithLabelValues(forwardedHostTrim, addrTrim).Set(1) } - connectionCtxs.Store(addr, &tsoConnectionContext{cctx, cancel, addr, stream}) + c.conCtxMgr.Store(ctx, addr, stream) continue } log.Error("[tso] create the tso stream failed", diff --git a/client/clients/tso/dispatcher.go b/client/clients/tso/dispatcher.go index 58722088886..c05ab27d755 100644 --- a/client/clients/tso/dispatcher.go +++ b/client/clients/tso/dispatcher.go @@ -18,7 +18,6 @@ import ( "context" "fmt" "math" - "math/rand" "runtime/trace" "sync" "sync/atomic" @@ -36,6 +35,7 @@ import ( "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" "github.com/tikv/pd/client/pkg/batch" + cctx "github.com/tikv/pd/client/pkg/connectionctx" "github.com/tikv/pd/client/pkg/retry" "github.com/tikv/pd/client/pkg/utils/timerutil" "github.com/tikv/pd/client/pkg/utils/tsoutil" @@ -76,7 +76,8 @@ type tsoInfo struct { type tsoServiceProvider interface { getOption() *opt.Option getServiceDiscovery() sd.ServiceDiscovery - updateConnectionCtxs(ctx context.Context, connectionCtxs *sync.Map) bool + getConnectionCtxMgr() *cctx.Manager[*tsoStream] + updateConnectionCtxs(ctx context.Context) bool } const dispatcherCheckRPCConcurrencyInterval = time.Second * 5 @@ -85,12 +86,10 @@ type tsoDispatcher struct { ctx context.Context cancel context.CancelFunc - provider tsoServiceProvider - // URL -> *connectionContext - connectionCtxs *sync.Map - tsoRequestCh chan *Request - tsDeadlineCh chan *deadline - latestTSOInfo atomic.Pointer[tsoInfo] + provider tsoServiceProvider + tsoRequestCh chan *Request + tsDeadlineCh chan *deadline + latestTSOInfo atomic.Pointer[tsoInfo] // For reusing `*batchController` objects batchBufferPool *sync.Pool @@ -102,8 +101,6 @@ type tsoDispatcher struct { lastCheckConcurrencyTime time.Time tokenCount int rpcConcurrency int - - updateConnectionCtxsCh chan struct{} } func newTSODispatcher( @@ -122,12 +119,11 @@ func newTSODispatcher( tokenCh := make(chan struct{}, tokenChCapacity) td := &tsoDispatcher{ - ctx: dispatcherCtx, - cancel: dispatcherCancel, - provider: provider, - connectionCtxs: &sync.Map{}, - tsoRequestCh: tsoRequestCh, - tsDeadlineCh: make(chan *deadline, tokenChCapacity), + ctx: dispatcherCtx, + cancel: dispatcherCancel, + provider: provider, + tsoRequestCh: tsoRequestCh, + tsDeadlineCh: make(chan *deadline, tokenChCapacity), batchBufferPool: &sync.Pool{ New: func() any { return batch.NewController[*Request]( @@ -137,8 +133,7 @@ func newTSODispatcher( ) }, }, - tokenCh: tokenCh, - updateConnectionCtxsCh: make(chan struct{}, 1), + tokenCh: tokenCh, } go td.watchTSDeadline() return td @@ -168,13 +163,6 @@ func (td *tsoDispatcher) watchTSDeadline() { } } -func (td *tsoDispatcher) scheduleUpdateConnectionCtxs() { - select { - case td.updateConnectionCtxsCh <- struct{}{}: - default: - } -} - func (td *tsoDispatcher) revokePendingRequests(err error) { for range len(td.tsoRequestCh) { req := <-td.tsoRequestCh @@ -196,9 +184,9 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { var ( ctx = td.ctx provider = td.provider - svcDiscovery = provider.getServiceDiscovery() option = provider.getOption() - connectionCtxs = td.connectionCtxs + svcDiscovery = provider.getServiceDiscovery() + conCtxMgr = provider.getConnectionCtxMgr() tsoBatchController *batch.Controller[*Request] ) @@ -207,10 +195,7 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { defer func() { log.Info("[tso] exit tso dispatcher") // Cancel all connections. - connectionCtxs.Range(func(_, cc any) bool { - cc.(*tsoConnectionContext).cancel() - return true - }) + conCtxMgr.ReleaseAll() if tsoBatchController != nil && tsoBatchController.GetCollectedRequestCount() != 0 { // If you encounter this failure, please check the stack in the logs to see if it's a panic. log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop", zap.Any("panic", recover())) @@ -219,8 +204,6 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { td.revokePendingRequests(tsoErr) wg.Done() }() - // Daemon goroutine to update the connectionCtxs periodically and handle the `connectionCtxs` update event. - go td.connectionCtxsUpdater() var ( err error @@ -291,14 +274,14 @@ tsoBatchLoop: // Choose a stream to send the TSO gRPC request. streamChoosingLoop: for { - connectionCtx := chooseStream(connectionCtxs) + connectionCtx := conCtxMgr.GetConnectionCtx() if connectionCtx != nil { - streamCtx, cancel, streamURL, stream = connectionCtx.ctx, connectionCtx.cancel, connectionCtx.streamURL, connectionCtx.stream + streamCtx, cancel, streamURL, stream = connectionCtx.Ctx, connectionCtx.Cancel, connectionCtx.StreamURL, connectionCtx.Stream } // Check stream and retry if necessary. if stream == nil { log.Info("[tso] tso stream is not ready") - if provider.updateConnectionCtxs(ctx, connectionCtxs) { + if provider.updateConnectionCtxs(ctx) { continue streamChoosingLoop } timer := time.NewTimer(constants.RetryInterval) @@ -325,8 +308,7 @@ tsoBatchLoop: case <-streamCtx.Done(): log.Info("[tso] tso stream is canceled", zap.String("stream-url", streamURL)) // Set `stream` to nil and remove this stream from the `connectionCtxs` due to being canceled. - connectionCtxs.Delete(streamURL) - cancel() + conCtxMgr.Release(streamURL) stream = nil continue default: @@ -334,7 +316,7 @@ tsoBatchLoop: // Check if any error has occurred on this stream when receiving asynchronously. if err = stream.GetRecvError(); err != nil { - exit := !td.handleProcessRequestError(ctx, bo, streamURL, cancel, err) + exit := !td.handleProcessRequestError(ctx, bo, conCtxMgr, streamURL, err) stream = nil if exit { td.cancelCollectedRequests(tsoBatchController, invalidStreamID, errors.WithStack(ctx.Err())) @@ -419,7 +401,7 @@ tsoBatchLoop: // reused in the next loop safely. tsoBatchController = nil } else { - exit := !td.handleProcessRequestError(ctx, bo, streamURL, cancel, err) + exit := !td.handleProcessRequestError(ctx, bo, conCtxMgr, streamURL, err) stream = nil if exit { return @@ -430,110 +412,44 @@ tsoBatchLoop: // handleProcessRequestError handles errors occurs when trying to process a TSO RPC request for the dispatcher loop. // Returns true if the dispatcher loop is ok to continue. Otherwise, the dispatcher loop should be exited. -func (td *tsoDispatcher) handleProcessRequestError(ctx context.Context, bo *retry.Backoffer, streamURL string, streamCancelFunc context.CancelFunc, err error) bool { +func (td *tsoDispatcher) handleProcessRequestError( + ctx context.Context, + bo *retry.Backoffer, + conCtxMgr *cctx.Manager[*tsoStream], + streamURL string, + err error, +) bool { + log.Error("[tso] getTS error after processing requests", + zap.String("stream-url", streamURL), + zap.Error(errs.ErrClientGetTSO.FastGenByArgs(err.Error()))) + select { case <-ctx.Done(): return false default: } + // Release this stream from the manager due to error. + conCtxMgr.Release(streamURL) + // Update the member list to ensure the latest topology is used before the next batch. svcDiscovery := td.provider.getServiceDiscovery() - - svcDiscovery.ScheduleCheckMemberChanged() - log.Error("[tso] getTS error after processing requests", - zap.String("stream-url", streamURL), - zap.Error(errs.ErrClientGetTSO.FastGenByArgs(err.Error()))) - // Set `stream` to nil and remove this stream from the `connectionCtxs` due to error. - td.connectionCtxs.Delete(streamURL) - streamCancelFunc() - // Because ScheduleCheckMemberChanged is asynchronous, if the leader changes, we better call `updateMember` ASAP. if errs.IsLeaderChange(err) { + // If the leader changed, we better call `CheckMemberChanged` blockingly to + // ensure the next round of TSO requests can be sent to the new leader. if err := bo.Exec(ctx, svcDiscovery.CheckMemberChanged); err != nil { - select { - case <-ctx.Done(): - return false - default: - } + log.Error("[tso] check member changed error after the leader changed", zap.Error(err)) } - // Because the TSO Follower Proxy could be configured online, - // If we change it from on -> off, background updateConnectionCtxs - // will cancel the current stream, then the EOF error caused by cancel() - // should not trigger the updateConnectionCtxs here. - // So we should only call it when the leader changes. - td.provider.updateConnectionCtxs(ctx, td.connectionCtxs) + } else { + // For other errors, we can just schedule a member change check asynchronously. + svcDiscovery.ScheduleCheckMemberChanged() } - return true -} - -// updateConnectionCtxs updates the `connectionCtxs` regularly. -func (td *tsoDispatcher) connectionCtxsUpdater() { - var ( - ctx = td.ctx - connectionCtxs = td.connectionCtxs - provider = td.provider - option = td.provider.getOption() - updateTicker = &time.Ticker{} - ) - - log.Info("[tso] start tso connection contexts updater") - setNewUpdateTicker := func(interval time.Duration) { - if updateTicker.C != nil { - updateTicker.Stop() - } - if interval == 0 { - updateTicker = &time.Ticker{} - } else { - updateTicker = time.NewTicker(interval) - } - } - // If the TSO Follower Proxy is enabled, set the update interval to the member update interval. - if option.GetEnableTSOFollowerProxy() { - setNewUpdateTicker(sd.MemberUpdateInterval) - } - // Set to nil before returning to ensure that the existing ticker can be GC. - defer setNewUpdateTicker(0) - - for { - provider.updateConnectionCtxs(ctx, connectionCtxs) - select { - case <-ctx.Done(): - log.Info("[tso] exit tso connection contexts updater") - return - case <-option.EnableTSOFollowerProxyCh: - enableTSOFollowerProxy := option.GetEnableTSOFollowerProxy() - log.Info("[tso] tso follower proxy status changed", - zap.Bool("enable", enableTSOFollowerProxy)) - if enableTSOFollowerProxy && updateTicker.C == nil { - // Because the TSO Follower Proxy is enabled, - // the periodic check needs to be performed. - setNewUpdateTicker(sd.MemberUpdateInterval) - } else if !enableTSOFollowerProxy && updateTicker.C != nil { - // Because the TSO Follower Proxy is disabled, - // the periodic check needs to be turned off. - setNewUpdateTicker(0) - } - case <-updateTicker.C: - // Triggered periodically when the TSO Follower Proxy is enabled. - case <-td.updateConnectionCtxsCh: - // Triggered by the leader/follower change. - } - } -} - -// chooseStream uses the reservoir sampling algorithm to randomly choose a connection. -// connectionCtxs will only have only one stream to choose when the TSO Follower Proxy is off. -func chooseStream(connectionCtxs *sync.Map) (connectionCtx *tsoConnectionContext) { - idx := 0 - connectionCtxs.Range(func(_, cc any) bool { - j := rand.Intn(idx + 1) - if j < 1 { - connectionCtx = cc.(*tsoConnectionContext) - } - idx++ + select { + case <-ctx.Done(): + return false + default: return true - }) - return connectionCtx + } } // processRequests sends the RPC request for the batch. It's guaranteed that after calling this function, requests diff --git a/client/clients/tso/dispatcher_test.go b/client/clients/tso/dispatcher_test.go index cefc53f3944..7e5554c7c7b 100644 --- a/client/clients/tso/dispatcher_test.go +++ b/client/clients/tso/dispatcher_test.go @@ -30,19 +30,21 @@ import ( "github.com/pingcap/log" "github.com/tikv/pd/client/opt" + cctx "github.com/tikv/pd/client/pkg/connectionctx" sd "github.com/tikv/pd/client/servicediscovery" ) type mockTSOServiceProvider struct { option *opt.Option createStream func(ctx context.Context) *tsoStream - updateConnMu sync.Mutex + conCtxMgr *cctx.Manager[*tsoStream] } func newMockTSOServiceProvider(option *opt.Option, createStream func(ctx context.Context) *tsoStream) *mockTSOServiceProvider { return &mockTSOServiceProvider{ option: option, createStream: createStream, + conCtxMgr: cctx.NewManager[*tsoStream](), } } @@ -54,24 +56,21 @@ func (*mockTSOServiceProvider) getServiceDiscovery() sd.ServiceDiscovery { return sd.NewMockPDServiceDiscovery([]string{mockStreamURL}, nil) } -func (m *mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, connectionCtxs *sync.Map) bool { - // Avoid concurrent updating in the background updating goroutine and active updating in the dispatcher loop when - // stream is missing. - m.updateConnMu.Lock() - defer m.updateConnMu.Unlock() +func (m *mockTSOServiceProvider) getConnectionCtxMgr() *cctx.Manager[*tsoStream] { + return m.conCtxMgr +} - _, ok := connectionCtxs.Load(mockStreamURL) - if ok { +func (m *mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context) bool { + if m.conCtxMgr.Exist(mockStreamURL) { return true } - ctx, cancel := context.WithCancel(ctx) var stream *tsoStream if m.createStream == nil { stream = newTSOStream(ctx, mockStreamURL, newMockTSOStreamImpl(ctx, resultModeGenerated)) } else { stream = m.createStream(ctx) } - connectionCtxs.LoadOrStore(mockStreamURL, &tsoConnectionContext{ctx, cancel, mockStreamURL, stream}) + m.conCtxMgr.Store(ctx, mockStreamURL, stream) return true } diff --git a/client/pkg/connectionctx/manager.go b/client/pkg/connectionctx/manager.go new file mode 100644 index 00000000000..04c1eb13d3a --- /dev/null +++ b/client/pkg/connectionctx/manager.go @@ -0,0 +1,143 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionctx + +import ( + "context" + "sync" + + "golang.org/x/exp/rand" +) + +type connectionCtx[T any] struct { + Ctx context.Context + Cancel context.CancelFunc + // Current URL of the stream connection. + StreamURL string + // Current stream to send the gRPC requests. + Stream T +} + +// Manager is used to manage the connection contexts. +type Manager[T any] struct { + sync.RWMutex + connectionCtxs map[string]*connectionCtx[T] +} + +// NewManager is used to create a new connection context manager. +func NewManager[T any]() *Manager[T] { + return &Manager[T]{ + connectionCtxs: make(map[string]*connectionCtx[T], 3), + } +} + +// Exist is used to check if the connection context exists by the given URL. +func (c *Manager[T]) Exist(url string) bool { + c.RLock() + defer c.RUnlock() + _, ok := c.connectionCtxs[url] + return ok +} + +// Store is used to store the connection context, `overwrite` is used to force the store operation +// no matter whether the connection context exists before, which is false by default. +func (c *Manager[T]) Store(ctx context.Context, url string, stream T, overwrite ...bool) { + c.Lock() + defer c.Unlock() + overwriteFlag := false + if len(overwrite) > 0 { + overwriteFlag = overwrite[0] + } + _, ok := c.connectionCtxs[url] + if !overwriteFlag && ok { + return + } + c.storeLocked(ctx, url, stream) +} + +func (c *Manager[T]) storeLocked(ctx context.Context, url string, stream T) { + c.releaseLocked(url) + cctx, cancel := context.WithCancel(ctx) + c.connectionCtxs[url] = &connectionCtx[T]{cctx, cancel, url, stream} +} + +// CleanAllAndStore is used to store the connection context exclusively. It will release +// all other connection contexts. `stream` is optional, if it is not provided, all +// connection contexts other than the given `url` will be released. +func (c *Manager[T]) CleanAllAndStore(ctx context.Context, url string, stream ...T) { + c.Lock() + defer c.Unlock() + // Remove all other `connectionCtx`s. + c.gcLocked(func(curURL string) bool { + return curURL != url + }) + if len(stream) == 0 { + return + } + c.storeLocked(ctx, url, stream[0]) +} + +// GC is used to release all connection contexts that match the given condition. +func (c *Manager[T]) GC(condition func(url string) bool) { + c.Lock() + defer c.Unlock() + c.gcLocked(condition) +} + +func (c *Manager[T]) gcLocked(condition func(url string) bool) { + for url := range c.connectionCtxs { + if condition(url) { + c.releaseLocked(url) + } + } +} + +// ReleaseAll is used to release all connection contexts. +func (c *Manager[T]) ReleaseAll() { + c.GC(func(string) bool { return true }) +} + +// Release is used to delete a connection context from the connection context map and release the resources. +func (c *Manager[T]) Release(url string) { + c.Lock() + defer c.Unlock() + c.releaseLocked(url) +} + +func (c *Manager[T]) releaseLocked(url string) { + cc, ok := c.connectionCtxs[url] + if !ok { + return + } + cc.Cancel() + delete(c.connectionCtxs, url) +} + +// GetConnectionCtx is used to get a connection context from the connection context map. +// It uses the reservoir sampling algorithm to randomly pick one connection context. +func (c *Manager[T]) GetConnectionCtx() *connectionCtx[T] { + c.RLock() + defer c.RUnlock() + idx := 0 + var connectionCtx *connectionCtx[T] + for _, cc := range c.connectionCtxs { + j := rand.Intn(idx + 1) + if j < 1 { + connectionCtx = cc + } + idx++ + } + return connectionCtx +} diff --git a/client/pkg/connectionctx/manager_test.go b/client/pkg/connectionctx/manager_test.go new file mode 100644 index 00000000000..42504673b95 --- /dev/null +++ b/client/pkg/connectionctx/manager_test.go @@ -0,0 +1,83 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionctx + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestManager(t *testing.T) { + re := require.New(t) + ctx := context.Background() + manager := NewManager[int]() + + re.False(manager.Exist("test-url")) + manager.Store(ctx, "test-url", 1) + re.True(manager.Exist("test-url")) + + cctx := manager.GetConnectionCtx() + re.Equal("test-url", cctx.StreamURL) + re.Equal(1, cctx.Stream) + + manager.Store(ctx, "test-url", 2) + cctx = manager.GetConnectionCtx() + re.Equal("test-url", cctx.StreamURL) + re.Equal(1, cctx.Stream) + + manager.Store(ctx, "test-url", 2, true) + cctx = manager.GetConnectionCtx() + re.Equal("test-url", cctx.StreamURL) + re.Equal(2, cctx.Stream) + + manager.Store(ctx, "test-another-url", 3) + pickedCount := make(map[string]int) + for range 1000 { + cctx = manager.GetConnectionCtx() + pickedCount[cctx.StreamURL]++ + } + re.NotEmpty(pickedCount["test-url"]) + re.NotEmpty(pickedCount["test-another-url"]) + re.Equal(1000, pickedCount["test-url"]+pickedCount["test-another-url"]) + + manager.GC(func(url string) bool { + return url == "test-url" + }) + re.False(manager.Exist("test-url")) + re.True(manager.Exist("test-another-url")) + + manager.CleanAllAndStore(ctx, "test-url", 1) + re.True(manager.Exist("test-url")) + re.False(manager.Exist("test-another-url")) + + manager.Store(ctx, "test-another-url", 3) + manager.CleanAllAndStore(ctx, "test-unique-url", 4) + re.True(manager.Exist("test-unique-url")) + re.False(manager.Exist("test-url")) + re.False(manager.Exist("test-another-url")) + + manager.Release("test-unique-url") + re.False(manager.Exist("test-unique-url")) + + for i := range 1000 { + manager.Store(ctx, fmt.Sprintf("test-url-%d", i), i) + } + re.Len(manager.connectionCtxs, 1000) + manager.ReleaseAll() + re.Empty(manager.connectionCtxs) +} diff --git a/client/servicediscovery/pd_service_discovery.go b/client/servicediscovery/pd_service_discovery.go index 619d4196408..5530f3cfa9b 100644 --- a/client/servicediscovery/pd_service_discovery.go +++ b/client/servicediscovery/pd_service_discovery.go @@ -966,12 +966,9 @@ func (c *pdServiceDiscovery) updateURLs(members []*pdpb.Member) { return } c.urls.Store(urls) - // Update the connection contexts when member changes if TSO Follower Proxy is enabled. - if c.option.GetEnableTSOFollowerProxy() { - // Run callbacks to reflect the membership changes in the leader and followers. - for _, cb := range c.membersChangedCbs { - cb() - } + // Run callbacks to reflect the membership changes in the leader and followers. + for _, cb := range c.membersChangedCbs { + cb() } log.Info("[pd] update member urls", zap.Strings("old-urls", oldURLs), zap.Strings("new-urls", urls)) }