-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: adds azure metadata provider to cloudmeta pkg. (#4159)
feat: adds azure metadata provider to cloudmeta pkg. This adds azure metadata provider to cloudmeta pkg. Plus adds retries to azure metadata service using the github.com/hashicorp/go-retryablehttp package. Also adds some unit tests. Fixes: #4129
- Loading branch information
1 parent
0372277
commit 2fde888
Showing
31 changed files
with
2,313 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
// Copyright (C) 2024 ScyllaDB | ||
|
||
package cloudmeta | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"net/http" | ||
"time" | ||
|
||
"github.com/hashicorp/go-retryablehttp" | ||
"github.com/pkg/errors" | ||
"github.com/scylladb/go-log" | ||
) | ||
|
||
// azureBaseURL is a base url of azure metadata service. | ||
const azureBaseURL = "http://169.254.169.254/metadata" | ||
|
||
// azureMetadata is a wrapper around azure metadata service. | ||
type azureMetadata struct { | ||
client *http.Client | ||
|
||
baseURL string | ||
} | ||
|
||
// newAzureMetadata returns AzureMetadata service. | ||
func newAzureMetadata(logger log.Logger) *azureMetadata { | ||
return &azureMetadata{ | ||
client: defaultClient(logger), | ||
baseURL: azureBaseURL, | ||
} | ||
} | ||
|
||
func defaultClient(logger log.Logger) *http.Client { | ||
client := retryablehttp.NewClient() | ||
|
||
client.RetryMax = 3 | ||
client.RetryWaitMin = 500 * time.Millisecond | ||
client.RetryWaitMax = 5 * time.Second | ||
client.Logger = &logWrapper{ | ||
logger: logger, | ||
} | ||
|
||
transport := http.DefaultTransport.(*http.Transport).Clone() | ||
// we must not use proxy for the metadata requests - see https://learn.microsoft.com/en-us/azure/virtual-machines/instance-metadata-service?tabs=linux#proxies. | ||
transport.Proxy = nil | ||
|
||
client.HTTPClient = &http.Client{ | ||
// Quite small timeout per request, because we have retries and also it's a local network call. | ||
Timeout: 1 * time.Second, | ||
Transport: transport, | ||
} | ||
return client.StandardClient() | ||
} | ||
|
||
// Metadata return InstanceMetadata from azure if available. | ||
func (azure *azureMetadata) Metadata(ctx context.Context) (InstanceMetadata, error) { | ||
vmSize, err := azure.getVMSize(ctx) | ||
if err != nil { | ||
return InstanceMetadata{}, errors.Wrap(err, "azure.getVMSize") | ||
} | ||
if vmSize == "" { | ||
return InstanceMetadata{}, errors.New("azure vmSize is empty") | ||
} | ||
return InstanceMetadata{ | ||
CloudProvider: CloudProviderAzure, | ||
InstanceType: vmSize, | ||
}, nil | ||
} | ||
|
||
// azureAPIVersion should be present in every request to metadata service in query parameter. | ||
const azureAPIVersion = "2023-07-01" | ||
|
||
func (azure *azureMetadata) getVMSize(ctx context.Context) (string, error) { | ||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, azure.baseURL+"/instance", http.NoBody) | ||
if err != nil { | ||
return "", errors.Wrap(err, "http new request") | ||
} | ||
|
||
// Setting required headers and query parameters - https://learn.microsoft.com/en-us/azure/virtual-machines/instance-metadata-service?tabs=linux#security-and-authentication. | ||
req.Header.Add("Metadata", "true") | ||
query := req.URL.Query() | ||
query.Add("api-version", azureAPIVersion) | ||
req.URL.RawQuery = query.Encode() | ||
|
||
resp, err := azure.client.Do(req) | ||
if err != nil { | ||
return "", errors.Wrap(err, "azure.client.Do") | ||
} | ||
defer resp.Body.Close() | ||
|
||
if resp.StatusCode != http.StatusOK { | ||
return "", errors.Errorf("status code (%d) != 200", resp.StatusCode) | ||
} | ||
|
||
var data azureMetadataResponse | ||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { | ||
return "", errors.Wrap(err, "decode json") | ||
} | ||
|
||
return data.Compute.VMSize, nil | ||
} | ||
|
||
// azureMetadataResponse represents azure metadata service response. | ||
// full response specification can be found here - https://learn.microsoft.com/en-us/azure/virtual-machines/instance-metadata-service?tabs=linux#response-1. | ||
type azureMetadataResponse struct { | ||
Compute azureCompute `json:"compute"` | ||
} | ||
|
||
type azureCompute struct { | ||
VMSize string `json:"vmSize"` | ||
} | ||
|
||
// logWrapper implements go-retryablehttp.LeveledLogger interface. | ||
type logWrapper struct { | ||
logger log.Logger | ||
} | ||
|
||
// Info wraps logger.Info method. | ||
func (log *logWrapper) Info(msg string, keyVals ...interface{}) { | ||
log.logger.Info(context.Background(), msg, keyVals...) | ||
} | ||
|
||
// Error wraps logger.Error method. | ||
func (log *logWrapper) Error(msg string, keyVals ...interface{}) { | ||
log.logger.Error(context.Background(), msg, keyVals...) | ||
} | ||
|
||
// Warn wraps logger.Error method. | ||
func (log *logWrapper) Warn(msg string, keyVals ...interface{}) { | ||
log.logger.Error(context.Background(), msg, keyVals...) | ||
} | ||
|
||
// Debug wraps logger.Debug method. | ||
func (log *logWrapper) Debug(msg string, keyVals ...interface{}) { | ||
log.logger.Debug(context.Background(), msg, keyVals...) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
// Copyright (C) 2024 ScyllaDB | ||
|
||
package cloudmeta | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/scylladb/go-log" | ||
) | ||
|
||
func TestAzureMetadata(t *testing.T) { | ||
testCases := []struct { | ||
name string | ||
handler http.Handler | ||
|
||
expectedCalls int | ||
expectedErr bool | ||
expectedMeta InstanceMetadata | ||
}{ | ||
{ | ||
name: "when response is 200", | ||
handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
testCheckRequireParams(t, r) | ||
|
||
w.Write([]byte(`{"compute":{"vmSize":"Standard-A3"}}`)) | ||
}), | ||
expectedCalls: 1, | ||
expectedErr: false, | ||
expectedMeta: InstanceMetadata{ | ||
CloudProvider: CloudProviderAzure, | ||
InstanceType: "Standard-A3", | ||
}, | ||
}, | ||
{ | ||
name: "when response is 404: not retryable", | ||
handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
testCheckRequireParams(t, r) | ||
|
||
w.WriteHeader(http.StatusNotFound) | ||
w.Write([]byte(`internal server error`)) | ||
}), | ||
expectedCalls: 1, | ||
expectedErr: true, | ||
expectedMeta: InstanceMetadata{}, | ||
}, | ||
{ | ||
name: "when response is 500: retryable", | ||
handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
testCheckRequireParams(t, r) | ||
|
||
w.WriteHeader(http.StatusInternalServerError) | ||
w.Write([]byte(`internal server error`)) | ||
}), | ||
expectedCalls: 4, | ||
expectedErr: true, | ||
expectedMeta: InstanceMetadata{}, | ||
}, | ||
} | ||
|
||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
handler := &testHandler{Handler: tc.handler} | ||
testSrv := httptest.NewServer(handler) | ||
defer testSrv.Close() | ||
|
||
azureMeta := newAzureMetadata(log.NewDevelopment()) | ||
azureMeta.baseURL = testSrv.URL | ||
|
||
meta, err := azureMeta.Metadata(context.Background()) | ||
if tc.expectedErr && err == nil { | ||
t.Fatalf("expected err: %v\n", err) | ||
} | ||
if !tc.expectedErr && err != nil { | ||
t.Fatalf("unexpected err: %v\n", err) | ||
} | ||
|
||
if tc.expectedCalls != handler.calls { | ||
t.Fatalf("unexected number of calls: %d != %d", handler.calls, tc.expectedCalls) | ||
} | ||
|
||
if meta.CloudProvider != tc.expectedMeta.CloudProvider { | ||
t.Fatalf("unexpected cloud provider: %s", meta.CloudProvider) | ||
} | ||
|
||
if meta.InstanceType != tc.expectedMeta.InstanceType { | ||
t.Fatalf("unexpected instance type: %s", meta.InstanceType) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
type testHandler struct { | ||
http.Handler | ||
// Keep track of how many times handler func has been called | ||
// so we can test retries policy. | ||
calls int | ||
} | ||
|
||
func (th *testHandler) ServeHTTP(w http.ResponseWriter, t *http.Request) { | ||
th.calls++ | ||
th.Handler.ServeHTTP(w, t) | ||
} | ||
|
||
func testCheckRequireParams(t *testing.T, r *http.Request) { | ||
t.Helper() | ||
metadataHeader := r.Header.Get("Metadata") | ||
if metadataHeader != "true" { | ||
t.Fatalf("Metadata: true header is required") | ||
} | ||
apiVersion := r.URL.Query().Get("api-version") | ||
if apiVersion != azureAPIVersion { | ||
t.Fatalf("unexpected ?api-version: %s", apiVersion) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.