Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Support custom retry logic per method #1081

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ func (c *DatabricksClient) GetOAuthToken(ctx context.Context, authDetails string

// Do sends an HTTP request against path.
func (c *DatabricksClient) Do(ctx context.Context, method, path string,
headers map[string]string, request, response any,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add this back in if desired, just a small formatting change.

visitors ...func(*http.Request) error) error {
headers map[string]string, request, response any, visitors ...func(*http.Request) error) error {
Copy link
Contributor

@renaudhartert-db renaudhartert-db Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional] I would go with one parameter per line to echo the way function calls and struct declarations are made:

func (c *DatabricksClient) Do(
	ctx      context.Context, 
	method   string,
	path     string,
	headers  map[string]string, 
	request  any, 
	response any, 
	visitors ...func(*http.Request) error
) error {

There's a couple of similar patterns in the Go standard library but not many. One of the reason is that long lists of parameters are usually substituted with a struct (https://google.github.io/styleguide/go/best-practices#option-structure). I actually wanted to make that change for quite sometime but it didn't feel right sending one PR just for that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. This Do() method should be formulated either as a struct argument or with functional options, since everything after path is optional. I'd prefer not to change this signature unless necessary, as it is used in multiple places in the TF provider.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with not using a struct as this is out of scope for this PR. Let's at least format idiomatically then. Our Go style does not mandate a 80 char line length which makes the current formatting quite arbitrary. I'm fine with either having everything on a single line or one argument per line.

opts := []httpclient.DoOption{}
for _, v := range visitors {
opts = append(opts, httpclient.WithRequestVisitor(v))
Expand Down
27 changes: 17 additions & 10 deletions config/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"fmt"
"net/http"
"net/url"
"regexp"
"time"

"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/common"
"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/useragent"
Expand Down Expand Up @@ -73,17 +75,22 @@ func (c *Config) NewApiClient() (*httpclient.ApiClient, error) {
return nil
},
},
TransientErrors: []string{
"REQUEST_LIMIT_EXCEEDED", // This is temporary workaround for SCIM API returning 500. Remove when it's fixed
},
ErrorMapper: apierr.GetAPIError,
ErrorRetriable: func(ctx context.Context, err error) bool {
var apiErr *apierr.APIError
if errors.As(err, &apiErr) {
return apiErr.IsRetriable(ctx)
}
return false
},
ErrorRetriable: httpclient.CombineRetriers(
func(ctx context.Context, _ *http.Request, _ *common.ResponseWrapper, err error) bool {
var apiErr *apierr.APIError
if errors.As(err, &apiErr) {
return apiErr.IsRetriable(ctx)
}
return false
},
httpclient.RetryUrlErrors,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this out of the ApiClient to have a single codesite where the retry logic is defined for the client. The downside is that you need to add this explicitly in your ErrorRetriable if you don't specify DefaultErrorRetriable. Happy to make this a default behavior, let me know what you think.

httpclient.RetryTransientErrors([]string{"REQUEST_LIMIT_EXCEEDED"}),
httpclient.RetryMatchedRequests([]httpclient.RestApiMatcher{
// Get Permissions API can be retried on 504
{Method: http.MethodGet, Path: *regexp.MustCompile(`/api/2.0/permissions/[^/]+/[^/]+`)},
}, httpclient.RetryOnGatewayTimeout),
),
}), nil
}

Expand Down
54 changes: 54 additions & 0 deletions config/api_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package config

import (
"context"
"fmt"
"io"
"net/http"
"strings"
"testing"

"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type hc func(r *http.Request) (*http.Response, error)

func (cb hc) RoundTrip(r *http.Request) (*http.Response, error) {
return cb(r)
}

func (cb hc) SkipRetryOnIO() bool {
return true
}

func TestApiClient_RetriesGetPermissionsOnGatewayTimeout(t *testing.T) {
requestCount := 0
c := &Config{
HTTPTransport: hc(func(r *http.Request) (*http.Response, error) {
initialRequestCount := requestCount
requestCount++
if initialRequestCount == 0 {
return &http.Response{
Request: r,
StatusCode: http.StatusGatewayTimeout,
Body: io.NopCloser(strings.NewReader(
fmt.Sprintf(`{"error_code":"TEMPORARILY_UNAVAILABLE", "message":"The service at %s is taking too long to process your request. Please try again later or try a faster operation."}`, r.URL))),
}, nil
}
return &http.Response{
Request: r,
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{"permissions": ["can_run_queries"]}`)),
}, nil
}),
}
client, err := c.NewApiClient()
require.NoError(t, err)
ctx := context.Background()
var res map[string][]string
err = client.Do(ctx, "GET", "/api/2.0/permissions/object/id", httpclient.WithResponseUnmarshal(&res))
assert.NoError(t, err)
assert.Equal(t, map[string][]string{"permissions": {"can_run_queries"}}, res)
}
Comment on lines +1 to +54
Copy link
Contributor

@renaudhartert-db renaudhartert-db Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend structuring the unit tests differently by having this test focused on how ApiClient manages mocked ErrorRetrier (or the absence of an ErrorRetrier). The tests of the ErrorRetrier themselves (e.g. verify that the path pattern match properly) should happen in httpclient/errors_test.go.

This test could look like the following (I did not verify that the code works):

type mock struct {
	MaxFails      int            // number of times the failed Response is returned
	FailResponse  *http.Response // response to return in case of fail
	FailError     error          // error to return in case of fail
	NumCalls      int            // total number of calls
}

func (m *mock) RoundTrip(r *http.Request) (*http.Response, error) {
	m.NumCalls++
	if m.NumCalls <= m.MaxFails {
		return m.FailResponse, n.FailError
	}
	return &http.Response{
		Request:    r,
		StatusCode: http.StatusOK,
		Body:       io.NopCloser(strings.NewReader(`{}`)),
	}, nil
}

func (m *mock) SkipRetryOnIO() bool {
	return true
}

func TestApiClient_Do_retries(t *testing.T) {
	testCases := []struct{
		desc         string
		config       *Config
		errorRetrier ErrorRetrier
		wantNumCalls int
	} {
		{
			desc: "nil retrier",
			mock: &mock{
				MaxFails: 1,
				FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}
			}
			wantNumCalls: 1,
		},
		{
			desc: "no retry",
			mock: &mock{
				MaxFails: 1,
				FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}
			}
			errorRetrier: func(context.Context, *http.Request, *common.ResponseWrapper, error) bool {
				return false
			},
			wantNumCalls: 1,
		},
		{
			desc: "retry 1 time",
			mock: &mock{
				MaxFails: 1,
				FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}
			}
			errorRetrier: func(context.Context, *http.Request, *common.ResponseWrapper, error) bool {
				return true
			},
			wantNumCalls: 2,
		},
		{
			desc: "retry 2 times",
			mock: &mock{
				MaxFails: 2,
				FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}
			}
			errorRetrier: func(_ context.Context, _ *http.Request, _ *common.ResponseWrapper, _ error) bool {
				return true
			},
			wantNumCalls: 3,
		},
		{
			desc: "retry 3 times",
			mock: &mock{
				MaxFails: 3,
				FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}
			}
			errorRetrier: func(_ context.Context, _ *http.Request, _ *common.ResponseWrapper, _ error) bool {
				return true
			},
			wantNumCalls: 4,
		},
	} 


	func _, tc := range testCases {
		t.Run(tc.desc, func(t *testing.T) {
			cfg := &Config{HTTPTransport: tc.mock} 
			client, err := cfg.NewApiClient()
			client.ErrorRetrier = tc.errorRetrier

			err = client.Do(context.Background(), "GET", "test-path") 
			gotNumCalls = tc.mock.NumCalls

			if gotNumCalls != tc.wantNumCalls {
				t.Errorf("got %d calls, want %d", gotNumCalls, tc.wantNumCalls)
			}
		})
	}
}

Please feel free to ignore this comment if this is too much work or if the ApiClient cannot be instrumented that easily.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It definitely can be instrumented this way, and this is a nice test case to use (I'll adapt it and include it in this PR). However, I did want to specifically test the get permissions pathway. Essentially, this tests that "the client returned by Config.GetApiClient() correctly implements retry on 504." I will add more test cases here though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did want to specifically test the get permissions pathway

Sounds good to me as long as this complements the overall testing of the retry logic.

17 changes: 10 additions & 7 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,16 @@ func (c *Config) EnsureResolved() error {
HTTPTimeout: time.Duration(c.HTTPTimeoutSeconds) * time.Second,
Transport: c.HTTPTransport,
ErrorMapper: c.refreshTokenErrorMapper,
TransientErrors: []string{
"throttled",
"too many requests",
"429",
"request limit exceeded",
"rate limit",
},
ErrorRetriable: httpclient.CombineRetriers(
httpclient.DefaultErrorRetriable,
httpclient.RetryTransientErrors([]string{
"throttled",
"too many requests",
"429",
"request limit exceeded",
"rate limit",
}),
),
})
if c.azureTenantIdFetchClient == nil {
c.azureTenantIdFetchClient = &http.Client{
Expand Down
57 changes: 12 additions & 45 deletions httpclient/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net/http"
"net/url"
"runtime"
"strings"
"time"

"github.com/databricks/databricks-sdk-go/common"
Expand All @@ -35,9 +34,8 @@ type ClientConfig struct {
DebugTruncateBytes int
RateLimitPerSecond int

ErrorMapper func(ctx context.Context, resp common.ResponseWrapper) error
ErrorRetriable func(ctx context.Context, err error) bool
TransientErrors []string
ErrorMapper func(ctx context.Context, resp common.ResponseWrapper) error
ErrorRetriable ErrorRetryer

Transport http.RoundTripper
}
Expand Down Expand Up @@ -130,7 +128,6 @@ func (c *ApiClient) Do(ctx context.Context, method, path string, opts ...DoOptio
// merge client-wide and request-specific visitors
visitors = append(visitors, o.in)
}

}
// Use default AuthVisitor if none is provided
if authVisitor == nil {
Expand Down Expand Up @@ -170,45 +167,6 @@ func (c *ApiClient) Do(ctx context.Context, method, path string, opts ...DoOptio
return nil
}

func (c *ApiClient) isRetriable(ctx context.Context, err error) bool {
if c.config.ErrorRetriable(ctx, err) {
return true
}
if isRetriableUrlError(err) {
// all IO errors are retriable
logger.Debugf(ctx, "Attempting retry because of IO error: %s", err)
return true
}
message := err.Error()
// Handle transient errors for retries
for _, substring := range c.config.TransientErrors {
if strings.Contains(message, substring) {
logger.Debugf(ctx, "Attempting retry because of %#v", substring)
return true
}
}
// some API's recommend retries on HTTP 500, but we'll add that later
return false
}

// Common error-handling logic for all responses that may need to be retried.
//
// If the error is retriable, return a retries.Err to retry the request. However, as the request body will have been consumed
// by the first attempt, the body must be reset before retrying. If the body cannot be reset, return a retries.Err to halt.
//
// Always returns nil for the first parameter as there is no meaningful response body to return in the error case.
//
// If it is certain that an error should not be retried, use failRequest() instead.
func (c *ApiClient) handleError(ctx context.Context, err error, body common.RequestBody) (*common.ResponseWrapper, *retries.Err) {
if !c.isRetriable(ctx, err) {
return nil, retries.Halt(err)
}
if resetErr := body.Reset(); resetErr != nil {
return nil, retries.Halt(resetErr)
}
return nil, retries.Continue(err)
}

// Fails the request with a retries.Err to halt future retries.
func (c *ApiClient) failRequest(msg string, err error) (*common.ResponseWrapper, *retries.Err) {
err = fmt.Errorf("%s: %w", msg, err)
Expand Down Expand Up @@ -299,7 +257,16 @@ func (c *ApiClient) attempt(

// proactively release the connections in HTTP connection pool
c.httpClient.CloseIdleConnections()
return c.handleError(ctx, err, requestBody)

// Non-retriable errors can be returned immediately.
if !c.config.ErrorRetriable(ctx, request, &responseWrapper, err) {
return nil, retries.Halt(err)
}
// Retriable errors may require the request body to be reset.
if resetErr := requestBody.Reset(); resetErr != nil {
return nil, retries.Halt(resetErr)
}
return nil, retries.Continue(err)
}
}

Expand Down
80 changes: 70 additions & 10 deletions httpclient/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"io"
"net/http"
"net/url"
"regexp"
"strings"

"github.com/databricks/databricks-sdk-go/common"
"github.com/databricks/databricks-sdk-go/logger"
)

type HttpError struct {
Expand Down Expand Up @@ -45,17 +47,39 @@ func DefaultErrorMapper(ctx context.Context, resp common.ResponseWrapper) error
}
}

func DefaultErrorRetriable(ctx context.Context, err error) bool {
var httpError *HttpError
if errors.As(err, &httpError) {
if httpError.StatusCode == http.StatusTooManyRequests {
return true
}
if httpError.StatusCode == http.StatusGatewayTimeout {
return true
type ErrorRetryer func(context.Context, *http.Request, *common.ResponseWrapper, error) bool

func DefaultErrorRetriable(ctx context.Context, req *http.Request, resp *common.ResponseWrapper, err error) bool {
return CombineRetriers(
RetryOnTooManyRequests,
RetryOnGatewayTimeout,
RetryUrlErrors,
)(ctx, req, resp, err)
}

func RetryOnTooManyRequests(ctx context.Context, _ *http.Request, resp *common.ResponseWrapper, err error) bool {
if resp.Response == nil {
return false
}
return resp.Response.StatusCode == http.StatusTooManyRequests
}

func RetryOnGatewayTimeout(ctx context.Context, _ *http.Request, resp *common.ResponseWrapper, err error) bool {
if resp.Response == nil {
return false
}
return resp.Response.StatusCode == http.StatusGatewayTimeout
}

func CombineRetriers(retriers ...ErrorRetryer) ErrorRetryer {
return func(ctx context.Context, req *http.Request, resp *common.ResponseWrapper, err error) bool {
for _, retrier := range retriers {
if retrier(ctx, req, resp, err) {
return true
}
}
return false
}
return false
}

var urlErrorTransientErrorMessages = []string{
Expand All @@ -66,15 +90,51 @@ var urlErrorTransientErrorMessages = []string{
"i/o timeout",
}

func isRetriableUrlError(err error) bool {
func RetryUrlErrors(ctx context.Context, _ *http.Request, _ *common.ResponseWrapper, err error) bool {
var urlError *url.Error
if !errors.As(err, &urlError) {
return false
}
for _, msg := range urlErrorTransientErrorMessages {
if strings.Contains(err.Error(), msg) {
logger.Debugf(ctx, "Attempting retry because of IO error: %s", err)
return true
}
}
return false
}

func RetryTransientErrors(errors []string) ErrorRetryer {
return func(ctx context.Context, _ *http.Request, _ *common.ResponseWrapper, err error) bool {
message := err.Error()
// Handle transient errors for retries
for _, substring := range errors {
if strings.Contains(message, substring) {
logger.Debugf(ctx, "Attempting retry because of %#v", substring)
return true
}
}
return false
}
}

type RestApiMatcher struct {
Method string
Path regexp.Regexp
}

func (m *RestApiMatcher) Matches(req *http.Request) bool {
return req.Method == m.Method && m.Path.MatchString(req.URL.Path)
}

func RetryMatchedRequests(methods []RestApiMatcher, retryer ErrorRetryer) ErrorRetryer {
return func(ctx context.Context, r *http.Request, rw *common.ResponseWrapper, err error) bool {
for _, m := range methods {
if m.Matches(r) && retryer(ctx, r, rw, err) {
logger.Debugf(ctx, "Attempting retry because of gateway timeout")
return true
}
}
return false
}
}
Loading