Skip to content

Commit

Permalink
use interceptor for circuit breaker
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Leung <[email protected]>
  • Loading branch information
rleungx committed Dec 19, 2024
1 parent ecb31de commit 579f67b
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 119 deletions.
28 changes: 4 additions & 24 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ import (
"github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
Expand All @@ -42,7 +40,6 @@ import (
"github.com/tikv/pd/client/metrics"
"github.com/tikv/pd/client/opt"
"github.com/tikv/pd/client/pkg/caller"
cb "github.com/tikv/pd/client/pkg/circuitbreaker"
"github.com/tikv/pd/client/pkg/utils/tlsutil"
sd "github.com/tikv/pd/client/servicediscovery"
)
Expand Down Expand Up @@ -461,12 +458,6 @@ func (c *client) UpdateOption(option opt.DynamicOption, value any) error {
return errors.New("[pd] invalid value type for TSOClientRPCConcurrency option, it should be int")
}
c.inner.option.SetTSOClientRPCConcurrency(value)
case opt.RegionMetadataCircuitBreakerSettings:
applySettingsChange, ok := value.(func(config *cb.Settings))
if !ok {
return errors.New("[pd] invalid value type for RegionMetadataCircuitBreakerSettings option, it should be pd.Settings")
}
c.inner.regionMetaCircuitBreaker.ChangeSettings(applySettingsChange)
default:
return errors.New("[pd] unsupported client option")
}
Expand Down Expand Up @@ -661,13 +652,7 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio
if serviceClient == nil {
return nil, errs.ErrClientGetProtoClient
}
resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) {
region, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegion(cctx, req)
failpoint.Inject("triggerCircuitBreaker", func() {
err = status.Error(codes.ResourceExhausted, "resource exhausted")
})
return region, isOverloaded(err), err
})
resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegion(cctx, req)
if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) {
protoClient, cctx := c.getClientAndContext(ctx)
if protoClient == nil {
Expand Down Expand Up @@ -707,10 +692,7 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR
if serviceClient == nil {
return nil, errs.ErrClientGetProtoClient
}
resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) {
resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetPrevRegion(cctx, req)
return resp, isOverloaded(err), err
})
resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetPrevRegion(cctx, req)
if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) {
protoClient, cctx := c.getClientAndContext(ctx)
if protoClient == nil {
Expand Down Expand Up @@ -750,10 +732,8 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt
if serviceClient == nil {
return nil, errs.ErrClientGetProtoClient
}
resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) {
resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegionByID(cctx, req)
return resp, isOverloaded(err), err
})

resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegionByID(cctx, req)
if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) {
protoClient, cctx := c.getClientAndContext(ctx)
if protoClient == nil {
Expand Down
22 changes: 4 additions & 18 deletions client/inner_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (

"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/pdpb"
Expand All @@ -19,7 +17,6 @@ import (
"github.com/tikv/pd/client/errs"
"github.com/tikv/pd/client/metrics"
"github.com/tikv/pd/client/opt"
cb "github.com/tikv/pd/client/pkg/circuitbreaker"
sd "github.com/tikv/pd/client/servicediscovery"
)

Expand All @@ -29,11 +26,10 @@ const (
)

type innerClient struct {
keyspaceID uint32
svrUrls []string
pdSvcDiscovery sd.ServiceDiscovery
tokenDispatcher *tokenDispatcher
regionMetaCircuitBreaker *cb.CircuitBreaker[*pdpb.GetRegionResponse]
keyspaceID uint32
svrUrls []string
pdSvcDiscovery sd.ServiceDiscovery
tokenDispatcher *tokenDispatcher

// For service mode switching.
serviceModeKeeper
Expand All @@ -59,7 +55,6 @@ func (c *innerClient) init(updateKeyspaceIDCb sd.UpdateKeyspaceIDFunc) error {
}
return err
}
c.regionMetaCircuitBreaker = cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", c.option.RegionMetaCircuitBreakerSettings)

return nil
}
Expand Down Expand Up @@ -252,12 +247,3 @@ func (c *innerClient) dispatchTSORequestWithRetry(ctx context.Context) tso.TSFut
}
return req
}

func isOverloaded(err error) cb.Overloading {
switch status.Code(errors.Cause(err)) {
case codes.DeadlineExceeded, codes.Unavailable, codes.ResourceExhausted:
return cb.Yes
default:
return cb.No
}
}
27 changes: 5 additions & 22 deletions client/opt/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (

"github.com/pingcap/errors"

cb "github.com/tikv/pd/client/pkg/circuitbreaker"
"github.com/tikv/pd/client/pkg/retry"
)

Expand All @@ -50,8 +49,6 @@ const (
EnableFollowerHandle
// TSOClientRPCConcurrency controls the amount of ongoing TSO RPC requests at the same time in a single TSO client.
TSOClientRPCConcurrency
// RegionMetadataCircuitBreakerSettings controls settings for circuit breaker for region metadata requests.
RegionMetadataCircuitBreakerSettings

dynamicOptionCount
)
Expand All @@ -72,18 +69,16 @@ type Option struct {
// Dynamic options.
dynamicOptions [dynamicOptionCount]atomic.Value

EnableTSOFollowerProxyCh chan struct{}
RegionMetaCircuitBreakerSettings cb.Settings
EnableTSOFollowerProxyCh chan struct{}
}

// NewOption creates a new PD client option with the default values set.
func NewOption() *Option {
co := &Option{
Timeout: defaultPDTimeout,
MaxRetryTimes: maxInitClusterRetries,
EnableTSOFollowerProxyCh: make(chan struct{}, 1),
InitMetrics: true,
RegionMetaCircuitBreakerSettings: cb.AlwaysClosedSettings,
Timeout: defaultPDTimeout,
MaxRetryTimes: maxInitClusterRetries,
EnableTSOFollowerProxyCh: make(chan struct{}, 1),
InitMetrics: true,
}

co.dynamicOptions[MaxTSOBatchWaitInterval].Store(defaultMaxTSOBatchWaitInterval)
Expand Down Expand Up @@ -154,11 +149,6 @@ func (o *Option) GetTSOClientRPCConcurrency() int {
return o.dynamicOptions[TSOClientRPCConcurrency].Load().(int)
}

// GetRegionMetadataCircuitBreakerSettings gets circuit breaker settings for PD region metadata calls.
func (o *Option) GetRegionMetadataCircuitBreakerSettings() cb.Settings {
return o.dynamicOptions[RegionMetadataCircuitBreakerSettings].Load().(cb.Settings)
}

// ClientOption configures client.
type ClientOption func(*Option)

Expand Down Expand Up @@ -213,13 +203,6 @@ func WithInitMetricsOption(initMetrics bool) ClientOption {
}
}

// WithRegionMetaCircuitBreaker configures the client with circuit breaker for region meta calls
func WithRegionMetaCircuitBreaker(config cb.Settings) ClientOption {
return func(op *Option) {
op.RegionMetaCircuitBreakerSettings = config
}
}

// WithBackoffer configures the client with backoffer.
func WithBackoffer(bo *retry.Backoffer) ClientOption {
return func(op *Option) {
Expand Down
32 changes: 27 additions & 5 deletions client/pkg/circuitbreaker/circuit_breaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package circuitbreaker

import (
"context"
"fmt"
"strings"
"sync"
Expand Down Expand Up @@ -129,12 +130,11 @@ func (cb *CircuitBreaker[T]) ChangeSettings(apply func(config *Settings)) {
// Execute calls the given function if the CircuitBreaker is closed and returns the result of execution.
// Execute returns an error instantly if the CircuitBreaker is open.
// https://github.com/tikv/rfcs/blob/master/text/0115-circuit-breaker.md
func (cb *CircuitBreaker[T]) Execute(call func() (T, Overloading, error)) (T, error) {
func (cb *CircuitBreaker[T]) Execute(call func() (Overloading, error)) error {
state, err := cb.onRequest()
if err != nil {
cb.fastFailCounter.Inc()
var defaultValue T
return defaultValue, err
return err
}

defer func() {
Expand All @@ -146,10 +146,10 @@ func (cb *CircuitBreaker[T]) Execute(call func() (T, Overloading, error)) (T, er
}
}()

result, overloaded, err := call()
overloaded, err := call()
cb.emitMetric(overloaded, err)
cb.onResult(state, overloaded)
return result, err
return err
}

func (cb *CircuitBreaker[T]) onRequest() (*State[T], error) {
Expand Down Expand Up @@ -309,3 +309,25 @@ func (s *State[T]) onResult(overloaded Overloading) {
panic("unknown state")
}
}

// Define context key type
type cbCtxKey struct{}

// Key used to store circuit breaker
var CircuitBreakerKey = cbCtxKey{}

// FromContext retrieves the circuit breaker from the context
func FromContext[T any](ctx context.Context) *CircuitBreaker[T] {
if ctx == nil {
return nil
}

Check warning on line 323 in client/pkg/circuitbreaker/circuit_breaker.go

View check run for this annotation

Codecov / codecov/patch

client/pkg/circuitbreaker/circuit_breaker.go#L322-L323

Added lines #L322 - L323 were not covered by tests
if cb, ok := ctx.Value(CircuitBreakerKey).(*CircuitBreaker[T]); ok {
return cb
}
return nil
}

// WithCircuitBreaker stores the circuit breaker into a new context
func WithCircuitBreaker[T any](ctx context.Context, cb *CircuitBreaker[T]) context.Context {
return context.WithValue(ctx, CircuitBreakerKey, cb)
}
35 changes: 16 additions & 19 deletions client/pkg/circuitbreaker/circuit_breaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,16 @@ func TestCircuitBreakerExecuteWrapperReturnValues(t *testing.T) {
cb := NewCircuitBreaker[int]("test_cb", settings)
originalError := errors.New("circuit breaker is open")

result, err := cb.Execute(func() (int, Overloading, error) {
return 42, No, originalError
err := cb.Execute(func() (Overloading, error) {
return No, originalError
})
re.Equal(err, originalError)
re.Equal(42, result)

// same by interpret the result as overloading error
result, err = cb.Execute(func() (int, Overloading, error) {
return 42, Yes, originalError
err = cb.Execute(func() (Overloading, error) {
return Yes, originalError
})
re.Equal(err, originalError)
re.Equal(42, result)
}

func TestCircuitBreakerOpenState(t *testing.T) {
Expand Down Expand Up @@ -118,8 +116,8 @@ func TestCircuitBreakerHalfOpenToOpen(t *testing.T) {
cb.advance(settings.CoolDownInterval)
assertSucceeds(cb, re)
re.Equal(StateHalfOpen, cb.state.stateType)
_, err := cb.Execute(func() (int, Overloading, error) {
return 42, Yes, nil // this trip circuit breaker again
err := cb.Execute(func() (Overloading, error) {
return Yes, nil // this trip circuit breaker again
})
re.NoError(err)
re.Equal(StateHalfOpen, cb.state.stateType)
Expand Down Expand Up @@ -149,10 +147,10 @@ func TestCircuitBreakerHalfOpenFailOverPendingCount(t *testing.T) {
defer func() {
end <- true
}()
_, err := cb.Execute(func() (int, Overloading, error) {
err := cb.Execute(func() (Overloading, error) {
start <- true
<-wait
return 42, No, nil
return No, nil
})
re.NoError(err)
}()
Expand Down Expand Up @@ -188,10 +186,10 @@ func TestCircuitBreakerCountOnlyRequestsInSameWindow(t *testing.T) {
defer func() {
end <- true
}()
_, err := cb.Execute(func() (int, Overloading, error) {
err := cb.Execute(func() (Overloading, error) {
start <- true
<-wait
return 42, No, nil
return No, nil
})
re.NoError(err)
}()
Expand Down Expand Up @@ -244,27 +242,26 @@ func newCircuitBreakerMovedToHalfOpenState(re *require.Assertions) *CircuitBreak

func driveQPS(cb *CircuitBreaker[int], count int, overload Overloading, re *require.Assertions) {
for range count {
_, err := cb.Execute(func() (int, Overloading, error) {
return 42, overload, nil
err := cb.Execute(func() (Overloading, error) {
return overload, nil
})
re.NoError(err)
}
}

func assertFastFail(cb *CircuitBreaker[int], re *require.Assertions) {
var executed = false
_, err := cb.Execute(func() (int, Overloading, error) {
err := cb.Execute(func() (Overloading, error) {
executed = true
return 42, No, nil
return No, nil
})
re.Equal(err, errs.ErrCircuitBreakerOpen)
re.False(executed)
}

func assertSucceeds(cb *CircuitBreaker[int], re *require.Assertions) {
result, err := cb.Execute(func() (int, Overloading, error) {
return 42, No, nil
err := cb.Execute(func() (Overloading, error) {
return No, nil
})
re.NoError(err)
re.Equal(42, result)
}
Loading

0 comments on commit 579f67b

Please sign in to comment.