Skip to content

Commit

Permalink
perf: use smart pointers instead of owned strings (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
eegli authored Apr 6, 2023
1 parent a0b2164 commit 802452b
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 56 deletions.
24 changes: 16 additions & 8 deletions src/cmds/assume_role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
sts::{extract_sts_err, StsAction},
};
use async_trait::async_trait;
use std::borrow::Cow;

#[derive(clap::Parser, Debug, Default)]
pub struct AssumeRole {
Expand Down Expand Up @@ -39,8 +40,11 @@ impl ProfileName for AssumeRole {
}

#[async_trait]
impl StsAction for AssumeRole {
type Output = ShortTermProfile;
impl<'a> StsAction for &'a AssumeRole {
type Output = ShortTermProfile<'a>;

const DEFAULT_DURATION: i32 = 3600;

async fn execute(
&self,
config: &Config,
Expand All @@ -53,19 +57,23 @@ impl StsAction for AssumeRole {
.assume_role()
.set_role_arn(Some(self.role_arn.clone()))
.set_role_session_name(Some(self.role_name.clone()))
.set_serial_number(Some(lt_profile.mfa_device.clone()))
.set_serial_number(Some(lt_profile.mfa_device.to_string()))
.set_token_code(Some(mfa_token))
.set_duration_seconds(config.duration.or(Some(3600)))
.set_duration_seconds(config.duration.or(Some(Self::DEFAULT_DURATION)))
.send()
.await
.map_err(extract_sts_err)?;
let mut stp = ShortTermProfile::try_from(output.credentials())?;
let assumed_role = output.assumed_role_user().unwrap();

let mut stp = ShortTermProfile::try_from(output.credentials)?;

// Assumed_role_arn is the user input role_arn, not the actual
// role_arn returned by STS
stp.assumed_role_arn = Some(self.role_arn.clone());
stp.assumed_role_arn = Some(Cow::Borrowed(&self.role_arn));
// Assumed_role_id is the actual role_id returned by STS
stp.assumed_role_id = Some(assumed_role.assumed_role_id().unwrap().to_string());
stp.assumed_role_id = output
.assumed_role_user
.map(|v| v.assumed_role_id)
.unwrap_or_default();

Ok(stp)
}
Expand Down
18 changes: 10 additions & 8 deletions src/cmds/session_token.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use async_trait::async_trait;

use crate::{
config::Config,
profile::{LongTermProfile, ProfileName, ShortTermProfile},
sts::{extract_sts_err, StsAction},
};
use async_trait::async_trait;

#[derive(clap::Parser, Debug, Default)]
pub struct SessionToken;
Expand All @@ -16,8 +15,11 @@ impl ProfileName for SessionToken {
}

#[async_trait]
impl StsAction for SessionToken {
type Output = ShortTermProfile;
impl<'a> StsAction for &'a SessionToken {
type Output = ShortTermProfile<'a>;

const DEFAULT_DURATION: i32 = 43200;

async fn execute(
&self,
config: &Config,
Expand All @@ -28,13 +30,13 @@ impl StsAction for SessionToken {
.create_client()
.await
.get_session_token()
.serial_number(lt_profile.mfa_device.clone())
.duration_seconds(config.duration.unwrap_or(43200))
.token_code(mfa_token.to_string())
.serial_number(lt_profile.mfa_device.to_string())
.duration_seconds(config.duration.unwrap_or(Self::DEFAULT_DURATION))
.token_code(mfa_token)
.send()
.await
.map_err(extract_sts_err)?;
let short_term_profile = ShortTermProfile::try_from(output.credentials())?;
let short_term_profile = ShortTermProfile::try_from(output.credentials)?;
Ok(short_term_profile)
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ impl Config {

fn validate_credentials_path(&mut self) -> anyhow::Result<()> {
if self.credentials.is_relative() {
self.credentials = dirs::home_dir().unwrap().join(self.credentials.as_path());
self.credentials = dirs::home_dir()
.expect("Cannot find home directory")
.join(self.credentials.as_path());
}
if !self.credentials.is_file() {
anyhow::bail!("The credentials file does not exist");
Expand Down
23 changes: 12 additions & 11 deletions src/creds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ use crate::{
utils::get_remaining_time,
};
use ini::{Ini, Properties};
use std::path::Path;
use std::{borrow::Cow, path::Path};
use std::{fmt::Debug, time::SystemTime};
use thiserror::Error;

pub struct CredentialsHandler(pub Ini);

impl CredentialsHandler {
Expand All @@ -21,10 +22,10 @@ impl CredentialsHandler {
Ok(Self(Ini::load_from_file(path)?))
}

pub fn get_long_term_profile(
&self,
conf: &Config,
) -> Result<LongTermProfile, CredentialsError> {
pub fn get_long_term_profile<'a>(
&'a self,
conf: &'a Config,
) -> Result<LongTermProfile<'a>, CredentialsError> {
let profile = &conf.profile_name;
let sections = self
.0
Expand All @@ -37,23 +38,23 @@ impl CredentialsHandler {
1 => {
let section = sections[0];
let mut pf = LongTermProfile {
name: profile.to_owned(),
name: Cow::Borrowed(profile),
..Default::default()
};
match section.get(LongTermProfile::ACCESS_KEY) {
Some(access_key) => pf.access_key = access_key.to_owned(),
Some(access_key) => pf.access_key = Cow::Borrowed(access_key),
None => Err(CredentialsError::NoAccessKey(profile.to_owned()))?,
}
match section.get(LongTermProfile::SECRET_KEY) {
Some(secret_key) => pf.secret_key = secret_key.to_owned(),
Some(secret_key) => pf.secret_key = Cow::Borrowed(secret_key),
None => Err(CredentialsError::NoSecretKey(profile.to_owned()))?,
}
match conf
.mfa_device
.as_deref()
.or(section.get(LongTermProfile::MFA_DEVICE))
{
Some(mfa_device) => pf.mfa_device = mfa_device.to_owned(),
Some(mfa_device) => pf.mfa_device = Cow::Borrowed(mfa_device),
None => Err(CredentialsError::NoMfaDevice(profile.to_owned()))?,
}

Expand All @@ -80,7 +81,7 @@ impl CredentialsHandler {
self.0.set_to(
Some(profile_name),
LongTermProfile::ASSUMED_ROLE_ARN.to_owned(),
arn.to_owned(),
arn.to_string(),
);
}

Expand Down Expand Up @@ -242,7 +243,7 @@ mod test_short_term_profile {
let mut handler = CredentialsHandler::_new("").unwrap();
let profile = ShortTermProfile {
assumed_role_id: Some("id".to_owned()),
assumed_role_arn: Some("arn".to_owned()),
assumed_role_arn: Some(Cow::Owned("arn".to_owned())),
..Default::default()
};
let profile_name = "test";
Expand Down
49 changes: 22 additions & 27 deletions src/profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@ use crate::config::Config;
use aws_sdk_sts::primitives::DateTime as AwsDateTime;
use aws_sdk_sts::types::Credentials as StsCredentials;
use aws_smithy_types::date_time::{ConversionError, DateTimeParseError, Format};
use std::{ops::Deref, str::FromStr, time::SystemTime};
use std::{borrow::Cow, ops::Deref, str::FromStr, time::SystemTime};

#[derive(Debug, Default)]
pub struct LongTermProfile {
pub name: String,
pub access_key: String,
pub secret_key: String,
pub mfa_device: String,
pub struct LongTermProfile<'a> {
pub name: Cow<'a, str>,
pub access_key: Cow<'a, str>,
pub secret_key: Cow<'a, str>,
pub mfa_device: Cow<'a, str>,
}
#[derive(Debug, Default)]
pub struct ShortTermProfile {
pub struct ShortTermProfile<'a> {
pub access_key: String,
pub secret_key: String,
pub session_token: String,
pub expiration: DateTime,
pub assumed_role_id: Option<String>,
pub assumed_role_arn: Option<String>,
pub assumed_role_arn: Option<Cow<'a, str>>,
}

#[derive(Debug, Clone)]
Expand All @@ -34,8 +34,8 @@ pub trait Profile {
const EXPIRATION: &'static str = "expiration";
}

impl Profile for LongTermProfile {}
impl Profile for ShortTermProfile {}
impl<'a> Profile for LongTermProfile<'a> {}
impl<'a> Profile for ShortTermProfile<'a> {}

pub trait ProfileName {
fn short_profile_name(&self, config: &Config) -> String;
Expand Down Expand Up @@ -68,33 +68,28 @@ impl TryFrom<DateTime> for SystemTime {
}
}

impl ShortTermProfile {
impl<'a> ShortTermProfile<'a> {
pub fn format_expiration(&self) -> String {
self.expiration.fmt(Format::DateTime).unwrap()
}
}

impl TryFrom<Option<&StsCredentials>> for ShortTermProfile {
impl<'a> TryFrom<Option<StsCredentials>> for ShortTermProfile<'a> {
type Error = anyhow::Error;
fn try_from(creds: Option<&StsCredentials>) -> anyhow::Result<Self> {
fn try_from(creds: Option<StsCredentials>) -> anyhow::Result<Self> {
let creds = creds.ok_or_else(|| anyhow::anyhow!("Failed to extract STS credentials"))?;

if let (
Some(access_key_id),
Some(secret_access_key),
Some(session_token),
Some(expiration),
) = (
creds.access_key_id(),
creds.secret_access_key(),
creds.session_token(),
creds.expiration(),
if let (Some(access_key), Some(secret_key), Some(session_token), Some(expiration)) = (
creds.access_key_id,
creds.secret_access_key,
creds.session_token,
creds.expiration,
) {
Ok(Self {
access_key: access_key_id.to_owned(),
secret_key: secret_access_key.to_owned(),
session_token: session_token.to_owned(),
expiration: DateTime(expiration.to_owned()),
access_key,
secret_key,
session_token,
expiration: DateTime(expiration),
assumed_role_arn: None,
assumed_role_id: None,
})
Expand Down
5 changes: 4 additions & 1 deletion src/sts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ use aws_sdk_sts::{
#[async_trait]
pub trait StsAction {
type Output;

const DEFAULT_DURATION: i32;

async fn execute(
&self,
config: &Config,
Expand All @@ -26,7 +29,7 @@ pub trait StsAction {
}
}

impl LongTermProfile {
impl<'a> LongTermProfile<'a> {
pub async fn create_client(&self) -> Client {
let credentials =
Credentials::from_keys(self.access_key.clone(), self.secret_key.clone(), None);
Expand Down

0 comments on commit 802452b

Please sign in to comment.