Unverified Commit b50c2baf authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Implement support for `sso-session` in AWS config file (#3379)

This PR implements parsing support for `[sso-session name]` in the
`~/.aws/config` file.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 5e20575e
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -16,3 +16,9 @@ message = "The types in the aws-http crate were moved into aws-runtime. Deprecat
references = ["smithy-rs#3355"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "jdisanti"

[[aws-sdk-rust]]
message = "Add support for `[sso-session]` in AWS config file for AWS Identity Center SSO credentials. Note that this does not include support for AWS Builder ID SSO sessions for services such as Code Catalyst (these lack the `sso_account_id` and `sso_role_name` fields in the profile config). Support for AWS Builder IDs is still being tracked in https://github.com/awslabs/aws-sdk-rust/issues/703."
references = ["aws-sdk-rust#703", "smithy-rs#3379"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
author = "jdisanti"
+46 −0
Original line number Diff line number Diff line
@@ -265,6 +265,29 @@ pub enum ProfileFileError {
        /// Additional information about the missing feature
        message: Option<Cow<'static, str>>,
    },

    /// Missing sso-session section in config
    #[non_exhaustive]
    MissingSsoSession {
        /// The name of the profile that specified `sso_session`
        profile: String,
        /// SSO session name
        sso_session: String,
    },

    /// Invalid SSO configuration
    #[non_exhaustive]
    InvalidSsoConfig {
        /// The name of the profile that the error originates in
        profile: String,
        /// Error message
        message: Cow<'static, str>,
    },

    /// Profile is intended to be used in the token provider chain rather
    /// than in the credentials chain.
    #[non_exhaustive]
    TokenProviderConfig {},
}

impl ProfileFileError {
@@ -324,6 +347,25 @@ impl Display for ProfileFileError {
                    "This behavior requires following cargo feature(s) enabled: {feature}. {message}",
                )
            }
            ProfileFileError::MissingSsoSession {
                profile,
                sso_session,
            } => {
                write!(f, "sso-session named `{sso_session}` (referenced by profile `{profile}`) was not found")
            }
            ProfileFileError::InvalidSsoConfig { profile, message } => {
                write!(f, "profile `{profile}` has invalid SSO config: {message}")
            }
            ProfileFileError::TokenProviderConfig { .. } => {
                // TODO(https://github.com/awslabs/aws-sdk-rust/issues/703): Update error message once token support is added
                write!(
                    f,
                    "selected profile will resolve an access token instead of credentials \
                     since it doesn't have `sso_account_id` and `sso_role_name` set. Access token \
                     support for services such as Code Catalyst hasn't been implemented yet and is \
                     being tracked in https://github.com/awslabs/aws-sdk-rust/issues/703"
                )
            }
        }
    }
}
@@ -497,4 +539,8 @@ mod test {
    make_test!(credential_process_failure);
    #[cfg(feature = "credentials-process")]
    make_test!(credential_process_invalid);
    #[cfg(feature = "sso")]
    make_test!(sso_credentials);
    #[cfg(feature = "sso")]
    make_test!(sso_token);
}
+7 −2
Original line number Diff line number Diff line
@@ -130,19 +130,24 @@ impl ProviderChain {
                sso_region,
                sso_role_name,
                sso_start_url,
                sso_session_name,
            } => {
                #[cfg(feature = "sso")]
                {
                    use crate::sso::{credentials::SsoProviderConfig, SsoCredentialsProvider};
                    use aws_types::region::Region;

                    let (Some(sso_account_id), Some(sso_role_name)) =
                        (sso_account_id, sso_role_name)
                    else {
                        return Err(ProfileFileError::TokenProviderConfig {});
                    };
                    let sso_config = SsoProviderConfig {
                        account_id: sso_account_id.to_string(),
                        role_name: sso_role_name.to_string(),
                        start_url: sso_start_url.to_string(),
                        region: Region::new(sso_region.to_string()),
                        // TODO(https://github.com/awslabs/aws-sdk-rust/issues/703): Implement sso_session_name profile property
                        session_name: None,
                        session_name: sso_session_name.map(|s| s.to_string()),
                    };
                    Arc::new(SsoCredentialsProvider::new(provider_config, sso_config))
                }
+108 −37
Original line number Diff line number Diff line
@@ -78,10 +78,13 @@ pub(super) enum BaseProvider<'a> {

    /// An SSO Provider
    Sso {
        sso_account_id: &'a str,
        sso_session_name: Option<&'a str>,
        sso_region: &'a str,
        sso_role_name: &'a str,
        sso_start_url: &'a str,

        // Credentials from SSO fields
        sso_account_id: Option<&'a str>,
        sso_role_name: Option<&'a str>,
    },

    /// A profile that specifies a `credential_process`
@@ -172,7 +175,7 @@ pub(super) fn resolve_chain(
                chain.push(role_provider);
                next
            } else {
                break base_provider(profile).map_err(|err| {
                break base_provider(profile_set, profile).map_err(|err| {
                    // It's possible for base_provider to return a `ProfileFileError::ProfileDidNotContainCredentials`
                    // if we're still looking at the first provider we want to surface it. However,
                    // if we're looking at any provider after the first we want to instead return a `ProfileFileError::InvalidCredentialSource`
@@ -193,7 +196,7 @@ pub(super) fn resolve_chain(
                // self referential profile, don't go through the loop because it will error
                // on the infinite loop check. Instead, reload this profile as a base profile
                // and exit.
                break base_provider(profile)?;
                break base_provider(profile_set, profile)?;
            }
            NextProfile::Named(name) => source_profile_name = name,
        }
@@ -216,6 +219,7 @@ mod sso {
    pub(super) const REGION: &str = "sso_region";
    pub(super) const ROLE_NAME: &str = "sso_role_name";
    pub(super) const START_URL: &str = "sso_start_url";
    pub(super) const SESSION_NAME: &str = "sso_session";
}

mod web_identity_token {
@@ -234,12 +238,15 @@ mod credential_process {

const PROVIDER_NAME: &str = "ProfileFile";

fn base_provider(profile: &Profile) -> Result<BaseProvider<'_>, ProfileFileError> {
fn base_provider<'a>(
    profile_set: &'a ProfileSet,
    profile: &'a Profile,
) -> Result<BaseProvider<'a>, ProfileFileError> {
    // the profile must define either a `CredentialsSource` or a concrete set of access keys
    match profile.get(role::CREDENTIAL_SOURCE) {
        Some(source) => Ok(BaseProvider::NamedSource(source)),
        None => web_identity_token_from_profile(profile)
            .or_else(|| sso_from_profile(profile))
            .or_else(|| sso_from_profile(profile_set, profile).transpose())
            .or_else(|| credential_process_from_profile(profile))
            .unwrap_or_else(|| Ok(BaseProvider::AccessKey(static_creds_from_profile(profile)?))),
    }
@@ -292,39 +299,93 @@ fn role_arn_from_profile(profile: &Profile) -> Option<RoleArn<'_>> {
    })
}

fn sso_from_profile(profile: &Profile) -> Option<Result<BaseProvider<'_>, ProfileFileError>> {
fn sso_from_profile<'a>(
    profile_set: &'a ProfileSet,
    profile: &'a Profile,
) -> Result<Option<BaseProvider<'a>>, ProfileFileError> {
    /*
    Sample:
    -- Sample without sso-session: --

    [profile sample-profile]
    sso_account_id = 012345678901
    sso_region = us-east-1
    sso_role_name = SampleRole
    sso_start_url = https://d-abc123.awsapps.com/start-beta

    -- Sample with sso-session: --

    [profile sample-profile]
    sso_session = dev
    sso_account_id = 012345678901
    sso_role_name = SampleRole

    [sso-session dev]
    sso_region = us-east-1
    sso_start_url = https://d-abc123.awsapps.com/start-beta
    */
    let account_id = profile.get(sso::ACCOUNT_ID);
    let region = profile.get(sso::REGION);
    let role_name = profile.get(sso::ROLE_NAME);
    let start_url = profile.get(sso::START_URL);
    if [account_id, region, role_name, start_url]
    let sso_account_id = profile.get(sso::ACCOUNT_ID);
    let mut sso_region = profile.get(sso::REGION);
    let sso_role_name = profile.get(sso::ROLE_NAME);
    let mut sso_start_url = profile.get(sso::START_URL);
    let sso_session_name = profile.get(sso::SESSION_NAME);
    if [
        sso_account_id,
        sso_region,
        sso_role_name,
        sso_start_url,
        sso_session_name,
    ]
    .iter()
        .all(|field| field.is_none())
    .all(Option::is_none)
    {
        return None;
        return Ok(None);
    }

    let invalid_sso_config = |s: &str| ProfileFileError::InvalidSsoConfig {
        profile: profile.name().into(),
        message: format!(
            "`{s}` can only be specified in the [sso-session] config when a session name is given"
        )
        .into(),
    };
    if let Some(sso_session_name) = sso_session_name {
        if sso_start_url.is_some() {
            return Err(invalid_sso_config(sso::START_URL));
        }
        if sso_region.is_some() {
            return Err(invalid_sso_config(sso::REGION));
        }
        if let Some(session) = profile_set.sso_session(sso_session_name) {
            sso_start_url = session.get(sso::START_URL);
            sso_region = session.get(sso::REGION);
        } else {
            return Err(ProfileFileError::MissingSsoSession {
                profile: profile.name().into(),
                sso_session: sso_session_name.into(),
            });
        }
    }

    let invalid_sso_creds = |left: &str, right: &str| ProfileFileError::InvalidSsoConfig {
        profile: profile.name().into(),
        message: format!("if `{left}` is set, then `{right}` must also be set").into(),
    };
    match (sso_account_id, sso_role_name) {
        (Some(_), Some(_)) | (None, None) => { /* good */ }
        (Some(_), None) => return Err(invalid_sso_creds(sso::ACCOUNT_ID, sso::ROLE_NAME)),
        (None, Some(_)) => return Err(invalid_sso_creds(sso::ROLE_NAME, sso::ACCOUNT_ID)),
    }

    let missing_field = |s| move || ProfileFileError::missing_field(profile, s);
    let parse_profile = || {
        let sso_account_id = account_id.ok_or_else(missing_field(sso::ACCOUNT_ID))?;
        let sso_region = region.ok_or_else(missing_field(sso::REGION))?;
        let sso_role_name = role_name.ok_or_else(missing_field(sso::ROLE_NAME))?;
        let sso_start_url = start_url.ok_or_else(missing_field(sso::START_URL))?;
        Ok(BaseProvider::Sso {
    let sso_region = sso_region.ok_or_else(missing_field(sso::REGION))?;
    let sso_start_url = sso_start_url.ok_or_else(missing_field(sso::START_URL))?;
    Ok(Some(BaseProvider::Sso {
        sso_account_id,
        sso_region,
        sso_role_name,
        sso_start_url,
        })
    };
    Some(parse_profile())
        sso_session_name,
    }))
}

fn web_identity_token_from_profile(
@@ -429,7 +490,11 @@ mod tests {
    }

    fn check(test_case: TestCase) {
        let source = ProfileSet::new(test_case.input.profile, test_case.input.selected_profile);
        let source = ProfileSet::new(
            test_case.input.profiles,
            test_case.input.selected_profile,
            test_case.input.sso_sessions,
        );
        let actual = resolve_chain(&source);
        let expected = test_case.output;
        match (expected, actual) {
@@ -458,8 +523,10 @@ mod tests {

    #[derive(Deserialize)]
    struct TestInput {
        profile: HashMap<String, HashMap<String, String>>,
        profiles: HashMap<String, HashMap<String, String>>,
        selected_profile: String,
        #[serde(default)]
        sso_sessions: HashMap<String, HashMap<String, String>>,
    }

    fn to_test_output(profile_chain: ProfileChain<'_>) -> Vec<Provider> {
@@ -484,15 +551,17 @@ mod tests {
                role_session_name: session_name.map(|sess| sess.to_string()),
            }),
            BaseProvider::Sso {
                sso_account_id,
                sso_region,
                sso_role_name,
                sso_start_url,
                sso_session_name,
                sso_account_id,
                sso_role_name,
            } => output.push(Provider::Sso {
                sso_account_id: sso_account_id.into(),
                sso_region: sso_region.into(),
                sso_role_name: sso_role_name.into(),
                sso_start_url: sso_start_url.into(),
                sso_session: sso_session_name.map(|s| s.to_string()),
                sso_account_id: sso_account_id.map(|s| s.to_string()),
                sso_role_name: sso_role_name.map(|s| s.to_string()),
            }),
        };
        for role in profile_chain.chain {
@@ -531,10 +600,12 @@ mod tests {
            role_session_name: Option<String>,
        },
        Sso {
            sso_account_id: String,
            sso_region: String,
            sso_role_name: String,
            sso_start_url: String,
            sso_session: Option<String>,

            sso_account_id: Option<String>,
            sso_role_name: Option<String>,
        },
    }

+200 −31
Original line number Diff line number Diff line
@@ -74,6 +74,7 @@ pub async fn load(
pub struct ProfileSet {
    profiles: HashMap<String, Profile>,
    selected_profile: Cow<'static, str>,
    sso_sessions: HashMap<String, SsoSession>,
}

impl ProfileSet {
@@ -84,6 +85,7 @@ impl ProfileSet {
    pub(crate) fn new(
        profiles: HashMap<String, HashMap<String, String>>,
        selected_profile: impl Into<Cow<'static, str>>,
        sso_sessions: HashMap<String, HashMap<String, String>>,
    ) -> Self {
        let mut base = ProfileSet::empty();
        base.selected_profile = selected_profile.into();
@@ -99,6 +101,18 @@ impl ProfileSet {
                ),
            );
        }
        for (name, session) in sso_sessions {
            base.sso_sessions.insert(
                name.clone(),
                SsoSession::new(
                    name,
                    session
                        .into_iter()
                        .map(|(k, v)| (k.clone(), Property::new(k, v)))
                        .collect(),
                ),
            );
        }
        base
    }

@@ -124,11 +138,21 @@ impl ProfileSet {
        self.profiles.is_empty()
    }

    /// Returns the names of the profiles in this profile set
    /// Returns the names of the profiles in this config
    pub fn profiles(&self) -> impl Iterator<Item = &str> {
        self.profiles.keys().map(String::as_ref)
    }

    /// Returns the names of the SSO sessions in this config
    pub fn sso_sessions(&self) -> impl Iterator<Item = &str> {
        self.sso_sessions.keys().map(String::as_ref)
    }

    /// Retrieves a named SSO session from the config
    pub(crate) fn sso_session(&self, name: &str) -> Option<&SsoSession> {
        self.sso_sessions.get(name)
    }

    fn parse(source: Source) -> Result<Self, ProfileParseError> {
        let mut base = ProfileSet::empty();
        base.selected_profile = source.profile;
@@ -143,35 +167,146 @@ impl ProfileSet {
        Self {
            profiles: Default::default(),
            selected_profile: "default".into(),
            sso_sessions: Default::default(),
        }
    }
}

/// An individual configuration profile
///
/// An AWS config may be composed of a multiple named profiles within a [`ProfileSet`].
/// Represents a top-level section (e.g., `[profile name]`) in a config file.
pub(crate) trait Section {
    /// The name of this section
    fn name(&self) -> &str;

    /// Returns all the properties in this section
    fn properties(&self) -> &HashMap<String, Property>;

    /// Returns a reference to the property named `name`
    fn get(&self, name: &str) -> Option<&str>;

    /// True if there are no properties in this section.
    fn is_empty(&self) -> bool;

    /// Insert a property into a section
    fn insert(&mut self, name: String, value: Property);
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Profile {
struct SectionInner {
    name: String,
    properties: HashMap<String, Property>,
}

impl Section for SectionInner {
    fn name(&self) -> &str {
        &self.name
    }

    fn properties(&self) -> &HashMap<String, Property> {
        &self.properties
    }

    fn get(&self, name: &str) -> Option<&str> {
        self.properties
            .get(to_ascii_lowercase(name).as_ref())
            .map(|prop| prop.value())
    }

    fn is_empty(&self) -> bool {
        self.properties.is_empty()
    }

    fn insert(&mut self, name: String, value: Property) {
        self.properties
            .insert(to_ascii_lowercase(&name).into(), value);
    }
}

/// An individual configuration profile
///
/// An AWS config may be composed of a multiple named profiles within a [`ProfileSet`].
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Profile(SectionInner);

impl Profile {
    /// Create a new profile
    pub fn new(name: String, properties: HashMap<String, Property>) -> Self {
        Self { name, properties }
    pub fn new(name: impl Into<String>, properties: HashMap<String, Property>) -> Self {
        Self(SectionInner {
            name: name.into(),
            properties,
        })
    }

    /// The name of this profile
    pub fn name(&self) -> &str {
        &self.name
        self.0.name()
    }

    /// Returns a reference to the property named `name`
    pub fn get(&self, name: &str) -> Option<&str> {
        self.properties
            .get(to_ascii_lowercase(name).as_ref())
            .map(|prop| prop.value())
        self.0.get(name)
    }
}

impl Section for Profile {
    fn name(&self) -> &str {
        self.0.name()
    }

    fn properties(&self) -> &HashMap<String, Property> {
        self.0.properties()
    }

    fn get(&self, name: &str) -> Option<&str> {
        self.0.get(name)
    }

    fn is_empty(&self) -> bool {
        self.0.is_empty()
    }

    fn insert(&mut self, name: String, value: Property) {
        self.0.insert(name, value)
    }
}

/// A `[sso-session name]` section in the config.
#[derive(Debug, Clone, Eq, PartialEq)]
pub(crate) struct SsoSession(SectionInner);

impl SsoSession {
    /// Create a new SSO session section.
    pub(crate) fn new(name: impl Into<String>, properties: HashMap<String, Property>) -> Self {
        Self(SectionInner {
            name: name.into(),
            properties,
        })
    }

    /// Returns a reference to the property named `name`
    pub(crate) fn get(&self, name: &str) -> Option<&str> {
        self.0.get(name)
    }
}

impl Section for SsoSession {
    fn name(&self) -> &str {
        self.0.name()
    }

    fn properties(&self) -> &HashMap<String, Property> {
        self.0.properties()
    }

    fn get(&self, name: &str) -> Option<&str> {
        self.0.get(name)
    }

    fn is_empty(&self) -> bool {
        self.0.is_empty()
    }

    fn insert(&mut self, name: String, value: Property) {
        self.0.insert(name, value)
    }
}

@@ -248,7 +383,10 @@ pub struct CouldNotReadProfileFile {

#[cfg(test)]
mod test {
    use crate::profile::parser::source::{File, Source};
    use crate::profile::parser::{
        source::{File, Source},
        Section,
    };
    use crate::profile::profile_file::ProfileFileKind;
    use crate::profile::ProfileSet;
    use arbitrary::{Arbitrary, Unstructured};
@@ -333,17 +471,28 @@ mod test {
    }

    // for test comparison purposes, flatten a profile into a hashmap
    fn flatten(profile: ProfileSet) -> HashMap<String, HashMap<String, String>> {
        profile
            .profiles
            .into_values()
            .map(|profile| {
    #[derive(Debug)]
    struct FlattenedProfileSet {
        profiles: HashMap<String, HashMap<String, String>>,
        sso_sessions: HashMap<String, HashMap<String, String>>,
    }
    fn flatten(config: ProfileSet) -> FlattenedProfileSet {
        FlattenedProfileSet {
            profiles: flatten_sections(config.profiles.values().map(|p| p as _)),
            sso_sessions: flatten_sections(config.sso_sessions.values().map(|s| s as _)),
        }
    }
    fn flatten_sections<'a>(
        sections: impl Iterator<Item = &'a dyn Section>,
    ) -> HashMap<String, HashMap<String, String>> {
        sections
            .map(|section| {
                (
                    profile.name,
                    profile
                        .properties
                        .into_values()
                        .map(|prop| (prop.key, prop.value))
                    section.name().to_string(),
                    section
                        .properties()
                        .values()
                        .map(|prop| (prop.key.clone(), prop.value.clone()))
                        .collect(),
                )
            })
@@ -373,11 +522,28 @@ mod test {
        let copy = test_case.clone();
        let parsed = ProfileSet::parse(make_source(test_case.input));
        let res = match (parsed.map(flatten), &test_case.output) {
            (Ok(actual), ParserOutput::Profiles(expected)) if &actual != expected => Err(format!(
                "mismatch:\nExpected: {:#?}\nActual: {:#?}",
                expected, actual
            )),
            (Ok(_), ParserOutput::Profiles(_)) => Ok(()),
            (
                Ok(FlattenedProfileSet {
                    profiles: actual_profiles,
                    sso_sessions: actual_sso_sessions,
                }),
                ParserOutput::Config {
                    profiles,
                    sso_sessions,
                },
            ) => {
                if profiles != &actual_profiles {
                    Err(format!(
                        "mismatched profiles:\nExpected: {profiles:#?}\nActual: {actual_profiles:#?}",
                    ))
                } else if sso_sessions != &actual_sso_sessions {
                    Err(format!(
                        "mismatched sso_sessions:\nExpected: {sso_sessions:#?}\nActual: {actual_sso_sessions:#?}",
                    ))
                } else {
                    Ok(())
                }
            }
            (Err(msg), ParserOutput::ErrorContaining(substr)) => {
                if format!("{}", msg).contains(substr) {
                    Ok(())
@@ -386,10 +552,9 @@ mod test {
                }
            }
            (Ok(output), ParserOutput::ErrorContaining(err)) => Err(format!(
                "expected an error: {} but parse succeeded:\n{:#?}",
                err, output
                "expected an error: {err} but parse succeeded:\n{output:#?}",
            )),
            (Err(err), ParserOutput::Profiles(_expected)) => {
            (Err(err), ParserOutput::Config { .. }) => {
                Err(format!("Expected to succeed but got: {}", err))
            }
        };
@@ -417,7 +582,11 @@ mod test {
    #[derive(Deserialize, Debug, Clone)]
    #[serde(rename_all = "camelCase")]
    enum ParserOutput {
        Profiles(HashMap<String, HashMap<String, String>>),
        Config {
            profiles: HashMap<String, HashMap<String, String>>,
            #[serde(default)]
            sso_sessions: HashMap<String, HashMap<String, String>>,
        },
        ErrorContaining(String),
    }

Loading