Unverified Commit f573cc22 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Only create inner chain provider once in profile creds provider (#3387)

The `SsoCredentialsProvider` maintains an in-memory expiring cache of
SSO tokens, and this cache is maintained within its instance. The
`ProfileFileCredentialsProvider`, which uses `SsoCredentialsProvider` as
a base/inner provider, is currently reconstructing the inner provider
every time it loads credentials, which discards the SSO token cache
entirely.

This PR refactors the `ProfileFileCredentialsProvider` to cache the
inner provider so that it is only initialized once on first load of
credentials. This is done via a new `ErrorTakingOnceCell` abstraction
since the same init-caching mechanism will be needed for the
`ProfileFileTokenProvider` when it is implemented.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 166f0e29
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -98,3 +98,9 @@ message = "Cap the maximum jitter fraction for credentials cache refresh buffer
references = ["smithy-rs#3402"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "ysaito1001"

[[aws-sdk-rust]]
message = "Retain the SSO token cache between calls to `provide_credentials` when using IAM Identity Center SSO via the AWS config file."
references = ["smithy-rs#3387"]
meta = { "breaking" = false, "bug" = true, "tada" = false }
author = "jdisanti"
+258 −62
Original line number Diff line number Diff line
@@ -22,11 +22,14 @@
//! - `exec` which contains a chain representation of providers to implement passing bootstrapped credentials
//! through a series of providers.

use crate::profile::parser::ProfileFileLoadError;
use crate::profile::profile_file::ProfileFiles;
use crate::profile::Profile;
use crate::profile::{cell::ErrorTakingOnceCell, parser::ProfileFileLoadError};
use crate::provider_config::ProviderConfig;
use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
use aws_credential_types::{
    provider::{self, error::CredentialsError, future, ProvideCredentials},
    Credentials,
};
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_types::SdkConfig;
use std::borrow::Cow;
@@ -39,15 +42,6 @@ use tracing::Instrument;
mod exec;
mod repr;

impl ProvideCredentials for ProfileFileCredentialsProvider {
    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
    where
        Self: 'a,
    {
        future::ProvideCredentials::new(self.load_credentials())
    }
}

/// AWS Profile based credentials provider
///
/// This credentials provider will load credentials from `~/.aws/config` and `~/.aws/credentials`.
@@ -139,6 +133,12 @@ impl ProvideCredentials for ProfileFileCredentialsProvider {
#[doc = include_str!("location_of_profile_files.md")]
#[derive(Debug)]
pub struct ProfileFileCredentialsProvider {
    config: Arc<Config>,
    inner_provider: ErrorTakingOnceCell<ChainProvider, CredentialsError>,
}

#[derive(Debug)]
struct Config {
    factory: exec::named::NamedProviderFactory,
    sdk_config: SdkConfig,
    provider_config: ProviderConfig,
@@ -151,50 +151,51 @@ impl ProfileFileCredentialsProvider {
    }

    async fn load_credentials(&self) -> provider::Result {
        let inner_provider = build_provider_chain(&self.provider_config, &self.factory)
            .await
            .map_err(|err| match err {
        // The inner provider needs to be cached across successive calls to load_credentials
        // since the base providers can potentially have information cached in their instances.
        // For example, the SsoCredentialsProvider maintains an in-memory expiring token cache.
        let inner_provider = self
            .inner_provider
            .get_or_init(
                {
                    let config = self.config.clone();
                    move || async move {
                        match build_provider_chain(config.clone()).await {
                            Ok(chain) => Ok(ChainProvider {
                                config: config.clone(),
                                chain: Some(Arc::new(chain)),
                            }),
                            Err(err) => match err {
                                ProfileFileError::NoProfilesDefined
                                | ProfileFileError::ProfileDidNotContainCredentials { .. } => {
                    CredentialsError::not_loaded(err)
                                    Ok(ChainProvider {
                                        config: config.clone(),
                                        chain: None,
                                    })
                                }
                _ => CredentialsError::invalid_configuration(format!(
                                _ => Err(CredentialsError::invalid_configuration(format!(
                                    "ProfileFile provider could not be built: {}",
                                    &err
                )),
            })?;
        let mut creds = match inner_provider
            .base()
            .provide_credentials()
            .instrument(tracing::debug_span!("load_base_credentials"))
            .await
        {
            Ok(creds) => {
                tracing::info!(creds = ?creds, "loaded base credentials");
                creds
            }
            Err(e) => {
                tracing::warn!(error = %DisplayErrorContext(&e), "failed to load base credentials");
                return Err(CredentialsError::provider_error(e));
            }
        };
        for provider in inner_provider.chain().iter() {
            let next_creds = provider
                .credentials(creds, &self.sdk_config)
                .instrument(tracing::debug_span!("load_assume_role", provider = ?provider))
                .await;
            match next_creds {
                Ok(next_creds) => {
                    tracing::info!(creds = ?next_creds, "loaded assume role credentials");
                    creds = next_creds
                                ))),
                            },
                        }
                Err(e) => {
                    tracing::warn!(provider = ?provider, "failed to load assume role credentials");
                    return Err(CredentialsError::provider_error(e));
                    }
                },
                CredentialsError::unhandled(
                    "profile file credentials provider initialization error already taken",
                ),
            )
            .await?;
        inner_provider.provide_credentials().await
    }
}
        Ok(creds)

impl ProvideCredentials for ProfileFileCredentialsProvider {
    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
    where
        Self: 'a,
    {
        future::ProvideCredentials::new(self.load_credentials())
    }
}

@@ -486,24 +487,78 @@ impl Builder {
        let factory = exec::named::NamedProviderFactory::new(named_providers);

        ProfileFileCredentialsProvider {
            config: Arc::new(Config {
                factory,
                sdk_config: conf.client_config(),
                provider_config: conf,
            }),
            inner_provider: ErrorTakingOnceCell::new(),
        }
    }
}

async fn build_provider_chain(
    provider_config: &ProviderConfig,
    factory: &exec::named::NamedProviderFactory,
    config: Arc<Config>,
) -> Result<exec::ProviderChain, ProfileFileError> {
    let profile_set = provider_config
    let profile_set = config
        .provider_config
        .try_profile()
        .await
        .map_err(|parse_err| ProfileFileError::InvalidProfile(parse_err.clone()))?;
    let repr = repr::resolve_chain(profile_set)?;
    tracing::info!(chain = ?repr, "constructed abstract provider from config file");
    exec::ProviderChain::from_repr(provider_config, repr, factory)
    exec::ProviderChain::from_repr(&config.provider_config, repr, &config.factory)
}

#[derive(Debug)]
struct ChainProvider {
    config: Arc<Config>,
    chain: Option<Arc<exec::ProviderChain>>,
}

impl ChainProvider {
    async fn provide_credentials(&self) -> Result<Credentials, CredentialsError> {
        // Can't borrow `self` across an await point, or else we lose `Send` on the returned future
        let config = self.config.clone();
        let chain = self.chain.clone();

        if let Some(chain) = chain {
            let mut creds = match chain
                .base()
                .provide_credentials()
                .instrument(tracing::debug_span!("load_base_credentials"))
                .await
            {
                Ok(creds) => {
                    tracing::info!(creds = ?creds, "loaded base credentials");
                    creds
                }
                Err(e) => {
                    tracing::warn!(error = %DisplayErrorContext(&e), "failed to load base credentials");
                    return Err(CredentialsError::provider_error(e));
                }
            };
            for provider in chain.chain().iter() {
                let next_creds = provider
                    .credentials(creds, &config.sdk_config)
                    .instrument(tracing::debug_span!("load_assume_role", provider = ?provider))
                    .await;
                match next_creds {
                    Ok(next_creds) => {
                        tracing::info!(creds = ?next_creds, "loaded assume role credentials");
                        creds = next_creds
                    }
                    Err(e) => {
                        tracing::warn!(provider = ?provider, "failed to load assume role credentials");
                        return Err(CredentialsError::provider_error(e));
                    }
                }
            }
            Ok(creds)
        } else {
            Err(CredentialsError::not_loaded_no_source())
        }
    }
}

#[cfg(test)]
@@ -544,3 +599,144 @@ mod test {
    #[cfg(feature = "sso")]
    make_test!(sso_token);
}

#[cfg(all(test, feature = "sso"))]
mod sso_tests {
    use crate::{profile::credentials::Builder, provider_config::ProviderConfig};
    use aws_credential_types::provider::ProvideCredentials;
    use aws_sdk_sso::config::RuntimeComponents;
    use aws_smithy_runtime_api::client::{
        http::{
            HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings,
            SharedHttpConnector,
        },
        orchestrator::{HttpRequest, HttpResponse},
    };
    use aws_smithy_types::body::SdkBody;
    use aws_types::os_shim_internal::{Env, Fs};
    use std::collections::HashMap;

    // In order to preserve the SSO token cache, the inner provider must only
    // be created once, rather than once per credential resolution.
    #[tokio::test]
    async fn create_inner_provider_exactly_once() {
        #[derive(Debug)]
        struct ClientInner {
            expected_token: &'static str,
        }
        impl HttpConnector for ClientInner {
            fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
                assert_eq!(
                    self.expected_token,
                    request.headers().get("x-amz-sso_bearer_token").unwrap()
                );
                HttpConnectorFuture::ready(Ok(HttpResponse::new(
                    200.try_into().unwrap(),
                    SdkBody::from("{\"roleCredentials\":{\"accessKeyId\":\"ASIARTESTID\",\"secretAccessKey\":\"TESTSECRETKEY\",\"sessionToken\":\"TESTSESSIONTOKEN\",\"expiration\": 1651516560000}}"),
                )))
            }
        }
        #[derive(Debug)]
        struct Client {
            inner: SharedHttpConnector,
        }
        impl Client {
            fn new(expected_token: &'static str) -> Self {
                Self {
                    inner: SharedHttpConnector::new(ClientInner { expected_token }),
                }
            }
        }
        impl HttpClient for Client {
            fn http_connector(
                &self,
                _settings: &HttpConnectorSettings,
                _components: &RuntimeComponents,
            ) -> SharedHttpConnector {
                self.inner.clone()
            }
        }

        let fs = Fs::from_map({
            let mut map = HashMap::new();
            map.insert(
                "/home/.aws/config".to_string(),
                br#"
[profile default]
sso_session = dev
sso_account_id = 012345678901
sso_role_name = SampleRole
region = us-east-1

[sso-session dev]
sso_region = us-east-1
sso_start_url = https://d-abc123.awsapps.com/start
                "#
                .to_vec(),
            );
            map.insert(
                "/home/.aws/sso/cache/34c6fceca75e456f25e7e99531e2425c6c1de443.json".to_string(),
                br#"
                {
                    "accessToken": "secret-access-token",
                    "expiresAt": "2199-11-14T04:05:45Z",
                    "refreshToken": "secret-refresh-token",
                    "clientId": "ABCDEFG323242423121312312312312312",
                    "clientSecret": "ABCDE123",
                    "registrationExpiresAt": "2199-03-06T19:53:17Z",
                    "region": "us-east-1",
                    "startUrl": "https://d-abc123.awsapps.com/start"
                }
                "#
                .to_vec(),
            );
            map
        });
        let provider_config = ProviderConfig::empty()
            .with_fs(fs.clone())
            .with_env(Env::from_slice(&[("HOME", "/home")]))
            .with_http_client(Client::new("secret-access-token"));
        let provider = Builder::default().configure(&provider_config).build();

        let first_creds = provider.provide_credentials().await.unwrap();

        // Write to the token cache with an access token that won't match the fake client's
        // expected access token, and thus, won't return SSO credentials.
        fs.write(
            "/home/.aws/sso/cache/34c6fceca75e456f25e7e99531e2425c6c1de443.json",
            r#"
            {
                "accessToken": "NEW!!secret-access-token",
                "expiresAt": "2199-11-14T04:05:45Z",
                "refreshToken": "secret-refresh-token",
                "clientId": "ABCDEFG323242423121312312312312312",
                "clientSecret": "ABCDE123",
                "registrationExpiresAt": "2199-03-06T19:53:17Z",
                "region": "us-east-1",
                "startUrl": "https://d-abc123.awsapps.com/start"
            }
            "#,
        )
        .await
        .unwrap();

        // Loading credentials will still work since the SSOTokenProvider should have only
        // been created once, and thus, the correct token is still in an in-memory cache.
        let second_creds = provider
            .provide_credentials()
            .await
            .expect("used cached token instead of loading from the file system");
        assert_eq!(first_creds, second_creds);

        // Now create a new provider, which should use the new cached token value from the file system
        // since it won't have the in-memory cache. We do this just to verify that the FS mutation above
        // actually worked correctly.
        let provider_config = ProviderConfig::empty()
            .with_fs(fs.clone())
            .with_env(Env::from_slice(&[("HOME", "/home")]))
            .with_http_client(Client::new("NEW!!secret-access-token"));
        let provider = Builder::default().configure(&provider_config).build();
        let third_creds = provider.provide_credentials().await.unwrap();
        assert_eq!(second_creds, third_creds);
    }
}
+95 −0
Original line number Diff line number Diff line
@@ -26,3 +26,98 @@ pub mod region;
pub use credentials::ProfileFileCredentialsProvider;
#[doc(inline)]
pub use region::ProfileFileRegionProvider;

mod cell {
    use std::future::Future;
    use std::sync::{Arc, Mutex};
    use tokio::sync::OnceCell;

    /// Once cell with a result where the error can be taken.
    ///
    /// The profile providers need to cache their inner provider value (specifically for SSO)
    /// in order to preserve the SSO token cache. This wrapper around [`OnceCell`] allows
    /// for initializing the inner provider once in a way that if it fails, the error can
    /// be taken so that it doesn't need to implement `Clone`.
    #[derive(Debug)]
    pub(super) struct ErrorTakingOnceCell<T, E> {
        cell: OnceCell<Result<Arc<T>, Mutex<E>>>,
    }

    impl<T, E> ErrorTakingOnceCell<T, E> {
        pub(super) fn new() -> Self {
            Self {
                cell: OnceCell::new(),
            }
        }

        pub(super) async fn get_or_init<F, Fut>(
            &self,
            init: F,
            mut taken_error: E,
        ) -> Result<Arc<T>, E>
        where
            F: FnOnce() -> Fut,
            Fut: Future<Output = Result<T, E>>,
        {
            let init = || async move { (init)().await.map(Arc::new).map_err(Mutex::new) };
            match self.cell.get_or_init(init).await {
                Ok(value) => Ok(value.clone()),
                Err(err) => {
                    let mut locked = err.lock().unwrap();
                    std::mem::swap(&mut *locked, &mut taken_error);
                    Err(taken_error)
                }
            }
        }
    }

    #[cfg(test)]
    mod tests {
        use crate::profile::cell::ErrorTakingOnceCell;
        use std::sync::{
            atomic::{AtomicUsize, Ordering},
            Arc,
        };

        #[derive(Debug)]
        enum Error {
            InitError,
            Taken,
        }

        #[tokio::test]
        async fn taken_error() {
            let cell = ErrorTakingOnceCell::new();
            let calls = AtomicUsize::new(0);
            let init = || async {
                calls.fetch_add(1, Ordering::SeqCst);
                Result::<String, _>::Err(Error::InitError)
            };

            let result = cell.get_or_init(init, Error::Taken).await;
            assert!(matches!(result, Err(Error::InitError)));

            let result = cell.get_or_init(init, Error::Taken).await;
            assert!(matches!(result, Err(Error::Taken)));

            let result = cell.get_or_init(init, Error::Taken).await;
            assert!(matches!(result, Err(Error::Taken)));
            assert_eq!(1, calls.load(Ordering::SeqCst));
        }

        #[tokio::test]
        async fn value_initialized_once() {
            let cell = ErrorTakingOnceCell::new();
            let calls = AtomicUsize::new(0);
            let init = || async {
                calls.fetch_add(1, Ordering::SeqCst);
                Result::<_, Error>::Ok("test".to_string())
            };

            let original = cell.get_or_init(init, Error::Taken).await.unwrap();
            let next = cell.get_or_init(init, Error::Taken).await.unwrap();
            assert!(Arc::ptr_eq(&original, &next));
            assert_eq!(1, calls.load(Ordering::SeqCst));
        }
    }
}
+12 −3
Original line number Diff line number Diff line
@@ -86,7 +86,7 @@ pub mod error {
    /// Details for [`CredentialsError::CredentialsNotLoaded`]
    #[derive(Debug)]
    pub struct CredentialsNotLoaded {
        source: Box<dyn Error + Send + Sync + 'static>,
        source: Option<Box<dyn Error + Send + Sync + 'static>>,
    }

    /// Details for [`CredentialsError::ProviderTimedOut`]
@@ -160,10 +160,19 @@ pub mod error {
        /// that the provider was configured in some way, but certain settings were invalid.
        pub fn not_loaded(source: impl Into<Box<dyn Error + Send + Sync + 'static>>) -> Self {
            CredentialsError::CredentialsNotLoaded(CredentialsNotLoaded {
                source: source.into(),
                source: Some(source.into()),
            })
        }

        /// The credentials provider did not provide credentials
        ///
        /// This error indicates the credentials provider was not enable or no configuration was set.
        /// This contrasts with [`invalid_configuration`](CredentialsError::InvalidConfiguration), indicating
        /// that the provider was configured in some way, but certain settings were invalid.
        pub fn not_loaded_no_source() -> Self {
            CredentialsError::CredentialsNotLoaded(CredentialsNotLoaded { source: None })
        }

        /// An unexpected error occurred loading credentials from this provider
        ///
        /// Unhandled errors should not occur during normal operation and should be reserved for exceptional
@@ -227,7 +236,7 @@ pub mod error {
        fn source(&self) -> Option<&(dyn Error + 'static)> {
            match self {
                CredentialsError::CredentialsNotLoaded(details) => {
                    Some(details.source.as_ref() as _)
                    details.source.as_ref().map(|s| s.as_ref() as _)
                }
                CredentialsError::ProviderTimedOut(_) => None,
                CredentialsError::InvalidConfiguration(details) => {