diff --git a/client/pkg/circuitbreaker/circuit_breaker.go b/client/pkg/circuitbreaker/circuit_breaker.go index 9c0a5382111..0acee5d5c8d 100644 --- a/client/pkg/circuitbreaker/circuit_breaker.go +++ b/client/pkg/circuitbreaker/circuit_breaker.go @@ -63,12 +63,12 @@ var AlwaysClosedSettings = Settings{ } // CircuitBreaker is a state machine to prevent sending requests that are likely to fail. -type CircuitBreaker[T any] struct { +type CircuitBreaker struct { config *Settings name string mutex sync.Mutex - state *State[T] + state *State successCounter prometheus.Counter errorCounter prometheus.Counter @@ -103,8 +103,8 @@ func (s StateType) String() string { var replacer = strings.NewReplacer(" ", "_", "-", "_") // NewCircuitBreaker returns a new CircuitBreaker configured with the given Settings. -func NewCircuitBreaker[T any](name string, st Settings) *CircuitBreaker[T] { - cb := new(CircuitBreaker[T]) +func NewCircuitBreaker(name string, st Settings) *CircuitBreaker { + cb := new(CircuitBreaker) cb.name = name cb.config = &st cb.state = cb.newState(time.Now(), StateClosed) @@ -119,7 +119,7 @@ func NewCircuitBreaker[T any](name string, st Settings) *CircuitBreaker[T] { // ChangeSettings changes the CircuitBreaker settings. // The changes will be reflected only in the next evaluation window. -func (cb *CircuitBreaker[T]) ChangeSettings(apply func(config *Settings)) { +func (cb *CircuitBreaker) ChangeSettings(apply func(config *Settings)) { cb.mutex.Lock() defer cb.mutex.Unlock() @@ -130,7 +130,7 @@ func (cb *CircuitBreaker[T]) ChangeSettings(apply func(config *Settings)) { // Execute calls the given function if the CircuitBreaker is closed and returns the result of execution. // Execute returns an error instantly if the CircuitBreaker is open. // https://github.com/tikv/rfcs/blob/master/text/0115-circuit-breaker.md -func (cb *CircuitBreaker[T]) Execute(call func() (Overloading, error)) error { +func (cb *CircuitBreaker) Execute(call func() (Overloading, error)) error { state, err := cb.onRequest() if err != nil { cb.fastFailCounter.Inc() @@ -152,7 +152,7 @@ func (cb *CircuitBreaker[T]) Execute(call func() (Overloading, error)) error { return err } -func (cb *CircuitBreaker[T]) onRequest() (*State[T], error) { +func (cb *CircuitBreaker) onRequest() (*State, error) { cb.mutex.Lock() defer cb.mutex.Unlock() @@ -161,7 +161,7 @@ func (cb *CircuitBreaker[T]) onRequest() (*State[T], error) { return state, err } -func (cb *CircuitBreaker[T]) onResult(state *State[T], overloaded Overloading) { +func (cb *CircuitBreaker) onResult(state *State, overloaded Overloading) { cb.mutex.Lock() defer cb.mutex.Unlock() @@ -170,7 +170,7 @@ func (cb *CircuitBreaker[T]) onResult(state *State[T], overloaded Overloading) { state.onResult(overloaded) } -func (cb *CircuitBreaker[T]) emitMetric(overloaded Overloading, err error) { +func (cb *CircuitBreaker) emitMetric(overloaded Overloading, err error) { switch overloaded { case No: cb.successCounter.Inc() @@ -185,9 +185,9 @@ func (cb *CircuitBreaker[T]) emitMetric(overloaded Overloading, err error) { } // State represents the state of CircuitBreaker. -type State[T any] struct { +type State struct { stateType StateType - cb *CircuitBreaker[T] + cb *CircuitBreaker end time.Time pendingCount uint32 @@ -196,7 +196,7 @@ type State[T any] struct { } // newState creates a new State with the given configuration and reset all success/failure counters. -func (cb *CircuitBreaker[T]) newState(now time.Time, stateType StateType) *State[T] { +func (cb *CircuitBreaker) newState(now time.Time, stateType StateType) *State { var end time.Time var pendingCount uint32 switch stateType { @@ -211,7 +211,7 @@ func (cb *CircuitBreaker[T]) newState(now time.Time, stateType StateType) *State default: panic("unknown state") } - return &State[T]{ + return &State{ cb: cb, stateType: stateType, pendingCount: pendingCount, @@ -227,7 +227,7 @@ func (cb *CircuitBreaker[T]) newState(now time.Time, stateType StateType) *State // Open state fails all request, it has a fixed duration of `Settings.CoolDownInterval` and always moves to HalfOpen state at the end of the interval. // HalfOpen state does not have a fixed duration and lasts till `Settings.HalfOpenSuccessCount` are evaluated. // If any of `Settings.HalfOpenSuccessCount` fails then it moves back to Open state, otherwise it moves to Closed state. -func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { +func (s *State) onRequest(cb *CircuitBreaker) (*State, error) { var now = time.Now() switch s.stateType { case StateClosed: @@ -299,7 +299,7 @@ func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { } } -func (s *State[T]) onResult(overloaded Overloading) { +func (s *State) onResult(overloaded Overloading) { switch overloaded { case No: s.successCount++ @@ -317,17 +317,17 @@ type cbCtxKey struct{} var CircuitBreakerKey = cbCtxKey{} // FromContext retrieves the circuit breaker from the context -func FromContext[T any](ctx context.Context) *CircuitBreaker[T] { +func FromContext(ctx context.Context) *CircuitBreaker { if ctx == nil { return nil } - if cb, ok := ctx.Value(CircuitBreakerKey).(*CircuitBreaker[T]); ok { + if cb, ok := ctx.Value(CircuitBreakerKey).(*CircuitBreaker); ok { return cb } return nil } // WithCircuitBreaker stores the circuit breaker into a new context -func WithCircuitBreaker[T any](ctx context.Context, cb *CircuitBreaker[T]) context.Context { +func WithCircuitBreaker(ctx context.Context, cb *CircuitBreaker) context.Context { return context.WithValue(ctx, CircuitBreakerKey, cb) } diff --git a/client/pkg/circuitbreaker/circuit_breaker_test.go b/client/pkg/circuitbreaker/circuit_breaker_test.go index d7df55c10ca..e62e55c1ab8 100644 --- a/client/pkg/circuitbreaker/circuit_breaker_test.go +++ b/client/pkg/circuitbreaker/circuit_breaker_test.go @@ -24,7 +24,7 @@ import ( ) // advance emulate the state machine clock moves forward by the given duration -func (cb *CircuitBreaker[T]) advance(duration time.Duration) { +func (cb *CircuitBreaker) advance(duration time.Duration) { cb.state.end = cb.state.end.Add(-duration - 1) } @@ -40,7 +40,7 @@ var minCountToOpen = int(settings.MinQPSForOpen * uint32(settings.ErrorRateWindo func TestCircuitBreakerExecuteWrapperReturnValues(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) + cb := NewCircuitBreaker("test_cb", settings) originalError := errors.New("circuit breaker is open") err := cb.Execute(func() (Overloading, error) { @@ -57,7 +57,7 @@ func TestCircuitBreakerExecuteWrapperReturnValues(t *testing.T) { func TestCircuitBreakerOpenState(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) + cb := NewCircuitBreaker("test_cb", settings) driveQPS(cb, minCountToOpen, Yes, re) re.Equal(StateClosed, cb.state.stateType) assertSucceeds(cb, re) // no error till ErrorRateWindow is finished @@ -68,7 +68,7 @@ func TestCircuitBreakerOpenState(t *testing.T) { func TestCircuitBreakerCloseStateNotEnoughQPS(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) + cb := NewCircuitBreaker("test_cb", settings) re.Equal(StateClosed, cb.state.stateType) driveQPS(cb, minCountToOpen/2, Yes, re) cb.advance(settings.ErrorRateWindow) @@ -78,7 +78,7 @@ func TestCircuitBreakerCloseStateNotEnoughQPS(t *testing.T) { func TestCircuitBreakerCloseStateNotEnoughErrorRate(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) + cb := NewCircuitBreaker("test_cb", settings) re.Equal(StateClosed, cb.state.stateType) driveQPS(cb, minCountToOpen/4, Yes, re) driveQPS(cb, minCountToOpen, No, re) @@ -89,7 +89,7 @@ func TestCircuitBreakerCloseStateNotEnoughErrorRate(t *testing.T) { func TestCircuitBreakerHalfOpenToClosed(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) + cb := NewCircuitBreaker("test_cb", settings) re.Equal(StateClosed, cb.state.stateType) driveQPS(cb, minCountToOpen, Yes, re) cb.advance(settings.ErrorRateWindow) @@ -107,7 +107,7 @@ func TestCircuitBreakerHalfOpenToClosed(t *testing.T) { func TestCircuitBreakerHalfOpenToOpen(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) + cb := NewCircuitBreaker("test_cb", settings) re.Equal(StateClosed, cb.state.stateType) driveQPS(cb, minCountToOpen, Yes, re) cb.advance(settings.ErrorRateWindow) @@ -176,7 +176,7 @@ func TestCircuitBreakerHalfOpenFailOverPendingCount(t *testing.T) { func TestCircuitBreakerCountOnlyRequestsInSameWindow(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) + cb := NewCircuitBreaker("test_cb", settings) re.Equal(StateClosed, cb.state.stateType) start := make(chan bool) @@ -212,7 +212,7 @@ func TestCircuitBreakerCountOnlyRequestsInSameWindow(t *testing.T) { func TestCircuitBreakerChangeSettings(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", AlwaysClosedSettings) + cb := NewCircuitBreaker("test_cb", AlwaysClosedSettings) driveQPS(cb, int(AlwaysClosedSettings.MinQPSForOpen*uint32(AlwaysClosedSettings.ErrorRateWindow.Seconds())), Yes, re) cb.advance(AlwaysClosedSettings.ErrorRateWindow) assertSucceeds(cb, re) @@ -229,8 +229,8 @@ func TestCircuitBreakerChangeSettings(t *testing.T) { re.Equal(StateOpen, cb.state.stateType) } -func newCircuitBreakerMovedToHalfOpenState(re *require.Assertions) *CircuitBreaker[int] { - cb := NewCircuitBreaker[int]("test_cb", settings) +func newCircuitBreakerMovedToHalfOpenState(re *require.Assertions) *CircuitBreaker { + cb := NewCircuitBreaker("test_cb", settings) re.Equal(StateClosed, cb.state.stateType) driveQPS(cb, minCountToOpen, Yes, re) cb.advance(settings.ErrorRateWindow) @@ -240,7 +240,7 @@ func newCircuitBreakerMovedToHalfOpenState(re *require.Assertions) *CircuitBreak return cb } -func driveQPS(cb *CircuitBreaker[int], count int, overload Overloading, re *require.Assertions) { +func driveQPS(cb *CircuitBreaker, count int, overload Overloading, re *require.Assertions) { for range count { err := cb.Execute(func() (Overloading, error) { return overload, nil @@ -249,7 +249,7 @@ func driveQPS(cb *CircuitBreaker[int], count int, overload Overloading, re *requ } } -func assertFastFail(cb *CircuitBreaker[int], re *require.Assertions) { +func assertFastFail(cb *CircuitBreaker, re *require.Assertions) { var executed = false err := cb.Execute(func() (Overloading, error) { executed = true @@ -259,7 +259,7 @@ func assertFastFail(cb *CircuitBreaker[int], re *require.Assertions) { re.False(executed) } -func assertSucceeds(cb *CircuitBreaker[int], re *require.Assertions) { +func assertSucceeds(cb *CircuitBreaker, re *require.Assertions) { err := cb.Execute(func() (Overloading, error) { return No, nil }) diff --git a/client/pkg/utils/grpcutil/grpcutil.go b/client/pkg/utils/grpcutil/grpcutil.go index bfaab6833f0..235e1088747 100644 --- a/client/pkg/utils/grpcutil/grpcutil.go +++ b/client/pkg/utils/grpcutil/grpcutil.go @@ -32,7 +32,6 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" "github.com/tikv/pd/client/errs" @@ -75,9 +74,10 @@ func UnaryBackofferInterceptor() grpc.UnaryClientInterceptor { } } -func UnaryCircuitBreakerInterceptor[T any]() grpc.UnaryClientInterceptor { +// UnaryCircuitBreakerInterceptor is a gRPC interceptor that adds a circuit breaker to the call. +func UnaryCircuitBreakerInterceptor() grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - cb := circuitbreaker.FromContext[*pdpb.GetRegionResponse](ctx) + cb := circuitbreaker.FromContext(ctx) if cb == nil { return invoker(ctx, method, req, reply, cc, opts...) } @@ -132,7 +132,7 @@ func GetClientConn(ctx context.Context, addr string, tlsCfg *tls.Config, do ...g retryOpt := grpc.WithChainUnaryInterceptor(UnaryBackofferInterceptor()) // Add circuit breaker interceptor - cbOpt := grpc.WithChainUnaryInterceptor(UnaryCircuitBreakerInterceptor[any]()) + cbOpt := grpc.WithChainUnaryInterceptor(UnaryCircuitBreakerInterceptor()) // Add retry related connection parameters backoffOpts := grpc.WithConnectParams(grpc.ConnectParams{ diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index cdab9e81991..c29ef40a83b 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -2073,7 +2073,7 @@ func TestCircuitBreaker(t *testing.T) { cli := setupCli(ctx, re, endpoints) defer cli.Close() - circuitBreaker := cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", circuitBreakerSettings) + circuitBreaker := cb.NewCircuitBreaker("region_meta", circuitBreakerSettings) ctx = cb.WithCircuitBreaker(ctx, circuitBreaker) for range 10 { region, err := cli.GetRegion(ctx, []byte("a")) @@ -2128,7 +2128,7 @@ func TestCircuitBreakerOpenAndChangeSettings(t *testing.T) { cli := setupCli(ctx, re, endpoints) defer cli.Close() - circuitBreaker := cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", circuitBreakerSettings) + circuitBreaker := cb.NewCircuitBreaker("region_meta", circuitBreakerSettings) ctx = cb.WithCircuitBreaker(ctx, circuitBreaker) for range 10 { region, err := cli.GetRegion(ctx, []byte("a")) @@ -2178,7 +2178,7 @@ func TestCircuitBreakerHalfOpenAndChangeSettings(t *testing.T) { cli := setupCli(ctx, re, endpoints) defer cli.Close() - circuitBreaker := cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", circuitBreakerSettings) + circuitBreaker := cb.NewCircuitBreaker("region_meta", circuitBreakerSettings) ctx = cb.WithCircuitBreaker(ctx, circuitBreaker) for range 10 { region, err := cli.GetRegion(ctx, []byte("a"))