diff --git a/s3/session.go b/s3/session.go index df9d551..9f03baf 100644 --- a/s3/session.go +++ b/s3/session.go @@ -4,6 +4,7 @@ import ( "io/ioutil" "net" "net/http" + "net/url" "os" "strconv" "strings" @@ -23,7 +24,7 @@ const HTTP = "http" // TODO : unit tests // Given an S3 bucket name, attempt to determine its region -func findBucketRegion(bucket string, config *aws.Config) (string, error) { +var findBucketRegion = func(bucket string, config *aws.Config) (string, error) { input := s3.GetBucketLocationInput{ Bucket: aws.String(bucket), } @@ -47,13 +48,35 @@ func findBucketRegion(bucket string, config *aws.Config) (string, error) { } // TODO : unit tests -func getAWSRegion(s3Bucket string, config *aws.Config, settings map[string]string) (string, error) { +func GetAWSRegion(s3Bucket string, config *aws.Config, settings map[string]string) (string, error) { if region, ok := settings[RegionSetting]; ok { return region, nil } + if config.Endpoint == nil || - *config.Endpoint == "" || - strings.HasSuffix(*config.Endpoint, ".amazonaws.com") { + *config.Endpoint == "" { + + if config.Endpoint != nil { + hostAddr, parseErr := url.Parse(*config.Endpoint) + + if parseErr != nil { + return "us-east-1", parseErr + } + + host, _, err := net.SplitHostPort(hostAddr.Host) + + if err != nil { + return "us-east-1", err + } + + if strings.HasSuffix(host, ".amazonaws.com") { + region, err := findBucketRegion(s3Bucket, config) + return region, errors.Wrapf(err, "%s is not set and s3:GetBucketLocation failed", RegionSetting) + } + + return "us-east-1", nil + } + region, err := findBucketRegion(s3Bucket, config) return region, errors.Wrapf(err, "%s is not set and s3:GetBucketLocation failed", RegionSetting) } else { @@ -122,7 +145,7 @@ func createSession(bucket string, settings map[string]string) (*session.Session, config.S3ForcePathStyle = aws.Bool(s3ForcePathStyle) } - region, err := getAWSRegion(bucket, config, settings) + region, err := GetAWSRegion(bucket, config, settings) if err != nil { return nil, err } diff --git a/s3/session_test.go b/s3/session_test.go new file mode 100644 index 0000000..d1ead2d --- /dev/null +++ b/s3/session_test.go @@ -0,0 +1,49 @@ +package s3 + +import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/stretchr/testify/assert" + "github.com/aws/aws-sdk-go/aws/defaults" + "testing" +) + +var bucket = "s3://test-bucket/wal-g-test-folder/Sub0" +var settings = map[string]string{ + EndpointSetting: "http://s3.mdst.yandex.net/", +} +var config = defaults.Get().Config.WithRegion(settings[RegionSetting]) + +func TestGetAWSRegionWithEmptyEndpoint(t *testing.T) { + findBucketRegion = func(bucket string, config *aws.Config) (string, error) { + return "europe", nil + } + + region, err := GetAWSRegion(bucket, config, settings) + + assert.Nil(t, err) + assert.Equal(t, region, "europe") +} + +func TestGetAWSRegionWithPort(t *testing.T) { + bucket = "s3://test-bucket:8080/wal-g-test-folder/Sub0" + findBucketRegion = func(bucket string, config *aws.Config) (string, error) { + return "europe", nil + } + + region, err := GetAWSRegion(bucket, config, settings) + + assert.Nil(t, err) + assert.Equal(t, region, "europe") +} + +func TestGetAWSRegionWithEndpoint(t *testing.T) { + config.Endpoint = aws.String("s3://test-bucket:8080/wal-g-test-folder/Sub0") + findBucketRegion = func(bucket string, config *aws.Config) (string, error) { + return "europe", nil + } + + region, err := GetAWSRegion(bucket, config, settings) + + assert.Nil(t, err) + assert.Equal(t, region, "us-east-1") +}