Unverified Commit 40f46621 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Replace credentials cache with identity cache (#3077)

This PR replaces the credentials cache with the new identity cache, and
adds config validation via the `SharedConfigValidator` runtime component
and `ValidateConfig` trait.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 66a3acf5
Loading
Loading
Loading
Loading
+18 −0
Original line number Diff line number Diff line
@@ -450,3 +450,21 @@ message = """
references = ["smithy-rs#3076"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "all" }
author = "ysaito1001"

[[aws-sdk-rust]]
message = "**This change has [detailed upgrade guidance](https://github.com/awslabs/aws-sdk-rust/discussions/923).** <br><br>The AWS credentials cache has been replaced with a more generic identity cache."
references = ["smithy-rs#3077"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "jdisanti"

[[smithy-rs]]
message = "**Behavior Break!** Identities for auth are now cached by default. See the `Config` builder's `identity_cache()` method docs for an example of how to disable this caching."
references = ["smithy-rs#3077"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"

[[smithy-rs]]
message = "Clients now have a default async sleep implementation so that one does not need to be specified if you're using Tokio."
references = ["smithy-rs#3071"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "client" }
author = "jdisanti"
+4 −1
Original line number Diff line number Diff line
@@ -12,15 +12,18 @@ allowed_external_types = [
   "aws_smithy_async::rt::sleep::SharedAsyncSleep",
   "aws_smithy_async::time::SharedTimeSource",
   "aws_smithy_async::time::TimeSource",
   "aws_smithy_types::body::SdkBody",
   "aws_smithy_http::endpoint",
   "aws_smithy_http::endpoint::error::InvalidEndpointError",
   "aws_smithy_http::result::SdkError",
   "aws_smithy_runtime::client::identity::cache::IdentityCache",
   "aws_smithy_runtime::client::identity::cache::lazy::LazyCacheBuilder",
   "aws_smithy_runtime_api::client::dns::ResolveDns",
   "aws_smithy_runtime_api::client::dns::SharedDnsResolver",
   "aws_smithy_runtime_api::client::http::HttpClient",
   "aws_smithy_runtime_api::client::http::SharedHttpClient",
   "aws_smithy_runtime_api::client::identity::ResolveCachedIdentity",
   "aws_smithy_runtime_api::client::identity::ResolveIdentity",
   "aws_smithy_types::body::SdkBody",
   "aws_smithy_types::retry",
   "aws_smithy_types::retry::*",
   "aws_smithy_types::timeout",
+4 −1
Original line number Diff line number Diff line
@@ -433,7 +433,6 @@ impl Builder {
            .runtime_plugin(common_plugin.clone())
            .runtime_plugin(TokenRuntimePlugin::new(
                common_plugin,
                config.time_source(),
                self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
            ))
            .with_connection_poisoning()
@@ -748,6 +747,7 @@ pub(crate) mod test {
    /// Tokens are refreshed up to 120 seconds early to avoid using an expired token.
    #[tokio::test]
    async fn token_refresh_buffer() {
        let _logs = capture_test_logs();
        let (_, http_client) = mock_imds_client(vec![
            ReplayEvent::new(
                token_request("http://[fd00:ec2::254]", 600),
@@ -785,11 +785,14 @@ pub(crate) mod test {
            .token_ttl(Duration::from_secs(600))
            .build();

        tracing::info!("resp1 -----------------------------------------------------------");
        let resp1 = client.get("/latest/metadata").await.expect("success");
        // now the cached credential has expired
        time_source.advance(Duration::from_secs(400));
        tracing::info!("resp2 -----------------------------------------------------------");
        let resp2 = client.get("/latest/metadata").await.expect("success");
        time_source.advance(Duration::from_secs(150));
        tracing::info!("resp3 -----------------------------------------------------------");
        let resp3 = client.get("/latest/metadata").await.expect("success");
        http_client.assert_requests_match(&[]);
        assert_eq!("test-imds-output1", resp1.as_ref());
+46 −38
Original line number Diff line number Diff line
@@ -14,10 +14,11 @@
//! - Retry token loading when it fails
//! - Attach the token to the request in the `x-aws-ec2-metadata-token` header

use crate::identity::IdentityCache;
use crate::imds::client::error::{ImdsError, TokenError, TokenErrorKind};
use aws_credential_types::cache::ExpiringCache;
use aws_smithy_async::time::SharedTimeSource;
use aws_smithy_runtime::client::orchestrator::operation::Operation;
use aws_smithy_runtime::expiring_cache::ExpiringCache;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver;
use aws_smithy_runtime_api::client::auth::{
@@ -50,6 +51,12 @@ const X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS: &str = "x-aws-ec2-metadata-token-ttl
const X_AWS_EC2_METADATA_TOKEN: &str = "x-aws-ec2-metadata-token";
const IMDS_TOKEN_AUTH_SCHEME: AuthSchemeId = AuthSchemeId::new(X_AWS_EC2_METADATA_TOKEN);

#[derive(Debug)]
struct TtlToken {
    value: HeaderValue,
    ttl: Duration,
}

/// IMDS Token
#[derive(Clone)]
struct Token {
@@ -76,20 +83,18 @@ pub(super) struct TokenRuntimePlugin {
}

impl TokenRuntimePlugin {
    pub(super) fn new(
        common_plugin: SharedRuntimePlugin,
        time_source: SharedTimeSource,
        token_ttl: Duration,
    ) -> Self {
    pub(super) fn new(common_plugin: SharedRuntimePlugin, token_ttl: Duration) -> Self {
        Self {
            components: RuntimeComponentsBuilder::new("TokenRuntimePlugin")
                .with_auth_scheme(TokenAuthScheme::new())
                .with_auth_scheme_option_resolver(Some(StaticAuthSchemeOptionResolver::new(vec![
                    IMDS_TOKEN_AUTH_SCHEME,
                ])))
                // The TokenResolver has a cache of its own, so don't use identity caching
                .with_identity_cache(Some(IdentityCache::no_cache()))
                .with_identity_resolver(
                    IMDS_TOKEN_AUTH_SCHEME,
                    TokenResolver::new(common_plugin, time_source, token_ttl),
                    TokenResolver::new(common_plugin, token_ttl),
                ),
        }
    }
@@ -107,8 +112,7 @@ impl RuntimePlugin for TokenRuntimePlugin {
#[derive(Debug)]
struct TokenResolverInner {
    cache: ExpiringCache<Token, ImdsError>,
    refresh: Operation<(), Token, TokenError>,
    time_source: SharedTimeSource,
    refresh: Operation<(), TtlToken, TokenError>,
}

#[derive(Clone, Debug)]
@@ -117,11 +121,7 @@ struct TokenResolver {
}

impl TokenResolver {
    fn new(
        common_plugin: SharedRuntimePlugin,
        time_source: SharedTimeSource,
        token_ttl: Duration,
    ) -> Self {
    fn new(common_plugin: SharedRuntimePlugin, token_ttl: Duration) -> Self {
        Self {
            inner: Arc::new(TokenResolverInner {
                cache: ExpiringCache::new(TOKEN_REFRESH_BUFFER),
@@ -141,26 +141,26 @@ impl TokenResolver {
                            .try_into()
                            .unwrap())
                    })
                    .deserializer({
                        let time_source = time_source.clone();
                        move |response| {
                            let now = time_source.now();
                            parse_token_response(response, now)
                                .map_err(OrchestratorError::operation)
                        }
                    .deserializer(move |response| {
                        parse_token_response(response).map_err(OrchestratorError::operation)
                    })
                    .build(),
                time_source,
            }),
        }
    }

    async fn get_token(&self) -> Result<(Token, SystemTime), ImdsError> {
        self.inner
            .refresh
            .invoke(())
            .await
    async fn get_token(
        &self,
        time_source: SharedTimeSource,
    ) -> Result<(Token, SystemTime), ImdsError> {
        let result = self.inner.refresh.invoke(()).await;
        let now = time_source.now();
        result
            .map(|token| {
                let token = Token {
                    value: token.value,
                    expiry: now + token.ttl,
                };
                let expiry = token.expiry;
                (token, expiry)
            })
@@ -168,7 +168,7 @@ impl TokenResolver {
    }
}

fn parse_token_response(response: &HttpResponse, now: SystemTime) -> Result<Token, TokenError> {
fn parse_token_response(response: &HttpResponse) -> Result<TtlToken, TokenError> {
    match response.status().as_u16() {
        400 => return Err(TokenErrorKind::InvalidParameters.into()),
        403 => return Err(TokenErrorKind::Forbidden.into()),
@@ -187,30 +187,38 @@ fn parse_token_response(response: &HttpResponse, now: SystemTime) -> Result<Toke
        .map_err(|_| TokenErrorKind::InvalidTtl)?
        .parse()
        .map_err(|_parse_error| TokenErrorKind::InvalidTtl)?;
    Ok(Token {
    Ok(TtlToken {
        value,
        expiry: now + Duration::from_secs(ttl),
        ttl: Duration::from_secs(ttl),
    })
}

impl ResolveIdentity for TokenResolver {
    fn resolve_identity<'a>(
        &'a self,
        _components: &'a RuntimeComponents,
        components: &'a RuntimeComponents,
        _config_bag: &'a ConfigBag,
    ) -> IdentityFuture<'a> {
        let time_source = components
            .time_source()
            .expect("time source required for IMDS token caching");
        IdentityFuture::new(async {
            let preloaded_token = self
                .inner
                .cache
                .yield_or_clear_if_expired(self.inner.time_source.now())
                .await;
            let now = time_source.now();
            let preloaded_token = self.inner.cache.yield_or_clear_if_expired(now).await;
            let token = match preloaded_token {
                Some(token) => Ok(token),
                Some(token) => {
                    tracing::trace!(
                        buffer_time=?TOKEN_REFRESH_BUFFER,
                        expiration=?token.expiry,
                        now=?now,
                        "loaded IMDS token from cache");
                    Ok(token)
                }
                None => {
                    tracing::debug!("IMDS token cache miss");
                    self.inner
                        .cache
                        .get_or_load(|| async { self.get_token().await })
                        .get_or_load(|| async { self.get_token(time_source).await })
                        .await
                }
            }?;
+39 −53
Original line number Diff line number Diff line
@@ -102,6 +102,12 @@ pub use aws_types::{
/// Load default sources for all configuration with override support
pub use loader::ConfigLoader;

/// Types for configuring identity caching.
pub mod identity {
    pub use aws_smithy_runtime::client::identity::IdentityCache;
    pub use aws_smithy_runtime::client::identity::LazyCacheBuilder;
}

#[allow(dead_code)]
const PKG_VERSION: &str = env!("CARGO_PKG_VERSION");

@@ -155,11 +161,11 @@ mod loader {
    use crate::meta::region::ProvideRegion;
    use crate::profile::profile_file::ProfileFiles;
    use crate::provider_config::ProviderConfig;
    use aws_credential_types::cache::CredentialsCache;
    use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
    use aws_smithy_async::rt::sleep::{default_async_sleep, AsyncSleep, SharedAsyncSleep};
    use aws_smithy_async::time::{SharedTimeSource, TimeSource};
    use aws_smithy_runtime_api::client::http::HttpClient;
    use aws_smithy_runtime_api::client::identity::{ResolveCachedIdentity, SharedIdentityCache};
    use aws_smithy_runtime_api::shared::IntoShared;
    use aws_smithy_types::retry::RetryConfig;
    use aws_smithy_types::timeout::TimeoutConfig;
@@ -189,7 +195,7 @@ mod loader {
    #[derive(Default, Debug)]
    pub struct ConfigLoader {
        app_name: Option<AppName>,
        credentials_cache: Option<CredentialsCache>,
        identity_cache: Option<SharedIdentityCache>,
        credentials_provider: CredentialsProviderOption,
        endpoint_url: Option<String>,
        region: Option<Box<dyn ProvideRegion>>,
@@ -333,22 +339,45 @@ mod loader {
            self
        }

        /// Override the credentials cache used to build [`SdkConfig`](aws_types::SdkConfig).
        /// The credentials cache has been replaced. Use the identity_cache() method instead. See its rustdoc for an example.
        #[deprecated(
            note = "The credentials cache has been replaced. Use the identity_cache() method instead for equivalent functionality. See its rustdoc for an example."
        )]
        pub fn credentials_cache(self) -> Self {
            self
        }

        /// Override the identity cache used to build [`SdkConfig`](aws_types::SdkConfig).
        ///
        /// The identity cache caches AWS credentials and SSO tokens. By default, a lazy cache is used
        /// that will load credentials upon first request, cache them, and then reload them during
        /// another request when they are close to expiring.
        ///
        /// # Examples
        ///
        /// Override the credentials cache but load the default value for region:
        /// Change a setting on the default lazy caching implementation:
        /// ```no_run
        /// # use aws_credential_types::cache::CredentialsCache;
        /// use aws_config::identity::IdentityCache;
        /// use std::time::Duration;
        ///
        /// # async fn create_config() {
        /// let config = aws_config::from_env()
        ///     .credentials_cache(CredentialsCache::lazy())
        ///     .identity_cache(
        ///         IdentityCache::lazy()
        ///             // Change the load timeout to 10 seconds.
        ///             // Note: there are other timeouts that could trigger if the load timeout is too long.
        ///             .load_timeout(Duration::from_secs(10))
        ///             .build()
        ///     )
        ///     .load()
        ///     .await;
        /// # }
        /// ```
        pub fn credentials_cache(mut self, credentials_cache: CredentialsCache) -> Self {
            self.credentials_cache = Some(credentials_cache);
        pub fn identity_cache(
            mut self,
            identity_cache: impl ResolveCachedIdentity + 'static,
        ) -> Self {
            self.identity_cache = Some(identity_cache.into_shared());
            self
        }

@@ -656,17 +685,6 @@ mod loader {
                CredentialsProviderOption::ExplicitlyUnset => None,
            };

            let credentials_cache = if credentials_provider.is_some() {
                Some(self.credentials_cache.unwrap_or_else(|| {
                    let mut builder =
                        CredentialsCache::lazy_builder().time_source(conf.time_source());
                    builder.set_sleep_impl(conf.sleep_impl());
                    builder.into_credentials_cache()
                }))
            } else {
                None
            };

            let mut builder = SdkConfig::builder()
                .region(region)
                .retry_config(retry_config)
@@ -675,7 +693,7 @@ mod loader {

            builder.set_http_client(self.http_client);
            builder.set_app_name(app_name);
            builder.set_credentials_cache(credentials_cache);
            builder.set_identity_cache(self.identity_cache);
            builder.set_credentials_provider(credentials_provider);
            builder.set_sleep_impl(sleep_impl);
            builder.set_endpoint_url(self.endpoint_url);
@@ -705,13 +723,11 @@ mod loader {
        use crate::{from_env, ConfigLoader};
        use aws_credential_types::provider::ProvideCredentials;
        use aws_smithy_async::rt::sleep::TokioSleep;
        use aws_smithy_async::time::{StaticTimeSource, TimeSource};
        use aws_smithy_runtime::client::http::test_util::{infallible_client_fn, NeverClient};
        use aws_types::app_name::AppName;
        use aws_types::os_shim_internal::{Env, Fs};
        use std::sync::atomic::{AtomicUsize, Ordering};
        use std::sync::Arc;
        use std::time::{SystemTime, UNIX_EPOCH};
        use tracing_test::traced_test;

        #[tokio::test]
@@ -800,7 +816,7 @@ mod loader {
        #[tokio::test]
        async fn disable_default_credentials() {
            let config = from_env().no_credentials().load().await;
            assert!(config.credentials_cache().is_none());
            assert!(config.identity_cache().is_none());
            assert!(config.credentials_provider().is_none());
        }

@@ -827,35 +843,5 @@ mod loader {
            let num_requests = num_requests.load(Ordering::Relaxed);
            assert!(num_requests > 0, "{}", num_requests);
        }

        #[tokio::test]
        async fn time_source_is_passed() {
            #[derive(Debug)]
            struct PanicTs;
            impl TimeSource for PanicTs {
                fn now(&self) -> SystemTime {
                    panic!("timesource-was-used")
                }
            }
            let config = from_env()
                .sleep_impl(InstantSleep)
                .time_source(StaticTimeSource::new(UNIX_EPOCH))
                .http_client(no_traffic_client())
                .load()
                .await;
            // assert that the innards contain the customized fields
            for inner in ["InstantSleep", "StaticTimeSource"] {
                assert!(
                    format!("{:#?}", config.credentials_cache()).contains(inner),
                    "{:#?}",
                    config.credentials_cache()
                );
                assert!(
                    format!("{:#?}", config.credentials_provider()).contains(inner),
                    "{:#?}",
                    config.credentials_cache()
                );
            }
        }
    }
}
Loading