diff --git a/README.md b/README.md index bab8b7133..700fce73f 100644 --- a/README.md +++ b/README.md @@ -391,7 +391,7 @@ e.POST("/path"). Expect(). Status(http.StatusOK) -// custom retry policy +// custom built-in retry policy e.POST("/path"). WithMaxRetries(5). WithRetryPolicy(httpexpect.RetryAllErrors). @@ -404,6 +404,15 @@ e.POST("/path"). WithRetryDelay(time.Second, time.Minute). Expect(). Status(http.StatusOK) + +// custom user-defined retry policy +e.POST("/path"). + WithMaxRetries(5). + WithRetryPolicyFunc(func(resp *http.Response, err error) bool { + return resp.StatusCode == http.StatusTeapot + }). + Expect(). + Status(http.StatusOK) ``` ##### Subdomains and per-request URL diff --git a/request.go b/request.go index 7e2312e34..18e0e24e2 100644 --- a/request.go +++ b/request.go @@ -36,11 +36,13 @@ type Request struct { redirectPolicy RedirectPolicy maxRedirects int - retryPolicy RetryPolicy - maxRetries int - minRetryDelay time.Duration - maxRetryDelay time.Duration - sleepFn func(d time.Duration) <-chan time.Time + retryPolicy RetryPolicy + withRetryPolicyCalled bool + maxRetries int + minRetryDelay time.Duration + maxRetryDelay time.Duration + sleepFn func(d time.Duration) <-chan time.Time + retryPolicyFn func(*http.Response, error) bool timeout time.Duration @@ -755,7 +757,64 @@ func (r *Request) WithRetryPolicy(policy RetryPolicy) *Request { return r } + if r.retryPolicyFn != nil { + opChain.fail(AssertionFailure{ + Type: AssertUsage, + Errors: []error{ + fmt.Errorf("expected: " + + "WithRetryPolicyFunc() and WithRetryPolicy() should be mutual exclusive, " + + "WithRetryPolicyFunc() is already called"), + }, + }) + return r + } + r.retryPolicy = policy + r.withRetryPolicyCalled = true + + return r +} + +// WithRetryPolicyFunc sets a function to replace built-in policies +// with user-defined policy. +// +// The function expects you to return true to perform a retry. And false to +// not perform a retry. +// +// Example: +// +// req := NewRequestC(config, "POST", "/path") +// req.WithRetryPolicyFunc(func(res *http.Response, err error) bool { +// return resp.StatusCode == http.StatusTeapot +// }) +func (r *Request) WithRetryPolicyFunc(fn func(res *http.Response, err error) bool) *Request { + opChain := r.chain.enter("WithRetryPolicyFunc()") + defer opChain.leave() + + r.mu.Lock() + defer r.mu.Unlock() + + if opChain.failed() { + return r + } + + if !r.checkOrder(opChain, "WithRetryPolicyFunc()") { + return r + } + + if r.withRetryPolicyCalled { + opChain.fail(AssertionFailure{ + Type: AssertUsage, + Errors: []error{ + fmt.Errorf("expected: " + + "WithRetryPolicyFunc() and WithRetryPolicy() should be mutual exclusive, " + + "WithRetryPolicy() is already called"), + }, + }) + return r + } + + r.retryPolicyFn = fn return r } @@ -2332,6 +2391,10 @@ func (r *Request) retryRequest(reqFunc func() (*http.Response, error)) ( } func (r *Request) shouldRetry(resp *http.Response, err error) bool { + if r.retryPolicyFn != nil { + return r.retryPolicyFn(resp, err) + } + var ( isTemporaryNetworkError bool // Deprecated isTimeoutError bool diff --git a/request_test.go b/request_test.go index e83a31f40..695ba39f7 100644 --- a/request_test.go +++ b/request_test.go @@ -3349,6 +3349,60 @@ func TestRequest_RetriesCancellation(t *testing.T) { assert.Equal(t, 1, callCount) } +func TestRequest_WithRetryPolicyFunc(t *testing.T) { + tests := []struct { + name string + fn func(res *http.Response, err error) bool + callCount int + }{ + { + name: "should not retry", + fn: func(res *http.Response, err error) bool { + return false + }, + callCount: 1, + }, + { + name: "should retry", + fn: func(res *http.Response, err error) bool { + return true + }, + callCount: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + callCount := 0 + + client := &mockClient{ + resp: http.Response{ + StatusCode: http.StatusTeapot, + }, + cb: func(req *http.Request) { + callCount++ + }, + } + + cfg := Config{ + Client: client, + Reporter: newMockReporter(t), + } + + req := NewRequestC(cfg, http.MethodGet, "/url"). + WithMaxRetries(1). + WithRetryDelay(0, 0). + WithRetryPolicyFunc(tt.fn) + req.chain.assert(t, success) + + resp := req.Expect() + resp.chain.assert(t, success) + + assert.Equal(t, tt.callCount, callCount) + }) + } +} + func TestRequest_Conflicts(t *testing.T) { client := &mockClient{} @@ -3492,6 +3546,44 @@ func TestRequest_Conflicts(t *testing.T) { }) } }) + + t.Run("retry policy conflict", func(t *testing.T) { + cases := []struct { + name string + fn func(req *Request) + }{ + { + "WithRetryPolicyFunc", + func(req *Request) { + req.WithRetryPolicyFunc(func(res *http.Response, err error) bool { + return res.StatusCode == http.StatusTeapot + }) + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + req := NewRequestC(config, "GET", "url") + + tc.fn(req) + req.chain.assert(t, success) + + req.WithRetryPolicy(RetryAllErrors) + req.chain.assert(t, failure) + }) + + t.Run(tc.name+" - reversed", func(t *testing.T) { + req := NewRequestC(config, "GET", "url") + + req.WithRetryPolicy(RetryAllErrors) + req.chain.assert(t, success) + + tc.fn(req) + req.chain.assert(t, failure) + }) + } + }) } func TestRequest_Usage(t *testing.T) { @@ -3642,6 +3734,15 @@ func TestRequest_Usage(t *testing.T) { prepFails: false, expectFails: true, }, + { + name: "WithRetryPolicyFunc - nil argument", + client: &mockClient{}, + prepFunc: func(req *Request) { + req.WithRetryPolicyFunc(nil) + }, + prepFails: false, + expectFails: false, + }, } for _, tc := range cases { @@ -3934,6 +4035,14 @@ func TestRequest_Order(t *testing.T) { req.WithMultipart() }, }, + { + name: "WithRetryPolicyFunc after Expect", + afterFunc: func(req *Request) { + req.WithRetryPolicyFunc(func(res *http.Response, err error) bool { + return res.StatusCode == http.StatusTeapot + }) + }, + }, } for _, tc := range cases {