From d8d00bca6fd01d4893104cb0a775d5dc4b40f2de Mon Sep 17 00:00:00 2001 From: Vishal Gawade Date: Fri, 8 Nov 2024 21:17:21 -0600 Subject: [PATCH 1/8] Adding error handling for status code 5xx --- go.mod | 3 +++ pkg/osv/osv.go | 41 +++++++++++++++++++++++++---------------- pkg/osv/osv_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 16 deletions(-) create mode 100644 pkg/osv/osv_test.go diff --git a/go.mod b/go.mod index caf5f92b7b..effda6b561 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/package-url/packageurl-go v0.1.3 github.com/pandatix/go-cvss v0.6.2 github.com/spdx/tools-golang v0.5.5 + github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.18.0 github.com/tidwall/pretty v1.2.1 github.com/tidwall/sjson v1.2.5 @@ -57,6 +58,7 @@ require ( github.com/containerd/stargz-snapshotter/estargz v0.15.1 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.5 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.11.0 // indirect github.com/docker/distribution v2.8.3+incompatible // indirect github.com/docker/docker-credential-helpers v0.8.1 // indirect @@ -84,6 +86,7 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0-rc3 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect diff --git a/pkg/osv/osv.go b/pkg/osv/osv.go index e27d700e60..3fb85b3f79 100644 --- a/pkg/osv/osv.go +++ b/pkg/osv/osv.go @@ -6,7 +6,7 @@ import ( "encoding/json" "fmt" "io" - "math/rand" + "net/http" "time" @@ -16,7 +16,7 @@ import ( "golang.org/x/sync/errgroup" ) -const ( +var ( // QueryEndpoint is the URL for posting queries to OSV. QueryEndpoint = "https://api.osv.dev/v1/querybatch" // GetEndpoint is the URL for getting vulnerabilities from OSV. @@ -158,7 +158,7 @@ func chunkBy[T any](items []T, chunkSize int) [][]T { // checkResponseError checks if the response has an error. func checkResponseError(resp *http.Response) error { - if resp.StatusCode == http.StatusOK { + if resp.StatusCode >= 200 && resp.StatusCode < 300 { return nil } @@ -311,21 +311,30 @@ func makeRetryRequest(action func() (*http.Response, error)) (*http.Response, er var resp *http.Response var err error - for i := range maxRetryAttempts { - // rand is initialized with a random number (since go1.20), and is also safe to use concurrently - // we do not need to use a cryptographically secure random jitter, this is just to spread out the retry requests - // #nosec G404 - jitterAmount := (rand.Float64() * float64(jitterMultiplier) * float64(i)) - time.Sleep(time.Duration(i*i)*time.Second + time.Duration(jitterAmount*1000)*time.Millisecond) - + for i := 0; i < maxRetryAttempts; i++ { resp, err = action() - if err == nil { - // Check the response for HTTP errors - err = checkResponseError(resp) - if err == nil { - break - } + if err != nil { + + sleepDuration := time.Duration(i*jitterMultiplier) * time.Second + time.Sleep(sleepDuration) + continue } + + if resp.StatusCode >= 500 { + + resp.Body.Close() + sleepDuration := time.Duration(i*jitterMultiplier) * time.Second + time.Sleep(sleepDuration) + continue + } + + // Success or client error, do not retry + break + } + + if resp != nil && resp.StatusCode >= 500 { + resp.Body.Close() + return nil, fmt.Errorf("received %d status code after %d attempts", resp.StatusCode, maxRetryAttempts) } return resp, err diff --git a/pkg/osv/osv_test.go b/pkg/osv/osv_test.go new file mode 100644 index 0000000000..5148ba353d --- /dev/null +++ b/pkg/osv/osv_test.go @@ -0,0 +1,43 @@ +// pkg/osv/osv_test.go + +package osv + +import ( + "log" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRetryOn5xx(t *testing.T) { + attempt := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt++ + w.WriteHeader(http.StatusInternalServerError) // 500 + })) + defer server.Close() + + // Override the QueryEndpoint for testing + originalQueryEndpoint := QueryEndpoint + QueryEndpoint = server.URL + defer func() { QueryEndpoint = originalQueryEndpoint }() + + client := &http.Client{ + Timeout: 2 * time.Second, + } + + resp, err := makeRetryRequest(func() (*http.Response, error) { + req, _ := http.NewRequest(http.MethodPost, QueryEndpoint, nil) + req.Header.Set("Content-Type", "application/json") + return client.Do(req) + }) + + log.Printf("TestRetryOn5xx: resp = %v, err = %v", resp, err) + + assert.Nil(t, resp, "Expected response to be nil after retries on 5xx errors") + assert.Error(t, err, "Expected an error after retries on 5xx errors") + assert.Equal(t, maxRetryAttempts, attempt, "Expected number of attempts to equal maxRetryAttempts") +} From c91540be86c512e1fdc95ed897a4f991d0ddb4e2 Mon Sep 17 00:00:00 2001 From: Vishal Gawade Date: Fri, 8 Nov 2024 21:47:34 -0600 Subject: [PATCH 2/8] updating osv.go comment for handling 200 inside makeRetryRequest --- pkg/osv/osv.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/osv/osv.go b/pkg/osv/osv.go index 3fb85b3f79..87278a0230 100644 --- a/pkg/osv/osv.go +++ b/pkg/osv/osv.go @@ -306,7 +306,7 @@ func HydrateWithClient(resp *BatchedResponse, client *http.Client) (*HydratedBat return &hydrated, nil } -// makeRetryRequest will return an error on both network errors, and if the response is not 200 +// makeRetryRequest will return an error on both network errors, and if the response is not 2xx func makeRetryRequest(action func() (*http.Response, error)) (*http.Response, error) { var resp *http.Response var err error From 03db1613f57ca52c7091411897d80306d16b01e8 Mon Sep 17 00:00:00 2001 From: Vishal Gawade Date: Sun, 10 Nov 2024 00:09:39 -0600 Subject: [PATCH 3/8] refactor: use standard library for test assertions --- pkg/osv/osv_test.go | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/pkg/osv/osv_test.go b/pkg/osv/osv_test.go index 5148ba353d..b1ad1021ef 100644 --- a/pkg/osv/osv_test.go +++ b/pkg/osv/osv_test.go @@ -1,5 +1,3 @@ -// pkg/osv/osv_test.go - package osv import ( @@ -8,8 +6,6 @@ import ( "net/http/httptest" "testing" "time" - - "github.com/stretchr/testify/assert" ) func TestRetryOn5xx(t *testing.T) { @@ -37,7 +33,18 @@ func TestRetryOn5xx(t *testing.T) { log.Printf("TestRetryOn5xx: resp = %v, err = %v", resp, err) - assert.Nil(t, resp, "Expected response to be nil after retries on 5xx errors") - assert.Error(t, err, "Expected an error after retries on 5xx errors") - assert.Equal(t, maxRetryAttempts, attempt, "Expected number of attempts to equal maxRetryAttempts") + // Assertion: resp should be nil + if resp != nil { + t.Errorf("Expected response to be nil after retries on 5xx errors, but got: %v", resp) + } + + // Assertion: err should not be nil + if err == nil { + t.Errorf("Expected an error after retries on 5xx errors, but got none") + } + + // Assertion: number of attempts should equal maxRetryAttempts + if attempt != maxRetryAttempts { + t.Errorf("Expected number of attempts to equal maxRetryAttempts (%d), but got: %d", maxRetryAttempts, attempt) + } } From a30e28e7d9f8e0d61d907095f3678aae46713aa7 Mon Sep 17 00:00:00 2001 From: Vishal Gawade Date: Sun, 10 Nov 2024 15:52:34 -0600 Subject: [PATCH 4/8] fix: correct makeRetryRequest retry logic and update tests --- pkg/osv/osv.go | 36 +++++++------- pkg/osv/osv_test.go | 115 +++++++++++++++++++++++++++++++------------- 2 files changed, 101 insertions(+), 50 deletions(-) diff --git a/pkg/osv/osv.go b/pkg/osv/osv.go index 87278a0230..a77ff9f6c5 100644 --- a/pkg/osv/osv.go +++ b/pkg/osv/osv.go @@ -6,7 +6,7 @@ import ( "encoding/json" "fmt" "io" - + "math/rand" "net/http" "time" @@ -313,31 +313,33 @@ func makeRetryRequest(action func() (*http.Response, error)) (*http.Response, er for i := 0; i < maxRetryAttempts; i++ { resp, err = action() - if err != nil { - - sleepDuration := time.Duration(i*jitterMultiplier) * time.Second + if err != nil || (resp != nil && resp.StatusCode >= 500) { + if resp != nil { + resp.Body.Close() + } + err = fmt.Errorf("attempt %d: received status code %d", i+1, getStatusCode(resp)) + // Apply jittered exponential back-off before retrying. + jitter := time.Duration(rand.Float64() * float64(jitterMultiplier) * float64(time.Second)) + sleepDuration := time.Duration(i*i)*time.Second + jitter time.Sleep(sleepDuration) - continue - } - - if resp.StatusCode >= 500 { - resp.Body.Close() - sleepDuration := time.Duration(i*jitterMultiplier) * time.Second - time.Sleep(sleepDuration) continue } - - // Success or client error, do not retry + // Success (2xx) or client-side error (4xx), do not retry. break } - if resp != nil && resp.StatusCode >= 500 { - resp.Body.Close() - return nil, fmt.Errorf("received %d status code after %d attempts", resp.StatusCode, maxRetryAttempts) + return resp, err +} + +// getStatusCode safely retrieves the status code from the response. +// Returns 0 if resp is nil. +func getStatusCode(resp *http.Response) int { + if resp == nil { + return 0 } - return resp, err + return resp.StatusCode } func MakeDetermineVersionRequest(name string, hashes []DetermineVersionHash) (*DetermineVersionResponse, error) { diff --git a/pkg/osv/osv_test.go b/pkg/osv/osv_test.go index b1ad1021ef..84a5d7926e 100644 --- a/pkg/osv/osv_test.go +++ b/pkg/osv/osv_test.go @@ -1,50 +1,99 @@ +// osv_test.go + package osv import ( - "log" "net/http" "net/http/httptest" "testing" "time" ) -func TestRetryOn5xx(t *testing.T) { - attempt := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - attempt++ - w.WriteHeader(http.StatusInternalServerError) // 500 - })) - defer server.Close() - - // Override the QueryEndpoint for testing - originalQueryEndpoint := QueryEndpoint - QueryEndpoint = server.URL - defer func() { QueryEndpoint = originalQueryEndpoint }() - - client := &http.Client{ - Timeout: 2 * time.Second, +func TestMakeRetryRequest(t *testing.T) { + testCases := []struct { + name string + statusCodes []int + expectedRespNil bool + expectedErr bool + expectedAttempts int + }{ + { + name: "Success on first attempt (200)", + statusCodes: []int{200}, + expectedRespNil: false, + expectedErr: false, + expectedAttempts: 1, + }, + { + name: "Client error (400), no retry", + statusCodes: []int{400}, + expectedRespNil: false, + expectedErr: false, + expectedAttempts: 1, + }, + { + name: "Server error (500) x4, fail after retries", + statusCodes: []int{500, 500, 500, 500}, + expectedRespNil: false, // resp is returned but contains server error + expectedErr: true, + expectedAttempts: maxRetryAttempts, + }, + { + name: "Server error (500) x2, then success (200)", + statusCodes: []int{500, 500, 200}, + expectedRespNil: false, + expectedErr: false, + expectedAttempts: 3, + }, } - resp, err := makeRetryRequest(func() (*http.Response, error) { - req, _ := http.NewRequest(http.MethodPost, QueryEndpoint, nil) - req.Header.Set("Content-Type", "application/json") - return client.Do(req) - }) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + attempt := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if attempt < len(tc.statusCodes) { + w.WriteHeader(tc.statusCodes[attempt]) + } else { + // If more requests are made than status codes provided, repeat the last status code. + w.WriteHeader(tc.statusCodes[len(tc.statusCodes)-1]) + } + attempt++ + })) + defer server.Close() - log.Printf("TestRetryOn5xx: resp = %v, err = %v", resp, err) + // Override the QueryEndpoint for testing. + originalQueryEndpoint := QueryEndpoint + QueryEndpoint = server.URL + defer func() { QueryEndpoint = originalQueryEndpoint }() - // Assertion: resp should be nil - if resp != nil { - t.Errorf("Expected response to be nil after retries on 5xx errors, but got: %v", resp) - } + client := &http.Client{ + Timeout: 2 * time.Second, + } - // Assertion: err should not be nil - if err == nil { - t.Errorf("Expected an error after retries on 5xx errors, but got none") - } + resp, err := makeRetryRequest(func() (*http.Response, error) { + req, _ := http.NewRequest(http.MethodPost, QueryEndpoint, nil) + req.Header.Set("Content-Type", "application/json") + return client.Do(req) + }) + + // Assertions using standard library. + if tc.expectedRespNil && resp != nil { + t.Errorf("Expected response to be nil, but got: %v", resp) + } + if !tc.expectedRespNil && resp == nil { + t.Errorf("Expected response to be non-nil, but got nil") + } + + if tc.expectedErr && err == nil { + t.Errorf("Expected an error, but got none") + } + if !tc.expectedErr && err != nil { + t.Errorf("Did not expect an error, but got: %v", err) + } - // Assertion: number of attempts should equal maxRetryAttempts - if attempt != maxRetryAttempts { - t.Errorf("Expected number of attempts to equal maxRetryAttempts (%d), but got: %d", maxRetryAttempts, attempt) + if attempt != tc.expectedAttempts { + t.Errorf("Expected %d attempts, but got: %d", tc.expectedAttempts, attempt) + } + }) } } From e297913c31f0699ed8ce88ff50b153d30c050a99 Mon Sep 17 00:00:00 2001 From: Vishal Gawade Date: Sun, 10 Nov 2024 16:06:07 -0600 Subject: [PATCH 5/8] Removing the QueryEndpoint Override from test --- pkg/osv/osv_test.go | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/pkg/osv/osv_test.go b/pkg/osv/osv_test.go index 84a5d7926e..5ec6228229 100644 --- a/pkg/osv/osv_test.go +++ b/pkg/osv/osv_test.go @@ -10,6 +10,8 @@ import ( ) func TestMakeRetryRequest(t *testing.T) { + t.Parallel() // Enable parallel execution of the main test function + testCases := []struct { name string statusCodes []int @@ -48,9 +50,11 @@ func TestMakeRetryRequest(t *testing.T) { } for _, tc := range testCases { + tc := tc // Capture range variable for parallel tests t.Run(tc.name, func(t *testing.T) { + t.Parallel() // Enable parallel execution of subtests attempt := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { if attempt < len(tc.statusCodes) { w.WriteHeader(tc.statusCodes[attempt]) } else { @@ -61,22 +65,23 @@ func TestMakeRetryRequest(t *testing.T) { })) defer server.Close() - // Override the QueryEndpoint for testing. - originalQueryEndpoint := QueryEndpoint - QueryEndpoint = server.URL - defer func() { QueryEndpoint = originalQueryEndpoint }() - client := &http.Client{ Timeout: 2 * time.Second, } resp, err := makeRetryRequest(func() (*http.Response, error) { - req, _ := http.NewRequest(http.MethodPost, QueryEndpoint, nil) + req, err := http.NewRequest(http.MethodPost, server.URL, nil) + if err != nil { + return nil, err + } req.Header.Set("Content-Type", "application/json") return client.Do(req) }) + if resp != nil { + defer resp.Body.Close() + } - // Assertions using standard library. + // Assertions if tc.expectedRespNil && resp != nil { t.Errorf("Expected response to be nil, but got: %v", resp) } From 8c369d6555b7ce29ae926281bcd2c0b4ded3fd2c Mon Sep 17 00:00:00 2001 From: Vishal Gawade Date: Mon, 11 Nov 2024 14:25:01 -0600 Subject: [PATCH 6/8] Added error handling in makeRetryRequest, removed redundant functions, and cleared all lint warnings --- pkg/osv/osv.go | 79 ++++++++++++++------------- pkg/osv/osv_test.go | 129 ++++++++++++++++++++++---------------------- 2 files changed, 105 insertions(+), 103 deletions(-) diff --git a/pkg/osv/osv.go b/pkg/osv/osv.go index a77ff9f6c5..c37a34d574 100644 --- a/pkg/osv/osv.go +++ b/pkg/osv/osv.go @@ -156,21 +156,6 @@ func chunkBy[T any](items []T, chunkSize int) [][]T { return append(chunks, items) } -// checkResponseError checks if the response has an error. -func checkResponseError(resp *http.Response) error { - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - return nil - } - - respBuf, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read error response from server: %w", err) - } - defer resp.Body.Close() - - return fmt.Errorf("server response error: %s", string(respBuf)) -} - // MakeRequest sends a batched query to osv.dev func MakeRequest(request BatchedQuery) (*BatchedResponse, error) { return MakeRequestWithClient(request, http.DefaultClient) @@ -306,40 +291,54 @@ func HydrateWithClient(resp *BatchedResponse, client *http.Client) (*HydratedBat return &hydrated, nil } -// makeRetryRequest will return an error on both network errors, and if the response is not 2xx +// makeRetryRequest executes HTTP requests with exponential backoff retry logic func makeRetryRequest(action func() (*http.Response, error)) (*http.Response, error) { - var resp *http.Response - var err error - - for i := 0; i < maxRetryAttempts; i++ { - resp, err = action() - if err != nil || (resp != nil && resp.StatusCode >= 500) { - if resp != nil { - resp.Body.Close() + const maxRetryAttempts = 4 + const jitterMultiplier = 2 + + var lastErr error + + for i := range maxRetryAttempts { + resp, err := action() + if err != nil { + lastErr = fmt.Errorf("attempt %d: request failed: %w", i+1, err) + if i == maxRetryAttempts-1 { + break } - err = fmt.Errorf("attempt %d: received status code %d", i+1, getStatusCode(resp)) - // Apply jittered exponential back-off before retrying. + backoff := time.Duration(i*i) * time.Second jitter := time.Duration(rand.Float64() * float64(jitterMultiplier) * float64(time.Second)) - sleepDuration := time.Duration(i*i)*time.Second + jitter - time.Sleep(sleepDuration) - + time.Sleep(backoff + jitter) continue } - // Success (2xx) or client-side error (4xx), do not retry. - break - } - return resp, err -} + // Check response validity + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return resp, nil + } + + // Read error response + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + lastErr = fmt.Errorf("attempt %d: failed to read response: %w", i+1, err) + } else if resp.StatusCode >= 400 && resp.StatusCode < 500 { + // Client errors (4xx) - return immediately + return nil, fmt.Errorf("client error: status=%d body=%s", resp.StatusCode, body) + } else { + // Server errors (5xx) - can retry + lastErr = fmt.Errorf("server error: status=%d body=%s", resp.StatusCode, body) + } + + if i == maxRetryAttempts-1 { + break + } -// getStatusCode safely retrieves the status code from the response. -// Returns 0 if resp is nil. -func getStatusCode(resp *http.Response) int { - if resp == nil { - return 0 + backoff := time.Duration(i*i) * time.Second + jitter := time.Duration(rand.Float64() * float64(jitterMultiplier) * float64(time.Second)) + time.Sleep(backoff + jitter) } - return resp.StatusCode + return nil, fmt.Errorf("max retries exceeded: %w", lastErr) } func MakeDetermineVersionRequest(name string, hashes []DetermineVersionHash) (*DetermineVersionResponse, error) { diff --git a/pkg/osv/osv_test.go b/pkg/osv/osv_test.go index 5ec6228229..68d8e51688 100644 --- a/pkg/osv/osv_test.go +++ b/pkg/osv/osv_test.go @@ -1,103 +1,106 @@ -// osv_test.go - package osv import ( + "fmt" + "io" "net/http" "net/http/httptest" + "strings" "testing" "time" ) func TestMakeRetryRequest(t *testing.T) { - t.Parallel() // Enable parallel execution of the main test function + t.Parallel() - testCases := []struct { - name string - statusCodes []int - expectedRespNil bool - expectedErr bool - expectedAttempts int + tests := []struct { + name string + statusCodes []int + expectedError string + wantAttempts int }{ { - name: "Success on first attempt (200)", - statusCodes: []int{200}, - expectedRespNil: false, - expectedErr: false, - expectedAttempts: 1, + name: "success on first attempt", + statusCodes: []int{http.StatusOK}, + wantAttempts: 1, }, { - name: "Client error (400), no retry", - statusCodes: []int{400}, - expectedRespNil: false, - expectedErr: false, - expectedAttempts: 1, + name: "client error no retry", + statusCodes: []int{http.StatusBadRequest}, + expectedError: "client error: status=400", + wantAttempts: 1, }, { - name: "Server error (500) x4, fail after retries", - statusCodes: []int{500, 500, 500, 500}, - expectedRespNil: false, // resp is returned but contains server error - expectedErr: true, - expectedAttempts: maxRetryAttempts, + name: "server error then success", + statusCodes: []int{http.StatusInternalServerError, http.StatusOK}, + wantAttempts: 2, }, { - name: "Server error (500) x2, then success (200)", - statusCodes: []int{500, 500, 200}, - expectedRespNil: false, - expectedErr: false, - expectedAttempts: 3, + name: "max retries on server error", + statusCodes: []int{http.StatusInternalServerError, http.StatusInternalServerError, http.StatusInternalServerError, http.StatusInternalServerError}, + expectedError: "max retries exceeded", + wantAttempts: 4, }, } - for _, tc := range testCases { - tc := tc // Capture range variable for parallel tests - t.Run(tc.name, func(t *testing.T) { - t.Parallel() // Enable parallel execution of subtests - attempt := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - if attempt < len(tc.statusCodes) { - w.WriteHeader(tc.statusCodes[attempt]) - } else { - // If more requests are made than status codes provided, repeat the last status code. - w.WriteHeader(tc.statusCodes[len(tc.statusCodes)-1]) + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + attempts := 0 + idx := 0 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + status := tt.statusCodes[idx] + if idx < len(tt.statusCodes)-1 { + idx++ } - attempt++ + + w.WriteHeader(status) + message := fmt.Sprintf("response-%d", attempts) + w.Write([]byte(message)) })) defer server.Close() - client := &http.Client{ - Timeout: 2 * time.Second, - } + client := &http.Client{Timeout: time.Second} resp, err := makeRetryRequest(func() (*http.Response, error) { - req, err := http.NewRequest(http.MethodPost, server.URL, nil) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - return client.Do(req) + return client.Get(server.URL) }) - if resp != nil { - defer resp.Body.Close() + + if attempts != tt.wantAttempts { + t.Errorf("got %d attempts, want %d", attempts, tt.wantAttempts) } - // Assertions - if tc.expectedRespNil && resp != nil { - t.Errorf("Expected response to be nil, but got: %v", resp) + if tt.expectedError != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.expectedError) + } + if !strings.Contains(err.Error(), tt.expectedError) { + t.Errorf("expected error containing %q, got %q", tt.expectedError, err) + } + return } - if !tc.expectedRespNil && resp == nil { - t.Errorf("Expected response to be non-nil, but got nil") + + if err != nil { + t.Fatalf("unexpected error: %v", err) } - if tc.expectedErr && err == nil { - t.Errorf("Expected an error, but got none") + if resp == nil { + t.Fatal("expected non-nil response") } - if !tc.expectedErr && err != nil { - t.Errorf("Did not expect an error, but got: %v", err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) } - if attempt != tc.expectedAttempts { - t.Errorf("Expected %d attempts, but got: %d", tc.expectedAttempts, attempt) + expectedBody := fmt.Sprintf("response-%d", attempts) + if string(body) != expectedBody { + t.Errorf("got body %q, want %q", string(body), expectedBody) } }) } From e81d59d7aed6d7f945ea630ba150dff3f70a5b32 Mon Sep 17 00:00:00 2001 From: Vishal Gawade Date: Mon, 11 Nov 2024 16:09:54 -0600 Subject: [PATCH 7/8] removed unused dependency in go.mod, reverted constants to original format, and restored jitter implementation to avoid code duplication in makeRetryRequest function --- go.mod | 1 - pkg/osv/osv.go | 34 +++++++++++----------------------- 2 files changed, 11 insertions(+), 24 deletions(-) diff --git a/go.mod b/go.mod index c5896ae86a..34e2cc6302 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,6 @@ require ( github.com/package-url/packageurl-go v0.1.3 github.com/pandatix/go-cvss v0.6.2 github.com/spdx/tools-golang v0.5.5 - github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.18.0 github.com/tidwall/pretty v1.2.1 github.com/tidwall/sjson v1.2.5 diff --git a/pkg/osv/osv.go b/pkg/osv/osv.go index c37a34d574..cd68c0b33a 100644 --- a/pkg/osv/osv.go +++ b/pkg/osv/osv.go @@ -16,7 +16,7 @@ import ( "golang.org/x/sync/errgroup" ) -var ( +const ( // QueryEndpoint is the URL for posting queries to OSV. QueryEndpoint = "https://api.osv.dev/v1/querybatch" // GetEndpoint is the URL for getting vulnerabilities from OSV. @@ -293,49 +293,37 @@ func HydrateWithClient(resp *BatchedResponse, client *http.Client) (*HydratedBat // makeRetryRequest executes HTTP requests with exponential backoff retry logic func makeRetryRequest(action func() (*http.Response, error)) (*http.Response, error) { - const maxRetryAttempts = 4 - const jitterMultiplier = 2 - var lastErr error for i := range maxRetryAttempts { + // rand is initialized with a random number (since go1.20), and is also safe to use concurrently + // we do not need to use a cryptographically secure random jitter, this is just to spread out the retry requests + // #nosec G404 + jitterAmount := (rand.Float64() * float64(jitterMultiplier) * float64(i)) + time.Sleep(time.Duration(i*i)*time.Second + time.Duration(jitterAmount*1000)*time.Millisecond) + resp, err := action() if err != nil { lastErr = fmt.Errorf("attempt %d: request failed: %w", i+1, err) - if i == maxRetryAttempts-1 { - break - } - backoff := time.Duration(i*i) * time.Second - jitter := time.Duration(rand.Float64() * float64(jitterMultiplier) * float64(time.Second)) - time.Sleep(backoff + jitter) continue } - // Check response validity if resp.StatusCode >= 200 && resp.StatusCode < 300 { return resp, nil } - // Read error response body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { lastErr = fmt.Errorf("attempt %d: failed to read response: %w", i+1, err) - } else if resp.StatusCode >= 400 && resp.StatusCode < 500 { - // Client errors (4xx) - return immediately - return nil, fmt.Errorf("client error: status=%d body=%s", resp.StatusCode, body) - } else { - // Server errors (5xx) - can retry - lastErr = fmt.Errorf("server error: status=%d body=%s", resp.StatusCode, body) + continue } - if i == maxRetryAttempts-1 { - break + if resp.StatusCode >= 400 && resp.StatusCode < 500 { + return nil, fmt.Errorf("client error: status=%d body=%s", resp.StatusCode, body) } - backoff := time.Duration(i*i) * time.Second - jitter := time.Duration(rand.Float64() * float64(jitterMultiplier) * float64(time.Second)) - time.Sleep(backoff + jitter) + lastErr = fmt.Errorf("server error: status=%d body=%s", resp.StatusCode, body) } return nil, fmt.Errorf("max retries exceeded: %w", lastErr) From a1a3e9363b171fae9dad927a6ec0f83e631d482b Mon Sep 17 00:00:00 2001 From: Vishal Gawade Date: Fri, 15 Nov 2024 00:19:18 -0600 Subject: [PATCH 8/8] Adding error handling for status code 429 --- pkg/osv/osv.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pkg/osv/osv.go b/pkg/osv/osv.go index cd68c0b33a..0faf7da1a6 100644 --- a/pkg/osv/osv.go +++ b/pkg/osv/osv.go @@ -319,6 +319,11 @@ func makeRetryRequest(action func() (*http.Response, error)) (*http.Response, er continue } + if resp.StatusCode == 429 { + lastErr = fmt.Errorf("attempt %d: too many requests: status=%d body=%s", i+1, resp.StatusCode, body) + continue + } + if resp.StatusCode >= 400 && resp.StatusCode < 500 { return nil, fmt.Errorf("client error: status=%d body=%s", resp.StatusCode, body) }