diff --git a/src/profile.rs b/src/profile.rs index 4bb8d7f..e1ff419 100644 --- a/src/profile.rs +++ b/src/profile.rs @@ -70,7 +70,7 @@ impl TryFrom for SystemTime { } impl<'a> LongTermProfile<'a> { - pub async fn create_client(&self) -> STSClient { + pub async fn create_client(&self, region: String) -> STSClient { let credentials = AWSCredentials::new( self.access_key.clone(), self.secret_key.clone(), @@ -81,7 +81,7 @@ impl<'a> LongTermProfile<'a> { let conf = StsConfig::Builder::new() .behavior_version(StsConfig::BehaviorVersion::v2023_11_09()) .credentials_provider(credentials) - .region(Some(StsConfig::Region::new("eu-central-5"))) + .region(Some(StsConfig::Region::new(region))) .build(); STSClient::from_conf(conf) diff --git a/src/sts/assume_role.rs b/src/sts/assume_role.rs index 41440b0..6f5289d 100644 --- a/src/sts/assume_role.rs +++ b/src/sts/assume_role.rs @@ -62,7 +62,7 @@ impl ShortTermCredentials for AssumeRole { lt_profile: &LongTermProfile<'_>, ) -> anyhow::Result { let output = lt_profile - .create_client() + .create_client(config.sts_region.clone()) .await .assume_role() .set_role_arn(Some(self.role_arn.clone())) diff --git a/src/sts/config.rs b/src/sts/config.rs index cc6ea1d..d59bfae 100644 --- a/src/sts/config.rs +++ b/src/sts/config.rs @@ -29,6 +29,13 @@ pub struct CommonStsConfig { help = "Force the creation of a new short-term profile even if one already exists" )] pub force_new_credentials: bool, + #[arg( + long, + global = true, + default_value = "us-east-1", + help = "The STS region to use for the AWS client" + )] + pub sts_region: String, } impl CommonStsConfig { diff --git a/src/sts/session_token.rs b/src/sts/session_token.rs index de12d4d..6fcef84 100644 --- a/src/sts/session_token.rs +++ b/src/sts/session_token.rs @@ -35,7 +35,7 @@ impl ShortTermCredentials for SessionToken { lt_profile: &LongTermProfile<'_>, ) -> anyhow::Result { let output = lt_profile - .create_client() + .create_client(config.sts_region.clone()) .await .get_session_token() .serial_number(lt_profile.mfa_device.to_string())