-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This major version release utilizes the latest version of the aws-sdk-go-v2. The following behavioral changes are included in this major version release: - Custom endpoint resolvers are attached to the STS and IAM clients, not to the credentials. This is apart of the aws-sdk-go-v2 EndpointResolverV2 feature. - By default, aws credential configurations will load values from environment variables. The user provided options will overload the default values. - The ability to mock out the underlying credential provider for unit testing. Changed behaviors from awsutil v1 includes the following: - Replaced aws errors with aws smithy-go errors - No longer able to utilize the aws default remote credential provider - The function GenerateCredentialChain returns a aws.Config, which contains the credential provider.
- Loading branch information
Showing
17 changed files
with
3,563 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
// Copyright (c) HashiCorp, Inc. | ||
// SPDX-License-Identifier: MPL-2.0 | ||
|
||
package awsutil | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net/url" | ||
|
||
"github.com/aws/aws-sdk-go-v2/aws" | ||
"github.com/aws/aws-sdk-go-v2/service/iam" | ||
"github.com/aws/aws-sdk-go-v2/service/sts" | ||
smithyEndpoints "github.com/aws/smithy-go/endpoints" | ||
) | ||
|
||
// IAMAPIFunc is a factory function for returning an IAM interface, | ||
// useful for supplying mock interfaces for testing IAM. | ||
type IAMAPIFunc func(awsConfig *aws.Config) (IAMClient, error) | ||
|
||
// IAMClient represents an iam.Client | ||
type IAMClient interface { | ||
CreateAccessKey(context.Context, *iam.CreateAccessKeyInput, ...func(*iam.Options)) (*iam.CreateAccessKeyOutput, error) | ||
DeleteAccessKey(context.Context, *iam.DeleteAccessKeyInput, ...func(*iam.Options)) (*iam.DeleteAccessKeyOutput, error) | ||
ListAccessKeys(context.Context, *iam.ListAccessKeysInput, ...func(*iam.Options)) (*iam.ListAccessKeysOutput, error) | ||
GetUser(context.Context, *iam.GetUserInput, ...func(*iam.Options)) (*iam.GetUserOutput, error) | ||
} | ||
|
||
// STSAPIFunc is a factory function for returning a STS interface, | ||
// useful for supplying mock interfaces for testing STS. | ||
type STSAPIFunc func(awsConfig *aws.Config) (STSClient, error) | ||
|
||
// STSClient represents an sts.Client | ||
type STSClient interface { | ||
AssumeRole(context.Context, *sts.AssumeRoleInput, ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) | ||
GetCallerIdentity(context.Context, *sts.GetCallerIdentityInput, ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) | ||
} | ||
|
||
// IAMClient returns an IAM client. | ||
// | ||
// Supported options: WithAwsConfig, WithIAMAPIFunc, WithIamEndpoint. | ||
// | ||
// If WithIAMAPIFunc is supplied, the included function is used as | ||
// the IAM client constructor instead. This can be used for Mocking | ||
// the IAM API. | ||
func (c *CredentialsConfig) IAMClient(ctx context.Context, opt ...Option) (IAMClient, error) { | ||
opts, err := getOpts(opt...) | ||
if err != nil { | ||
return nil, fmt.Errorf("error reading options: %w", err) | ||
} | ||
|
||
cfg := opts.withAwsConfig | ||
if cfg == nil { | ||
cfg, err = c.GenerateCredentialChain(ctx, opt...) | ||
if err != nil { | ||
return nil, fmt.Errorf("error calling GenerateCredentialChain: %w", err) | ||
} | ||
} | ||
|
||
if opts.withIAMAPIFunc != nil { | ||
return opts.withIAMAPIFunc(cfg) | ||
} | ||
|
||
var iamOpts []func(*iam.Options) | ||
if c.IAMEndpoint != "" { | ||
iamOpts = append(iamOpts, iam.WithEndpointResolverV2(&iamEndpointResolver{ | ||
endpoint: c.IAMEndpoint, | ||
})) | ||
} | ||
|
||
return iam.NewFromConfig(*cfg, iamOpts...), nil | ||
} | ||
|
||
// STSClient returns a STS client. | ||
// | ||
// Supported options: WithAwsConfig, WithSTSAPIFunc, WithStsEndpoint. | ||
// | ||
// If WithSTSAPIFunc is supplied, the included function is used as | ||
// the STS client constructor instead. This can be used for Mocking | ||
// the STS API. | ||
func (c *CredentialsConfig) STSClient(ctx context.Context, opt ...Option) (STSClient, error) { | ||
opts, err := getOpts(opt...) | ||
if err != nil { | ||
return nil, fmt.Errorf("error reading options: %w", err) | ||
} | ||
|
||
cfg := opts.withAwsConfig | ||
if cfg == nil { | ||
cfg, err = c.GenerateCredentialChain(ctx, opt...) | ||
if err != nil { | ||
return nil, fmt.Errorf("error calling GenerateCredentialChain: %w", err) | ||
} | ||
} | ||
|
||
if opts.withSTSAPIFunc != nil { | ||
return opts.withSTSAPIFunc(cfg) | ||
} | ||
|
||
var stsOpts []func(*sts.Options) | ||
if c.STSEndpoint != "" { | ||
stsOpts = append(stsOpts, sts.WithEndpointResolverV2(&stsEndpointResolver{ | ||
endpoint: c.STSEndpoint, | ||
})) | ||
} | ||
|
||
return sts.NewFromConfig(*cfg, stsOpts...), nil | ||
} | ||
|
||
// iamEndpointResolver is a implementation of the iam.EndpointResolverV2 interface. | ||
// The iamEndpointResolver is used with the IAMClient when the CredentialConfig | ||
// is configured with a custom endpoint for iam. | ||
type iamEndpointResolver struct { | ||
endpoint string | ||
} | ||
|
||
// ResolveEndpoint returns a smithyEndpoint resolver for the iam endpoint given | ||
// to the CredentialConfig. The resolver is set to the default endpoint resovler | ||
// provided by aws when the custom iam endpoint is empty or is not a parsable url. | ||
func (e *iamEndpointResolver) ResolveEndpoint(ctx context.Context, params iam.EndpointParameters) (resolver smithyEndpoints.Endpoint, err error) { | ||
var uri *url.URL | ||
resolver, err = iam.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) | ||
if err != nil { | ||
return | ||
} | ||
if e.endpoint == "" { | ||
// TODO: not sure if we should set an error or ignore the missing endpoint value and continue with the default resolver | ||
return | ||
} | ||
uri, err = url.Parse(e.endpoint) | ||
if err != nil { | ||
return | ||
} | ||
if uri == nil { | ||
// TODO: not sure if we should set an error or ignore the missing uri value and continue with the default resolver | ||
return | ||
} | ||
resolver = smithyEndpoints.Endpoint{ | ||
URI: *uri, | ||
} | ||
return | ||
} | ||
|
||
// stsEndpointResolver is a implementation of the sts.EndpointResolverV2 interface. | ||
// The stsEndpointResolver is used with the STSClient when the CredentialConfig | ||
// is configured with a custom endpoint for sts. | ||
type stsEndpointResolver struct { | ||
endpoint string | ||
} | ||
|
||
// ResolveEndpoint returns a smithyEndpoint resolver for the sts endpoint given | ||
// to the CredentialConfig. The resolver is set to the default endpoint resovler | ||
// provided by aws when the custom sts endpoint is empty or is not a parsable url. | ||
func (e *stsEndpointResolver) ResolveEndpoint(ctx context.Context, params sts.EndpointParameters) (resolver smithyEndpoints.Endpoint, err error) { | ||
var uri *url.URL | ||
resolver, err = sts.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) | ||
if err != nil { | ||
return | ||
} | ||
if e.endpoint == "" { | ||
// TODO: not sure if we should set an error or ignore the missing endpoint value and continue with the default resolver | ||
return | ||
} | ||
uri, err = url.Parse(e.endpoint) | ||
if err != nil { | ||
return | ||
} | ||
if uri == nil { | ||
// TODO: not sure if we should set an error or ignore the missing uri value and continue with the default resolver | ||
return | ||
} | ||
resolver = smithyEndpoints.Endpoint{ | ||
URI: *uri, | ||
} | ||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
// Copyright (c) HashiCorp, Inc. | ||
// SPDX-License-Identifier: MPL-2.0 | ||
|
||
package awsutil | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"testing" | ||
|
||
"github.com/aws/aws-sdk-go-v2/service/iam" | ||
"github.com/aws/aws-sdk-go-v2/service/sts" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
const testOptionErr = "test option error" | ||
|
||
func TestCredentialsConfigIAMClient(t *testing.T) { | ||
cases := []struct { | ||
name string | ||
credentialsConfig *CredentialsConfig | ||
opts []Option | ||
require func(t *testing.T, actual IAMClient) | ||
requireErr string | ||
}{ | ||
{ | ||
name: "options error", | ||
credentialsConfig: &CredentialsConfig{}, | ||
opts: []Option{MockOptionErr(errors.New(testOptionErr))}, | ||
requireErr: fmt.Sprintf("error reading options: %s", testOptionErr), | ||
}, | ||
{ | ||
name: "with mock IAM session", | ||
credentialsConfig: &CredentialsConfig{}, | ||
opts: []Option{WithIAMAPIFunc(NewMockIAM())}, | ||
require: func(t *testing.T, actual IAMClient) { | ||
t.Helper() | ||
require := require.New(t) | ||
require.Equal(&MockIAM{}, actual) | ||
}, | ||
}, | ||
{ | ||
name: "no mock client", | ||
credentialsConfig: &CredentialsConfig{}, | ||
opts: []Option{}, | ||
require: func(t *testing.T, actual IAMClient) { | ||
t.Helper() | ||
require := require.New(t) | ||
require.IsType(&iam.Client{}, actual) | ||
}, | ||
}, | ||
} | ||
|
||
for _, tc := range cases { | ||
tc := tc | ||
t.Run(tc.name, func(t *testing.T) { | ||
require := require.New(t) | ||
actual, err := tc.credentialsConfig.IAMClient(context.TODO(), tc.opts...) | ||
if tc.requireErr != "" { | ||
require.EqualError(err, tc.requireErr) | ||
return | ||
} | ||
|
||
require.NoError(err) | ||
tc.require(t, actual) | ||
}) | ||
} | ||
} | ||
|
||
func TestCredentialsConfigSTSClient(t *testing.T) { | ||
cases := []struct { | ||
name string | ||
credentialsConfig *CredentialsConfig | ||
opts []Option | ||
require func(t *testing.T, actual STSClient) | ||
requireErr string | ||
}{ | ||
{ | ||
name: "options error", | ||
credentialsConfig: &CredentialsConfig{}, | ||
opts: []Option{MockOptionErr(errors.New(testOptionErr))}, | ||
requireErr: fmt.Sprintf("error reading options: %s", testOptionErr), | ||
}, | ||
{ | ||
name: "with mock STS session", | ||
credentialsConfig: &CredentialsConfig{}, | ||
opts: []Option{WithSTSAPIFunc(NewMockSTS())}, | ||
require: func(t *testing.T, actual STSClient) { | ||
t.Helper() | ||
require := require.New(t) | ||
require.Equal(&MockSTS{}, actual) | ||
}, | ||
}, | ||
{ | ||
name: "no mock client", | ||
credentialsConfig: &CredentialsConfig{}, | ||
opts: []Option{}, | ||
require: func(t *testing.T, actual STSClient) { | ||
t.Helper() | ||
require := require.New(t) | ||
require.IsType(&sts.Client{}, actual) | ||
}, | ||
}, | ||
} | ||
|
||
for _, tc := range cases { | ||
tc := tc | ||
t.Run(tc.name, func(t *testing.T) { | ||
require := require.New(t) | ||
actual, err := tc.credentialsConfig.STSClient(context.TODO(), tc.opts...) | ||
if tc.requireErr != "" { | ||
require.EqualError(err, tc.requireErr) | ||
return | ||
} | ||
|
||
require.NoError(err) | ||
tc.require(t, actual) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Copyright (c) HashiCorp, Inc. | ||
// SPDX-License-Identifier: MPL-2.0 | ||
|
||
package awsutil | ||
|
||
import ( | ||
"errors" | ||
|
||
"github.com/aws/aws-sdk-go-v2/aws/retry" | ||
multierror "github.com/hashicorp/go-multierror" | ||
) | ||
|
||
var ErrUpstreamRateLimited = errors.New("upstream rate limited") | ||
|
||
// CheckAWSError will examine an error and convert to a logical error if | ||
// appropriate. If no appropriate error is found, return nil | ||
func CheckAWSError(err error) error { | ||
retryErr := retry.ThrottleErrorCode{ | ||
Codes: retry.DefaultThrottleErrorCodes, | ||
} | ||
if retryErr.IsErrorThrottle(err).Bool() { | ||
return ErrUpstreamRateLimited | ||
} | ||
return nil | ||
} | ||
|
||
// AppendAWSError checks if the given error is a known AWS error we modify, | ||
// and if so then returns a go-multierror, appending the original and the | ||
// AWS error. | ||
// If the error is not an AWS error, or not an error we wish to modify, then | ||
// return the original error. | ||
func AppendAWSError(err error) error { | ||
if awserr := CheckAWSError(err); awserr != nil { | ||
err = multierror.Append(err, awserr) | ||
} | ||
return err | ||
} |
Oops, something went wrong.