From db88a43f7783e8768730f8ab0ef6056e9d2eff04 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 17 Apr 2024 22:04:37 +0800 Subject: [PATCH] client/tso: use the TSO request pool at the tsoClient level to avoid data race (#8077) close tikv/pd#8055, ref tikv/pd#8076 Use the TSO request pool at the `tsoClient` level to avoid data race. Signed-off-by: JmPotato Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- client/client.go | 33 ++++++-------- client/tso_client.go | 61 ++++++++++++------------- client/tso_dispatcher.go | 32 -------------- client/tso_request.go | 96 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 139 insertions(+), 83 deletions(-) create mode 100644 client/tso_request.go diff --git a/client/client.go b/client/client.go index 1852b77e4c6..eaebef7e10c 100644 --- a/client/client.go +++ b/client/client.go @@ -798,23 +798,7 @@ func (c *client) GetLocalTSAsync(ctx context.Context, dcLocation string) TSFutur defer span.Finish() } - req := c.getTSORequest(ctx, dcLocation) - if err := c.dispatchTSORequestWithRetry(req); err != nil { - req.tryDone(err) - } - return req -} - -func (c *client) getTSORequest(ctx context.Context, dcLocation string) *tsoRequest { - req := tsoReqPool.Get().(*tsoRequest) - // Set needed fields in the request before using it. - req.start = time.Now() - req.clientCtx = c.ctx - req.requestCtx = ctx - req.physical = 0 - req.logical = 0 - req.dcLocation = dcLocation - return req + return c.dispatchTSORequestWithRetry(ctx, dcLocation) } const ( @@ -822,10 +806,11 @@ const ( dispatchRetryCount = 2 ) -func (c *client) dispatchTSORequestWithRetry(req *tsoRequest) error { +func (c *client) dispatchTSORequestWithRetry(ctx context.Context, dcLocation string) TSFuture { var ( retryable bool err error + req *tsoRequest ) for i := 0; i < dispatchRetryCount; i++ { // Do not delay for the first time. @@ -838,12 +823,22 @@ func (c *client) dispatchTSORequestWithRetry(req *tsoRequest) error { err = errs.ErrClientGetTSO.FastGenByArgs("tso client is nil") continue } + // Get a new request from the pool if it's nil or not from the current pool. + if req == nil || req.pool != tsoClient.tsoReqPool { + req = tsoClient.getTSORequest(ctx, dcLocation) + } retryable, err = tsoClient.dispatchRequest(req) if !retryable { break } } - return err + if err != nil { + if req == nil { + return newTSORequestFastFail(err) + } + req.tryDone(err) + } + return req } func (c *client) GetTS(ctx context.Context) (physical int64, logical int64, err error) { diff --git a/client/tso_client.go b/client/tso_client.go index 5f8b12df36f..8185b99d1d0 100644 --- a/client/tso_client.go +++ b/client/tso_client.go @@ -43,33 +43,6 @@ type TSOClient interface { GetMinTS(ctx context.Context) (int64, int64, error) } -type tsoRequest struct { - start time.Time - clientCtx context.Context - requestCtx context.Context - done chan error - physical int64 - logical int64 - dcLocation string -} - -var tsoReqPool = sync.Pool{ - New: func() any { - return &tsoRequest{ - done: make(chan error, 1), - physical: 0, - logical: 0, - } - }, -} - -func (req *tsoRequest) tryDone(err error) { - select { - case req.done <- err: - default: - } -} - type tsoClient struct { ctx context.Context cancel context.CancelFunc @@ -84,6 +57,8 @@ type tsoClient struct { // tso allocator leader is switched. tsoAllocServingURLSwitchedCallback []func() + // tsoReqPool is the pool to recycle `*tsoRequest`. + tsoReqPool *sync.Pool // tsoDispatcher is used to dispatch different TSO requests to // the corresponding dc-location TSO channel. tsoDispatcher sync.Map // Same as map[string]*tsoDispatcher @@ -104,11 +79,20 @@ func newTSOClient( ) *tsoClient { ctx, cancel := context.WithCancel(ctx) c := &tsoClient{ - ctx: ctx, - cancel: cancel, - option: option, - svcDiscovery: svcDiscovery, - tsoStreamBuilderFactory: factory, + ctx: ctx, + cancel: cancel, + option: option, + svcDiscovery: svcDiscovery, + tsoStreamBuilderFactory: factory, + tsoReqPool: &sync.Pool{ + New: func() any { + return &tsoRequest{ + done: make(chan error, 1), + physical: 0, + logical: 0, + } + }, + }, checkTSDeadlineCh: make(chan struct{}), checkTSODispatcherCh: make(chan struct{}, 1), updateTSOConnectionCtxsCh: make(chan struct{}, 1), @@ -155,6 +139,19 @@ func (c *tsoClient) Close() { log.Info("tso client is closed") } +func (c *tsoClient) getTSORequest(ctx context.Context, dcLocation string) *tsoRequest { + req := c.tsoReqPool.Get().(*tsoRequest) + // Set needed fields in the request before using it. + req.start = time.Now() + req.pool = c.tsoReqPool + req.requestCtx = ctx + req.clientCtx = c.ctx + req.physical = 0 + req.logical = 0 + req.dcLocation = dcLocation + return req +} + // GetTSOAllocators returns {dc-location -> TSO allocator leader URL} connection map func (c *tsoClient) GetTSOAllocators() *sync.Map { return &c.tsoAllocators diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index ad3aa1c5d74..d02fdd52af8 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -115,38 +115,6 @@ func (c *tsoClient) dispatchRequest(request *tsoRequest) (bool, error) { return false, nil } -// TSFuture is a future which promises to return a TSO. -type TSFuture interface { - // Wait gets the physical and logical time, it would block caller if data is not available yet. - Wait() (int64, int64, error) -} - -func (req *tsoRequest) Wait() (physical int64, logical int64, err error) { - // If tso command duration is observed very high, the reason could be it - // takes too long for Wait() be called. - start := time.Now() - cmdDurationTSOAsyncWait.Observe(start.Sub(req.start).Seconds()) - select { - case err = <-req.done: - defer trace.StartRegion(req.requestCtx, "pdclient.tsoReqDone").End() - err = errors.WithStack(err) - defer tsoReqPool.Put(req) - if err != nil { - cmdFailDurationTSO.Observe(time.Since(req.start).Seconds()) - return 0, 0, err - } - physical, logical = req.physical, req.logical - now := time.Now() - cmdDurationWait.Observe(now.Sub(start).Seconds()) - cmdDurationTSO.Observe(now.Sub(req.start).Seconds()) - return - case <-req.requestCtx.Done(): - return 0, 0, errors.WithStack(req.requestCtx.Err()) - case <-req.clientCtx.Done(): - return 0, 0, errors.WithStack(req.clientCtx.Err()) - } -} - func (c *tsoClient) updateTSODispatcher() { // Set up the new TSO dispatcher and batch controller. c.GetTSOAllocators().Range(func(dcLocationKey, _ any) bool { diff --git a/client/tso_request.go b/client/tso_request.go new file mode 100644 index 00000000000..f30ceb5268a --- /dev/null +++ b/client/tso_request.go @@ -0,0 +1,96 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "context" + "runtime/trace" + "sync" + "time" + + "github.com/pingcap/errors" +) + +// TSFuture is a future which promises to return a TSO. +type TSFuture interface { + // Wait gets the physical and logical time, it would block caller if data is not available yet. + Wait() (int64, int64, error) +} + +var ( + _ TSFuture = (*tsoRequest)(nil) + _ TSFuture = (*tsoRequestFastFail)(nil) +) + +type tsoRequest struct { + requestCtx context.Context + clientCtx context.Context + done chan error + physical int64 + logical int64 + dcLocation string + + // Runtime fields. + start time.Time + pool *sync.Pool +} + +// tryDone tries to send the result to the channel, it will not block. +func (req *tsoRequest) tryDone(err error) { + select { + case req.done <- err: + default: + } +} + +// Wait will block until the TSO result is ready. +func (req *tsoRequest) Wait() (physical int64, logical int64, err error) { + // If tso command duration is observed very high, the reason could be it + // takes too long for Wait() be called. + start := time.Now() + cmdDurationTSOAsyncWait.Observe(start.Sub(req.start).Seconds()) + select { + case err = <-req.done: + defer trace.StartRegion(req.requestCtx, "pdclient.tsoReqDone").End() + defer req.pool.Put(req) + err = errors.WithStack(err) + if err != nil { + cmdFailDurationTSO.Observe(time.Since(req.start).Seconds()) + return 0, 0, err + } + physical, logical = req.physical, req.logical + now := time.Now() + cmdDurationWait.Observe(now.Sub(start).Seconds()) + cmdDurationTSO.Observe(now.Sub(req.start).Seconds()) + return + case <-req.requestCtx.Done(): + return 0, 0, errors.WithStack(req.requestCtx.Err()) + case <-req.clientCtx.Done(): + return 0, 0, errors.WithStack(req.clientCtx.Err()) + } +} + +type tsoRequestFastFail struct { + err error +} + +func newTSORequestFastFail(err error) *tsoRequestFastFail { + return &tsoRequestFastFail{err} +} + +// Wait returns the error directly. +func (req *tsoRequestFastFail) Wait() (physical int64, logical int64, err error) { + return 0, 0, req.err +}