Skip to content

Commit

Permalink
[SNOW-1671768] Ignore max retries error when getting acelerate config (
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-astachowski authored Sep 27, 2024
1 parent 3428b00 commit a26ac8a
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 39 deletions.
77 changes: 38 additions & 39 deletions file_transfer_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"io"
"math"
Expand All @@ -24,7 +23,6 @@ import (
"time"

"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/smithy-go"
"github.com/gabriel-vasile/mimetype"
)

Expand Down Expand Up @@ -591,46 +589,47 @@ func (sfa *snowflakeFileTransferAgent) updateFileMetadataWithPresignedURL() erro
return nil
}

type s3BucketAccelerateConfigGetter interface {
GetBucketAccelerateConfiguration(ctx context.Context, params *s3.GetBucketAccelerateConfigurationInput, optFns ...func(*s3.Options)) (*s3.GetBucketAccelerateConfigurationOutput, error)
}

type s3ClientCreator interface {
extractBucketNameAndPath(location string) (*s3Location, error)
createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error)
}

func (sfa *snowflakeFileTransferAgent) transferAccelerateConfigWithUtil(s3Util s3ClientCreator) error {
s3Loc, err := s3Util.extractBucketNameAndPath(sfa.stageInfo.Location)
if err != nil {
return err
}
s3Cli, err := s3Util.createClient(sfa.stageInfo, false)
if err != nil {
return err
}
client, ok := s3Cli.(s3BucketAccelerateConfigGetter)
if !ok {
return (&SnowflakeError{
Number: ErrFailedToConvertToS3Client,
SQLState: sfa.data.SQLState,
QueryID: sfa.data.QueryID,
Message: errMsgFailedToConvertToS3Client,
}).exceptionTelemetry(sfa.sc)
}
ret, err := client.GetBucketAccelerateConfiguration(context.Background(), &s3.GetBucketAccelerateConfigurationInput{
Bucket: &s3Loc.bucketName,
})
sfa.useAccelerateEndpoint = ret != nil && ret.Status == "Enabled"
if err != nil {
logger.WithContext(sfa.sc.ctx).Warnln("An error occurred when getting accelerate config:", err)
}
return nil
}

func (sfa *snowflakeFileTransferAgent) transferAccelerateConfig() error {
if sfa.stageLocationType == s3Client {
s3Util := new(snowflakeS3Client)
s3Loc, err := s3Util.extractBucketNameAndPath(sfa.stageInfo.Location)
if err != nil {
return err
}
s3Cli, err := s3Util.createClient(sfa.stageInfo, false)
if err != nil {
return err
}
client, ok := s3Cli.(*s3.Client)
if !ok {
return (&SnowflakeError{
Number: ErrFailedToConvertToS3Client,
SQLState: sfa.data.SQLState,
QueryID: sfa.data.QueryID,
Message: errMsgFailedToConvertToS3Client,
}).exceptionTelemetry(sfa.sc)
}
ret, err := client.GetBucketAccelerateConfiguration(context.Background(), &s3.GetBucketAccelerateConfigurationInput{
Bucket: &s3Loc.bucketName,
})
sfa.useAccelerateEndpoint = ret != nil && ret.Status == "Enabled"
if err != nil {
var ae smithy.APIError
if errors.As(err, &ae) {
if ae.ErrorCode() == "AccessDenied" {
return nil
} else if ae.ErrorCode() == "MethodNotAllowed" {
return nil
} else if strings.EqualFold(ae.ErrorCode(), "UnsupportedArgument") {
// In AWS China and US Gov partitions, Transfer Acceleration is not supported
// https://docs.amazonaws.cn/en_us/aws/latest/userguide/s3.html#feature-diff
// https://docs.aws.amazon.com/govcloud-us/latest/UserGuide/govcloud-s3.html
return nil
}
}
return err
}
return sfa.transferAccelerateConfigWithUtil(s3Util)
}
return nil
}
Expand Down
130 changes: 130 additions & 0 deletions file_transfer_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"errors"
"fmt"
"github.com/aws/aws-sdk-go-v2/service/s3"
"io"
"net/url"
"os"
Expand Down Expand Up @@ -49,6 +50,135 @@ func TestGetBucketAccelerateConfiguration(t *testing.T) {
})
}

type s3ClientCreatorMock struct {
extract func(string) (*s3Location, error)
create func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error)
}

func (mock *s3ClientCreatorMock) extractBucketNameAndPath(location string) (*s3Location, error) {
return mock.extract(location)
}

func (mock *s3ClientCreatorMock) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
return mock.create(info, useAccelerateEndpoint)
}

type s3BucketAccelerateConfigGetterMock struct {
err error
}

func (mock *s3BucketAccelerateConfigGetterMock) GetBucketAccelerateConfiguration(ctx context.Context, params *s3.GetBucketAccelerateConfigurationInput, optFns ...func(*s3.Options)) (*s3.GetBucketAccelerateConfigurationOutput, error) {
return nil, mock.err
}

func TestGetBucketAccelerateConfigurationTooManyRetries(t *testing.T) {
runSnowflakeConnTest(t, func(sct *SCTest) {
buf := &bytes.Buffer{}
logger.SetOutput(buf)
err := logger.SetLogLevel("warn")
if err != nil {
return
}
sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
commandType: uploadCommand,
srcFiles: make([]string, 0),
data: &execResponseData{
SrcLocations: make([]string, 0),
},
stageInfo: &execResponseStageInfo{
Location: "test",
},
}
err = sfa.transferAccelerateConfigWithUtil(&s3ClientCreatorMock{
extract: func(s string) (*s3Location, error) {
return &s3Location{bucketName: "test", s3Path: "test"}, nil
},
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
return &s3BucketAccelerateConfigGetterMock{err: errors.New("testing")}, nil
},
})
assertNilE(t, err)
assertStringContainsE(t, buf.String(), "msg=\"An error occurred when getting accelerate config: testing\"")
})
}

func TestGetBucketAccelerateConfigurationFailedExtractBucketNameAndPath(t *testing.T) {
runSnowflakeConnTest(t, func(sct *SCTest) {
sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
commandType: uploadCommand,
srcFiles: make([]string, 0),
data: &execResponseData{
SrcLocations: make([]string, 0),
},
stageInfo: &execResponseStageInfo{
Location: "test",
},
}
err := sfa.transferAccelerateConfigWithUtil(&s3ClientCreatorMock{
extract: func(s string) (*s3Location, error) {
return nil, errors.New("failed extraction")
},
})
assertNotNilE(t, err)
})
}

func TestGetBucketAccelerateConfigurationFailedCreateClient(t *testing.T) {
runSnowflakeConnTest(t, func(sct *SCTest) {
sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
commandType: uploadCommand,
srcFiles: make([]string, 0),
data: &execResponseData{
SrcLocations: make([]string, 0),
},
stageInfo: &execResponseStageInfo{
Location: "test",
},
}
err := sfa.transferAccelerateConfigWithUtil(&s3ClientCreatorMock{
extract: func(s string) (*s3Location, error) {
return &s3Location{bucketName: "test", s3Path: "test"}, nil
},
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
return nil, errors.New("failed creation")
},
})
assertNotNilE(t, err)
})
}

func TestGetBucketAccelerateConfigurationInvalidClient(t *testing.T) {
runSnowflakeConnTest(t, func(sct *SCTest) {
sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
commandType: uploadCommand,
srcFiles: make([]string, 0),
data: &execResponseData{
SrcLocations: make([]string, 0),
},
stageInfo: &execResponseStageInfo{
Location: "test",
},
}
err := sfa.transferAccelerateConfigWithUtil(&s3ClientCreatorMock{
extract: func(s string) (*s3Location, error) {
return &s3Location{bucketName: "test", s3Path: "test"}, nil
},
create: func(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
return 1, nil
},
})
assertNotNilE(t, err)
})
}

func TestUnitDownloadWithInvalidLocalPath(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "data")
if err != nil {
Expand Down

0 comments on commit a26ac8a

Please sign in to comment.