Skip to content

Commit

Permalink
feat(awsutil): major version update
Browse files Browse the repository at this point in the history
This major version release utilizes the latest version
of the aws-sdk-go-v2. The following behavioral changes
are included in this major version release:

- Custom endpoint resolvers are attached to the STS and
IAM clients, not to the credentials. This is apart of the
aws-sdk-go-v2 EndpointResolverV2 feature.
- By default, aws credential configurations will load values
from environment variables. The user provided options will
overload the default values.
- The ability to mock out the underlying credential provider
for unit testing.

Changed behaviors from awsutil v1 includes the following:

- Replaced aws errors with aws smithy-go errors
- No longer able to utilize the aws default remote credential
provider
- The function GenerateCredentialChain returns a aws.Config,
which contains the credential provider.
  • Loading branch information
ddebko committed Aug 25, 2023
1 parent 72fcd87 commit 1d97ab4
Show file tree
Hide file tree
Showing 17 changed files with 3,563 additions and 0 deletions.
365 changes: 365 additions & 0 deletions awsutil/v2/LICENSE

Large diffs are not rendered by default.

175 changes: 175 additions & 0 deletions awsutil/v2/clients.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package awsutil

import (
"context"
"fmt"
"net/url"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/iam"
"github.com/aws/aws-sdk-go-v2/service/sts"
smithyEndpoints "github.com/aws/smithy-go/endpoints"
)

// IAMAPIFunc is a factory function for returning an IAM interface,
// useful for supplying mock interfaces for testing IAM.
type IAMAPIFunc func(awsConfig *aws.Config) (IAMClient, error)

// IAMClient represents an iam.Client
type IAMClient interface {
CreateAccessKey(context.Context, *iam.CreateAccessKeyInput, ...func(*iam.Options)) (*iam.CreateAccessKeyOutput, error)
DeleteAccessKey(context.Context, *iam.DeleteAccessKeyInput, ...func(*iam.Options)) (*iam.DeleteAccessKeyOutput, error)
ListAccessKeys(context.Context, *iam.ListAccessKeysInput, ...func(*iam.Options)) (*iam.ListAccessKeysOutput, error)
GetUser(context.Context, *iam.GetUserInput, ...func(*iam.Options)) (*iam.GetUserOutput, error)
}

// STSAPIFunc is a factory function for returning a STS interface,
// useful for supplying mock interfaces for testing STS.
type STSAPIFunc func(awsConfig *aws.Config) (STSClient, error)

// STSClient represents an sts.Client
type STSClient interface {
AssumeRole(context.Context, *sts.AssumeRoleInput, ...func(*sts.Options)) (*sts.AssumeRoleOutput, error)
GetCallerIdentity(context.Context, *sts.GetCallerIdentityInput, ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error)
}

// IAMClient returns an IAM client.
//
// Supported options: WithAwsConfig, WithIAMAPIFunc, WithIamEndpoint.
//
// If WithIAMAPIFunc is supplied, the included function is used as
// the IAM client constructor instead. This can be used for Mocking
// the IAM API.
func (c *CredentialsConfig) IAMClient(ctx context.Context, opt ...Option) (IAMClient, error) {
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("error reading options: %w", err)
}

cfg := opts.withAwsConfig
if cfg == nil {
cfg, err = c.GenerateCredentialChain(ctx, opt...)
if err != nil {
return nil, fmt.Errorf("error calling GenerateCredentialChain: %w", err)
}
}

if opts.withIAMAPIFunc != nil {
return opts.withIAMAPIFunc(cfg)
}

var iamOpts []func(*iam.Options)
if c.IAMEndpoint != "" {
iamOpts = append(iamOpts, iam.WithEndpointResolverV2(&iamEndpointResolver{
endpoint: c.IAMEndpoint,
}))
}

return iam.NewFromConfig(*cfg, iamOpts...), nil
}

// STSClient returns a STS client.
//
// Supported options: WithAwsConfig, WithSTSAPIFunc, WithStsEndpoint.
//
// If WithSTSAPIFunc is supplied, the included function is used as
// the STS client constructor instead. This can be used for Mocking
// the STS API.
func (c *CredentialsConfig) STSClient(ctx context.Context, opt ...Option) (STSClient, error) {
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("error reading options: %w", err)
}

cfg := opts.withAwsConfig
if cfg == nil {
cfg, err = c.GenerateCredentialChain(ctx, opt...)
if err != nil {
return nil, fmt.Errorf("error calling GenerateCredentialChain: %w", err)
}
}

if opts.withSTSAPIFunc != nil {
return opts.withSTSAPIFunc(cfg)
}

var stsOpts []func(*sts.Options)
if c.STSEndpoint != "" {
stsOpts = append(stsOpts, sts.WithEndpointResolverV2(&stsEndpointResolver{
endpoint: c.STSEndpoint,
}))
}

return sts.NewFromConfig(*cfg, stsOpts...), nil
}

// iamEndpointResolver is a implementation of the iam.EndpointResolverV2 interface.
// The iamEndpointResolver is used with the IAMClient when the CredentialConfig
// is configured with a custom endpoint for iam.
type iamEndpointResolver struct {
endpoint string
}

// ResolveEndpoint returns a smithyEndpoint resolver for the iam endpoint given
// to the CredentialConfig. The resolver is set to the default endpoint resovler
// provided by aws when the custom iam endpoint is empty or is not a parsable url.
func (e *iamEndpointResolver) ResolveEndpoint(ctx context.Context, params iam.EndpointParameters) (resolver smithyEndpoints.Endpoint, err error) {
var uri *url.URL
resolver, err = iam.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params)
if err != nil {
return
}
if e.endpoint == "" {
// TODO: not sure if we should set an error or ignore the missing endpoint value and continue with the default resolver
return
}
uri, err = url.Parse(e.endpoint)
if err != nil {
return
}
if uri == nil {
// TODO: not sure if we should set an error or ignore the missing uri value and continue with the default resolver
return
}
resolver = smithyEndpoints.Endpoint{
URI: *uri,
}
return
}

// stsEndpointResolver is a implementation of the sts.EndpointResolverV2 interface.
// The stsEndpointResolver is used with the STSClient when the CredentialConfig
// is configured with a custom endpoint for sts.
type stsEndpointResolver struct {
endpoint string
}

// ResolveEndpoint returns a smithyEndpoint resolver for the sts endpoint given
// to the CredentialConfig. The resolver is set to the default endpoint resovler
// provided by aws when the custom sts endpoint is empty or is not a parsable url.
func (e *stsEndpointResolver) ResolveEndpoint(ctx context.Context, params sts.EndpointParameters) (resolver smithyEndpoints.Endpoint, err error) {
var uri *url.URL
resolver, err = sts.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params)
if err != nil {
return
}
if e.endpoint == "" {
// TODO: not sure if we should set an error or ignore the missing endpoint value and continue with the default resolver
return
}
uri, err = url.Parse(e.endpoint)
if err != nil {
return
}
if uri == nil {
// TODO: not sure if we should set an error or ignore the missing uri value and continue with the default resolver
return
}
resolver = smithyEndpoints.Endpoint{
URI: *uri,
}
return
}
121 changes: 121 additions & 0 deletions awsutil/v2/clients_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package awsutil

import (
"context"
"errors"
"fmt"
"testing"

"github.com/aws/aws-sdk-go-v2/service/iam"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/stretchr/testify/require"
)

const testOptionErr = "test option error"

func TestCredentialsConfigIAMClient(t *testing.T) {
cases := []struct {
name string
credentialsConfig *CredentialsConfig
opts []Option
require func(t *testing.T, actual IAMClient)
requireErr string
}{
{
name: "options error",
credentialsConfig: &CredentialsConfig{},
opts: []Option{MockOptionErr(errors.New(testOptionErr))},
requireErr: fmt.Sprintf("error reading options: %s", testOptionErr),
},
{
name: "with mock IAM session",
credentialsConfig: &CredentialsConfig{},
opts: []Option{WithIAMAPIFunc(NewMockIAM())},
require: func(t *testing.T, actual IAMClient) {
t.Helper()
require := require.New(t)
require.Equal(&MockIAM{}, actual)
},
},
{
name: "no mock client",
credentialsConfig: &CredentialsConfig{},
opts: []Option{},
require: func(t *testing.T, actual IAMClient) {
t.Helper()
require := require.New(t)
require.IsType(&iam.Client{}, actual)
},
},
}

for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
actual, err := tc.credentialsConfig.IAMClient(context.TODO(), tc.opts...)
if tc.requireErr != "" {
require.EqualError(err, tc.requireErr)
return
}

require.NoError(err)
tc.require(t, actual)
})
}
}

func TestCredentialsConfigSTSClient(t *testing.T) {
cases := []struct {
name string
credentialsConfig *CredentialsConfig
opts []Option
require func(t *testing.T, actual STSClient)
requireErr string
}{
{
name: "options error",
credentialsConfig: &CredentialsConfig{},
opts: []Option{MockOptionErr(errors.New(testOptionErr))},
requireErr: fmt.Sprintf("error reading options: %s", testOptionErr),
},
{
name: "with mock STS session",
credentialsConfig: &CredentialsConfig{},
opts: []Option{WithSTSAPIFunc(NewMockSTS())},
require: func(t *testing.T, actual STSClient) {
t.Helper()
require := require.New(t)
require.Equal(&MockSTS{}, actual)
},
},
{
name: "no mock client",
credentialsConfig: &CredentialsConfig{},
opts: []Option{},
require: func(t *testing.T, actual STSClient) {
t.Helper()
require := require.New(t)
require.IsType(&sts.Client{}, actual)
},
},
}

for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
actual, err := tc.credentialsConfig.STSClient(context.TODO(), tc.opts...)
if tc.requireErr != "" {
require.EqualError(err, tc.requireErr)
return
}

require.NoError(err)
tc.require(t, actual)
})
}
}
37 changes: 37 additions & 0 deletions awsutil/v2/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package awsutil

import (
"errors"

"github.com/aws/aws-sdk-go-v2/aws/retry"
multierror "github.com/hashicorp/go-multierror"
)

var ErrUpstreamRateLimited = errors.New("upstream rate limited")

// CheckAWSError will examine an error and convert to a logical error if
// appropriate. If no appropriate error is found, return nil
func CheckAWSError(err error) error {
retryErr := retry.ThrottleErrorCode{
Codes: retry.DefaultThrottleErrorCodes,
}
if retryErr.IsErrorThrottle(err).Bool() {
return ErrUpstreamRateLimited
}
return nil
}

// AppendAWSError checks if the given error is a known AWS error we modify,
// and if so then returns a go-multierror, appending the original and the
// AWS error.
// If the error is not an AWS error, or not an error we wish to modify, then
// return the original error.
func AppendAWSError(err error) error {
if awserr := CheckAWSError(err); awserr != nil {
err = multierror.Append(err, awserr)
}
return err
}
Loading

0 comments on commit 1d97ab4

Please sign in to comment.