Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Leung <[email protected]>
  • Loading branch information
rleungx committed Dec 26, 2024
1 parent 579f67b commit 13d663d
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 39 deletions.
36 changes: 18 additions & 18 deletions client/pkg/circuitbreaker/circuit_breaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ var AlwaysClosedSettings = Settings{
}

// CircuitBreaker is a state machine to prevent sending requests that are likely to fail.
type CircuitBreaker[T any] struct {
type CircuitBreaker struct {
config *Settings
name string

mutex sync.Mutex
state *State[T]
state *State

successCounter prometheus.Counter
errorCounter prometheus.Counter
Expand Down Expand Up @@ -103,8 +103,8 @@ func (s StateType) String() string {
var replacer = strings.NewReplacer(" ", "_", "-", "_")

// NewCircuitBreaker returns a new CircuitBreaker configured with the given Settings.
func NewCircuitBreaker[T any](name string, st Settings) *CircuitBreaker[T] {
cb := new(CircuitBreaker[T])
func NewCircuitBreaker(name string, st Settings) *CircuitBreaker {
cb := new(CircuitBreaker)
cb.name = name
cb.config = &st
cb.state = cb.newState(time.Now(), StateClosed)
Expand All @@ -119,7 +119,7 @@ func NewCircuitBreaker[T any](name string, st Settings) *CircuitBreaker[T] {

// ChangeSettings changes the CircuitBreaker settings.
// The changes will be reflected only in the next evaluation window.
func (cb *CircuitBreaker[T]) ChangeSettings(apply func(config *Settings)) {
func (cb *CircuitBreaker) ChangeSettings(apply func(config *Settings)) {
cb.mutex.Lock()
defer cb.mutex.Unlock()

Expand All @@ -130,7 +130,7 @@ func (cb *CircuitBreaker[T]) ChangeSettings(apply func(config *Settings)) {
// Execute calls the given function if the CircuitBreaker is closed and returns the result of execution.
// Execute returns an error instantly if the CircuitBreaker is open.
// https://github.com/tikv/rfcs/blob/master/text/0115-circuit-breaker.md
func (cb *CircuitBreaker[T]) Execute(call func() (Overloading, error)) error {
func (cb *CircuitBreaker) Execute(call func() (Overloading, error)) error {
state, err := cb.onRequest()
if err != nil {
cb.fastFailCounter.Inc()
Expand All @@ -152,7 +152,7 @@ func (cb *CircuitBreaker[T]) Execute(call func() (Overloading, error)) error {
return err
}

func (cb *CircuitBreaker[T]) onRequest() (*State[T], error) {
func (cb *CircuitBreaker) onRequest() (*State, error) {
cb.mutex.Lock()
defer cb.mutex.Unlock()

Expand All @@ -161,7 +161,7 @@ func (cb *CircuitBreaker[T]) onRequest() (*State[T], error) {
return state, err
}

func (cb *CircuitBreaker[T]) onResult(state *State[T], overloaded Overloading) {
func (cb *CircuitBreaker) onResult(state *State, overloaded Overloading) {
cb.mutex.Lock()
defer cb.mutex.Unlock()

Expand All @@ -170,7 +170,7 @@ func (cb *CircuitBreaker[T]) onResult(state *State[T], overloaded Overloading) {
state.onResult(overloaded)
}

func (cb *CircuitBreaker[T]) emitMetric(overloaded Overloading, err error) {
func (cb *CircuitBreaker) emitMetric(overloaded Overloading, err error) {
switch overloaded {
case No:
cb.successCounter.Inc()
Expand All @@ -185,9 +185,9 @@ func (cb *CircuitBreaker[T]) emitMetric(overloaded Overloading, err error) {
}

// State represents the state of CircuitBreaker.
type State[T any] struct {
type State struct {
stateType StateType
cb *CircuitBreaker[T]
cb *CircuitBreaker
end time.Time

pendingCount uint32
Expand All @@ -196,7 +196,7 @@ type State[T any] struct {
}

// newState creates a new State with the given configuration and reset all success/failure counters.
func (cb *CircuitBreaker[T]) newState(now time.Time, stateType StateType) *State[T] {
func (cb *CircuitBreaker) newState(now time.Time, stateType StateType) *State {
var end time.Time
var pendingCount uint32
switch stateType {
Expand All @@ -211,7 +211,7 @@ func (cb *CircuitBreaker[T]) newState(now time.Time, stateType StateType) *State
default:
panic("unknown state")
}
return &State[T]{
return &State{
cb: cb,
stateType: stateType,
pendingCount: pendingCount,
Expand All @@ -227,7 +227,7 @@ func (cb *CircuitBreaker[T]) newState(now time.Time, stateType StateType) *State
// Open state fails all request, it has a fixed duration of `Settings.CoolDownInterval` and always moves to HalfOpen state at the end of the interval.
// HalfOpen state does not have a fixed duration and lasts till `Settings.HalfOpenSuccessCount` are evaluated.
// If any of `Settings.HalfOpenSuccessCount` fails then it moves back to Open state, otherwise it moves to Closed state.
func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) {
func (s *State) onRequest(cb *CircuitBreaker) (*State, error) {
var now = time.Now()
switch s.stateType {
case StateClosed:
Expand Down Expand Up @@ -299,7 +299,7 @@ func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) {
}
}

func (s *State[T]) onResult(overloaded Overloading) {
func (s *State) onResult(overloaded Overloading) {
switch overloaded {
case No:
s.successCount++
Expand All @@ -317,17 +317,17 @@ type cbCtxKey struct{}
var CircuitBreakerKey = cbCtxKey{}

// FromContext retrieves the circuit breaker from the context
func FromContext[T any](ctx context.Context) *CircuitBreaker[T] {
func FromContext(ctx context.Context) *CircuitBreaker {
if ctx == nil {
return nil
}
if cb, ok := ctx.Value(CircuitBreakerKey).(*CircuitBreaker[T]); ok {
if cb, ok := ctx.Value(CircuitBreakerKey).(*CircuitBreaker); ok {
return cb
}
return nil
}

// WithCircuitBreaker stores the circuit breaker into a new context
func WithCircuitBreaker[T any](ctx context.Context, cb *CircuitBreaker[T]) context.Context {
func WithCircuitBreaker(ctx context.Context, cb *CircuitBreaker) context.Context {
return context.WithValue(ctx, CircuitBreakerKey, cb)
}
28 changes: 14 additions & 14 deletions client/pkg/circuitbreaker/circuit_breaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
)

// advance emulate the state machine clock moves forward by the given duration
func (cb *CircuitBreaker[T]) advance(duration time.Duration) {
func (cb *CircuitBreaker) advance(duration time.Duration) {
cb.state.end = cb.state.end.Add(-duration - 1)
}

Expand All @@ -40,7 +40,7 @@ var minCountToOpen = int(settings.MinQPSForOpen * uint32(settings.ErrorRateWindo

func TestCircuitBreakerExecuteWrapperReturnValues(t *testing.T) {
re := require.New(t)
cb := NewCircuitBreaker[int]("test_cb", settings)
cb := NewCircuitBreaker("test_cb", settings)
originalError := errors.New("circuit breaker is open")

err := cb.Execute(func() (Overloading, error) {
Expand All @@ -57,7 +57,7 @@ func TestCircuitBreakerExecuteWrapperReturnValues(t *testing.T) {

func TestCircuitBreakerOpenState(t *testing.T) {
re := require.New(t)
cb := NewCircuitBreaker[int]("test_cb", settings)
cb := NewCircuitBreaker("test_cb", settings)
driveQPS(cb, minCountToOpen, Yes, re)
re.Equal(StateClosed, cb.state.stateType)
assertSucceeds(cb, re) // no error till ErrorRateWindow is finished
Expand All @@ -68,7 +68,7 @@ func TestCircuitBreakerOpenState(t *testing.T) {

func TestCircuitBreakerCloseStateNotEnoughQPS(t *testing.T) {
re := require.New(t)
cb := NewCircuitBreaker[int]("test_cb", settings)
cb := NewCircuitBreaker("test_cb", settings)
re.Equal(StateClosed, cb.state.stateType)
driveQPS(cb, minCountToOpen/2, Yes, re)
cb.advance(settings.ErrorRateWindow)
Expand All @@ -78,7 +78,7 @@ func TestCircuitBreakerCloseStateNotEnoughQPS(t *testing.T) {

func TestCircuitBreakerCloseStateNotEnoughErrorRate(t *testing.T) {
re := require.New(t)
cb := NewCircuitBreaker[int]("test_cb", settings)
cb := NewCircuitBreaker("test_cb", settings)
re.Equal(StateClosed, cb.state.stateType)
driveQPS(cb, minCountToOpen/4, Yes, re)
driveQPS(cb, minCountToOpen, No, re)
Expand All @@ -89,7 +89,7 @@ func TestCircuitBreakerCloseStateNotEnoughErrorRate(t *testing.T) {

func TestCircuitBreakerHalfOpenToClosed(t *testing.T) {
re := require.New(t)
cb := NewCircuitBreaker[int]("test_cb", settings)
cb := NewCircuitBreaker("test_cb", settings)
re.Equal(StateClosed, cb.state.stateType)
driveQPS(cb, minCountToOpen, Yes, re)
cb.advance(settings.ErrorRateWindow)
Expand All @@ -107,7 +107,7 @@ func TestCircuitBreakerHalfOpenToClosed(t *testing.T) {

func TestCircuitBreakerHalfOpenToOpen(t *testing.T) {
re := require.New(t)
cb := NewCircuitBreaker[int]("test_cb", settings)
cb := NewCircuitBreaker("test_cb", settings)
re.Equal(StateClosed, cb.state.stateType)
driveQPS(cb, minCountToOpen, Yes, re)
cb.advance(settings.ErrorRateWindow)
Expand Down Expand Up @@ -176,7 +176,7 @@ func TestCircuitBreakerHalfOpenFailOverPendingCount(t *testing.T) {

func TestCircuitBreakerCountOnlyRequestsInSameWindow(t *testing.T) {
re := require.New(t)
cb := NewCircuitBreaker[int]("test_cb", settings)
cb := NewCircuitBreaker("test_cb", settings)
re.Equal(StateClosed, cb.state.stateType)

start := make(chan bool)
Expand Down Expand Up @@ -212,7 +212,7 @@ func TestCircuitBreakerCountOnlyRequestsInSameWindow(t *testing.T) {
func TestCircuitBreakerChangeSettings(t *testing.T) {
re := require.New(t)

cb := NewCircuitBreaker[int]("test_cb", AlwaysClosedSettings)
cb := NewCircuitBreaker("test_cb", AlwaysClosedSettings)
driveQPS(cb, int(AlwaysClosedSettings.MinQPSForOpen*uint32(AlwaysClosedSettings.ErrorRateWindow.Seconds())), Yes, re)
cb.advance(AlwaysClosedSettings.ErrorRateWindow)
assertSucceeds(cb, re)
Expand All @@ -229,8 +229,8 @@ func TestCircuitBreakerChangeSettings(t *testing.T) {
re.Equal(StateOpen, cb.state.stateType)
}

func newCircuitBreakerMovedToHalfOpenState(re *require.Assertions) *CircuitBreaker[int] {
cb := NewCircuitBreaker[int]("test_cb", settings)
func newCircuitBreakerMovedToHalfOpenState(re *require.Assertions) *CircuitBreaker {
cb := NewCircuitBreaker("test_cb", settings)
re.Equal(StateClosed, cb.state.stateType)
driveQPS(cb, minCountToOpen, Yes, re)
cb.advance(settings.ErrorRateWindow)
Expand All @@ -240,7 +240,7 @@ func newCircuitBreakerMovedToHalfOpenState(re *require.Assertions) *CircuitBreak
return cb
}

func driveQPS(cb *CircuitBreaker[int], count int, overload Overloading, re *require.Assertions) {
func driveQPS(cb *CircuitBreaker, count int, overload Overloading, re *require.Assertions) {
for range count {
err := cb.Execute(func() (Overloading, error) {
return overload, nil
Expand All @@ -249,7 +249,7 @@ func driveQPS(cb *CircuitBreaker[int], count int, overload Overloading, re *requ
}
}

func assertFastFail(cb *CircuitBreaker[int], re *require.Assertions) {
func assertFastFail(cb *CircuitBreaker, re *require.Assertions) {
var executed = false
err := cb.Execute(func() (Overloading, error) {
executed = true
Expand All @@ -259,7 +259,7 @@ func assertFastFail(cb *CircuitBreaker[int], re *require.Assertions) {
re.False(executed)
}

func assertSucceeds(cb *CircuitBreaker[int], re *require.Assertions) {
func assertSucceeds(cb *CircuitBreaker, re *require.Assertions) {
err := cb.Execute(func() (Overloading, error) {
return No, nil
})
Expand Down
8 changes: 4 additions & 4 deletions client/pkg/utils/grpcutil/grpcutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/kvproto/pkg/pdpb"
"github.com/pingcap/log"

"github.com/tikv/pd/client/errs"
Expand Down Expand Up @@ -75,9 +74,10 @@ func UnaryBackofferInterceptor() grpc.UnaryClientInterceptor {
}
}

func UnaryCircuitBreakerInterceptor[T any]() grpc.UnaryClientInterceptor {
// UnaryCircuitBreakerInterceptor is a gRPC interceptor that adds a circuit breaker to the call.
func UnaryCircuitBreakerInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
cb := circuitbreaker.FromContext[*pdpb.GetRegionResponse](ctx)
cb := circuitbreaker.FromContext(ctx)
if cb == nil {
return invoker(ctx, method, req, reply, cc, opts...)
}
Expand Down Expand Up @@ -132,7 +132,7 @@ func GetClientConn(ctx context.Context, addr string, tlsCfg *tls.Config, do ...g
retryOpt := grpc.WithChainUnaryInterceptor(UnaryBackofferInterceptor())

// Add circuit breaker interceptor
cbOpt := grpc.WithChainUnaryInterceptor(UnaryCircuitBreakerInterceptor[any]())
cbOpt := grpc.WithChainUnaryInterceptor(UnaryCircuitBreakerInterceptor())

// Add retry related connection parameters
backoffOpts := grpc.WithConnectParams(grpc.ConnectParams{
Expand Down
6 changes: 3 additions & 3 deletions tests/integrations/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2073,7 +2073,7 @@ func TestCircuitBreaker(t *testing.T) {
cli := setupCli(ctx, re, endpoints)
defer cli.Close()

circuitBreaker := cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", circuitBreakerSettings)
circuitBreaker := cb.NewCircuitBreaker("region_meta", circuitBreakerSettings)
ctx = cb.WithCircuitBreaker(ctx, circuitBreaker)
for range 10 {
region, err := cli.GetRegion(ctx, []byte("a"))
Expand Down Expand Up @@ -2128,7 +2128,7 @@ func TestCircuitBreakerOpenAndChangeSettings(t *testing.T) {
cli := setupCli(ctx, re, endpoints)
defer cli.Close()

circuitBreaker := cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", circuitBreakerSettings)
circuitBreaker := cb.NewCircuitBreaker("region_meta", circuitBreakerSettings)
ctx = cb.WithCircuitBreaker(ctx, circuitBreaker)
for range 10 {
region, err := cli.GetRegion(ctx, []byte("a"))
Expand Down Expand Up @@ -2178,7 +2178,7 @@ func TestCircuitBreakerHalfOpenAndChangeSettings(t *testing.T) {
cli := setupCli(ctx, re, endpoints)
defer cli.Close()

circuitBreaker := cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", circuitBreakerSettings)
circuitBreaker := cb.NewCircuitBreaker("region_meta", circuitBreakerSettings)
ctx = cb.WithCircuitBreaker(ctx, circuitBreaker)
for range 10 {
region, err := cli.GetRegion(ctx, []byte("a"))
Expand Down

0 comments on commit 13d663d

Please sign in to comment.