Unverified Commit 70999dad authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Implement timeouts for LazyCachingCredentialsProvider (#595)

* Implement timeouts for LazyCachingCredentialsProvider

* Rename refresh to reload

* Update CHANGELOG

* Fix clippy

* CR feedback

* Add note about panic on `LazyCachedCredentialsProvider` builder

* Fix doc comment code reference
parent 03ae7cc6
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
## vNext (Month Day Year)
**New This Week**
- :tada: Add LazyCachingCredentialsProvider to aws-auth for use with expiring credentials, such as STS AssumeRole. Update STS example to use this new provider (#578, #595)
- :bug: Correctly encode HTTP Checksums using base64 instead of hex. Fixes aws-sdk-rust#164. (#615)
- Update SDK gradle build logic to use gradle properties (#620)
- (When complete) Add profile file provider for region (#594, #xyz)
+7 −0
Original line number Diff line number Diff line
@@ -5,8 +5,14 @@ authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>", "Russell Cohen <rcoh@a
license = "Apache-2.0"
edition = "2018"

[features]
rt-tokio = ["smithy-async/rt-tokio"]
default = ["rt-tokio"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
pin-project = "1"
smithy-async = { path = "../../../rust-runtime/smithy-async", default-features = false }
smithy-http = { path = "../../../rust-runtime/smithy-http" }
aws-types = { path = "../aws-types" }
tokio = { version = "1", features = ["sync"] }
@@ -20,3 +26,4 @@ http = "0.2.3"
test-env-log = { version = "0.2.7", features = ["trace"] }
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread", "test-util"] }
tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
smithy-async = { path = "../../../rust-runtime/smithy-async", features = ["rt-tokio"] }
+17 −0
Original line number Diff line number Diff line
@@ -3,6 +3,16 @@
 * SPDX-License-Identifier: Apache-2.0.
 */

//! AWS credential providers, generic caching provider implementations, and traits to implement custom providers.
//!
//! Credentials providers acquire AWS credentials from environment variables, files,
//! or calls to AWS services such as STS. Custom credential provider implementations can
//! be provided by implementing [`ProvideCredentials`] for synchronous use-cases, or
//! [`AsyncProvideCredentials`] for async use-cases. Generic credential caching implementations,
//! for example,
//! [`LazyCachingCredentialsProvider`](crate::provider::lazy_caching::LazyCachingCredentialsProvider),
//! are also provided as part of this module.

mod cache;
pub mod env;
pub mod lazy_caching;
@@ -16,11 +26,13 @@ use std::fmt::{Debug, Display, Formatter};
use std::future::{self, Future};
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;

#[derive(Debug)]
#[non_exhaustive]
pub enum CredentialsError {
    CredentialsNotLoaded,
    ProviderTimedOut(Duration),
    Unhandled(Box<dyn Error + Send + Sync + 'static>),
}

@@ -28,6 +40,11 @@ impl Display for CredentialsError {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        match self {
            CredentialsError::CredentialsNotLoaded => write!(f, "CredentialsNotLoaded"),
            CredentialsError::ProviderTimedOut(d) => write!(
                f,
                "Credentials provider timed out after {} seconds",
                d.as_secs()
            ),
            CredentialsError::Unhandled(err) => write!(f, "{}", err),
        }
    }
+2 −0
Original line number Diff line number Diff line
@@ -3,6 +3,8 @@
 * SPDX-License-Identifier: Apache-2.0.
 */

//! Credential provider implementation that pulls from environment variables

use crate::provider::{CredentialsError, ProvideCredentials};
use crate::Credentials;
use aws_types::os_shim_internal::Env;
+93 −47
Original line number Diff line number Diff line
@@ -3,53 +3,51 @@
 * SPDX-License-Identifier: Apache-2.0.
 */

//! Lazy, caching, credentials provider implementation

use crate::provider::cache::Cache;
use crate::provider::time::TimeSource;
use crate::provider::{AsyncProvideCredentials, BoxFuture, CredentialsResult};
use crate::provider::{AsyncProvideCredentials, BoxFuture, CredentialsError, CredentialsResult};
use smithy_async::future::timeout::Timeout;
use smithy_async::rt::sleep::AsyncSleep;
use std::sync::Arc;
use std::time::Duration;
use tracing::{trace_span, Instrument};

const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_LOAD_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_CREDENTIAL_EXPIRATION: Duration = Duration::from_secs(15 * 60);
const DEFAULT_BUFFER_TIME: Duration = Duration::from_secs(10);

// TODO: Implement async runtime-agnostic timeouts
// TODO: Add catch_unwind() to handle panics
// TODO: Update doc comment below once catch_unwind and timeouts are implemented
// TODO: Update warning not to use this in the STS example once it's prod ready

/// `LazyCachingCredentialsProvider` implements [`AsyncProvideCredentials`] by caching
/// credentials that it loads by calling a user-provided [`AsyncProvideCredentials`] implementation.
///
/// For example, you can provide an [`AsyncProvideCredentials`] implementation that calls
/// AWS STS's AssumeRole operation to get temporary credentials, and `LazyCachingCredentialsProvider`
/// will cache those credentials until they expire.
///
/// # Note
///
/// This is __NOT__ production ready yet. Timeouts and panic safety have not been implemented yet.
pub struct LazyCachingCredentialsProvider {
    time: Box<dyn TimeSource>,
    sleeper: Box<dyn AsyncSleep>,
    cache: Cache,
    refresh: Arc<dyn AsyncProvideCredentials>,
    _refresh_timeout: Duration,
    loader: Arc<dyn AsyncProvideCredentials>,
    load_timeout: Duration,
    default_credential_expiration: Duration,
}

impl LazyCachingCredentialsProvider {
    fn new(
        time: impl TimeSource,
        refresh: Arc<dyn AsyncProvideCredentials>,
        refresh_timeout: Duration,
        sleeper: Box<dyn AsyncSleep>,
        loader: Arc<dyn AsyncProvideCredentials>,
        load_timeout: Duration,
        default_credential_expiration: Duration,
        buffer_time: Duration,
    ) -> Self {
        LazyCachingCredentialsProvider {
            time: Box::new(time),
            sleeper,
            cache: Cache::new(buffer_time),
            refresh,
            _refresh_timeout: refresh_timeout,
            loader,
            load_timeout,
            default_credential_expiration,
        }
    }
@@ -66,7 +64,9 @@ impl AsyncProvideCredentials for LazyCachingCredentialsProvider {
        Self: 'a,
    {
        let now = self.time.now();
        let refresh = self.refresh.clone();
        let loader = self.loader.clone();
        let timeout_future = self.sleeper.sleep(self.load_timeout);
        let load_timeout = self.load_timeout;
        let cache = self.cache.clone();
        let default_credential_expiration = self.default_credential_expiration;

@@ -75,16 +75,18 @@ impl AsyncProvideCredentials for LazyCachingCredentialsProvider {
            if let Some(credentials) = cache.yield_or_clear_if_expired(now).await {
                Ok(credentials)
            } else {
                // If we didn't get credentials from the cache, then we need to try and refresh.
                // There may be other threads also refreshing simultaneously, but this is OK
                // If we didn't get credentials from the cache, then we need to try and load.
                // There may be other threads also loading simultaneously, but this is OK
                // since the futures are not eagerly executed, and the cache will only run one
                // of them.
                let span = trace_span!("lazy_refresh_credentials");
                let future = refresh.provide_credentials();
                let span = trace_span!("lazy_load_credentials");
                let future = Timeout::new(loader.provide_credentials(), timeout_future);
                cache
                    .get_or_load(|| {
                        async move {
                            let mut credentials = future.await?;
                            let mut credentials = future
                                .await
                                .map_err(|_| CredentialsError::ProviderTimedOut(load_timeout))??;
                            // If the credentials don't have an expiration time, then create a default one
                            if credentials.expiry().is_none() {
                                *credentials.expiry_mut() =
@@ -92,7 +94,7 @@ impl AsyncProvideCredentials for LazyCachingCredentialsProvider {
                            }
                            Ok(credentials)
                        }
                        // Only instrument the the actual refreshing future so that no span
                        // Only instrument the the actual load future so that no span
                        // is opened if the cache decides not to execute it.
                        .instrument(span)
                    })
@@ -105,10 +107,11 @@ impl AsyncProvideCredentials for LazyCachingCredentialsProvider {
pub mod builder {
    use crate::provider::lazy_caching::{
        LazyCachingCredentialsProvider, DEFAULT_BUFFER_TIME, DEFAULT_CREDENTIAL_EXPIRATION,
        DEFAULT_REFRESH_TIMEOUT,
        DEFAULT_LOAD_TIMEOUT,
    };
    use crate::provider::time::SystemTimeSource;
    use crate::provider::AsyncProvideCredentials;
    use smithy_async::rt::sleep::{default_async_sleep, AsyncSleep};
    use std::sync::Arc;
    use std::time::Duration;

@@ -124,7 +127,7 @@ pub mod builder {
    /// use std::time::Duration;
    ///
    /// let provider = LazyCachingCredentialsProvider::builder()
    ///     .refresh(async_provide_credentials_fn(|| async {
    ///     .load(async_provide_credentials_fn(|| async {
    ///         // An async process to retrieve credentials would go here:
    ///         Ok(Credentials::from_keys("example", "example", None))
    ///     }))
@@ -132,8 +135,9 @@ pub mod builder {
    /// ```
    #[derive(Default)]
    pub struct Builder {
        refresh: Option<Arc<dyn AsyncProvideCredentials>>,
        refresh_timeout: Option<Duration>,
        sleep: Option<Box<dyn AsyncSleep>>,
        load: Option<Arc<dyn AsyncProvideCredentials>>,
        load_timeout: Option<Duration>,
        buffer_time: Option<Duration>,
        default_credential_expiration: Option<Duration>,
    }
@@ -143,18 +147,27 @@ pub mod builder {
            Default::default()
        }

        /// An implementation of [`AsyncProvideCredentials`] that will be used to refresh
        /// An implementation of [`AsyncProvideCredentials`] that will be used to load
        /// the cached credentials once they're expired.
        pub fn refresh(mut self, refresh: impl AsyncProvideCredentials + 'static) -> Self {
            self.refresh = Some(Arc::new(refresh));
        pub fn load(mut self, loader: impl AsyncProvideCredentials + 'static) -> Self {
            self.load = Some(Arc::new(loader));
            self
        }

        /// Implementation of [`AsyncSleep`] to use for timeouts. This enables use of
        /// the `LazyCachingCredentialsProvider` with other async runtimes.
        /// If using Tokio as the async runtime, this should be set to an instance of
        /// [`TokioSleep`](smithy_async::rt::sleep::TokioSleep).
        pub fn sleep(mut self, sleep: impl AsyncSleep + 'static) -> Self {
            self.sleep = Some(Box::new(sleep));
            self
        }

        /// (Optional) Timeout for the given [`AsyncProvideCredentials`] implementation.
        /// Defaults to 5 seconds.
        pub fn refresh_timeout(mut self, timeout: Duration) -> Self {
            self.refresh_timeout = Some(timeout);
            unimplemented!("refresh_timeout hasn't been implemented yet")
        pub fn load_timeout(mut self, timeout: Duration) -> Self {
            self.load_timeout = Some(timeout);
            self
        }

        /// (Optional) Amount of time before the actual credential expiration time
@@ -176,6 +189,11 @@ pub mod builder {
        }

        /// Creates the [`LazyCachingCredentialsProvider`].
        ///
        /// ## Note:
        /// This will panic if no `sleep` implementation is given and if no default crate features
        /// are used. By default, the [`TokioSleep`](smithy_async::rt::sleep::TokioSleep)
        /// implementation will be set automatically.
        pub fn build(self) -> LazyCachingCredentialsProvider {
            let default_credential_expiration = self
                .default_credential_expiration
@@ -186,8 +204,11 @@ pub mod builder {
            );
            LazyCachingCredentialsProvider::new(
                SystemTimeSource,
                self.refresh.expect("refresh provider is required"),
                self.refresh_timeout.unwrap_or(DEFAULT_REFRESH_TIMEOUT),
                self.sleep.unwrap_or_else(|| {
                    default_async_sleep().expect("no default sleep implementation available")
                }),
                self.load.expect("load implementation is required"),
                self.load_timeout.unwrap_or(DEFAULT_LOAD_TIMEOUT),
                self.buffer_time.unwrap_or(DEFAULT_BUFFER_TIME),
                default_credential_expiration,
            )
@@ -199,12 +220,13 @@ pub mod builder {
mod tests {
    use crate::provider::lazy_caching::{
        LazyCachingCredentialsProvider, TimeSource, DEFAULT_BUFFER_TIME,
        DEFAULT_CREDENTIAL_EXPIRATION, DEFAULT_REFRESH_TIMEOUT,
        DEFAULT_CREDENTIAL_EXPIRATION, DEFAULT_LOAD_TIMEOUT,
    };
    use crate::provider::{
        async_provide_credentials_fn, AsyncProvideCredentials, CredentialsError, CredentialsResult,
    };
    use crate::Credentials;
    use smithy_async::rt::sleep::TokioSleep;
    use std::sync::{Arc, Mutex};
    use std::time::{Duration, SystemTime};
    use tracing::info;
@@ -234,20 +256,21 @@ mod tests {

    fn test_provider<T: TimeSource>(
        time: T,
        refresh_list: Vec<CredentialsResult>,
        load_list: Vec<CredentialsResult>,
    ) -> LazyCachingCredentialsProvider {
        let refresh_list = Arc::new(Mutex::new(refresh_list));
        let load_list = Arc::new(Mutex::new(load_list));
        LazyCachingCredentialsProvider::new(
            time,
            Box::new(TokioSleep::new()),
            Arc::new(async_provide_credentials_fn(move || {
                let list = refresh_list.clone();
                let list = load_list.clone();
                async move {
                    let next = list.lock().unwrap().remove(0);
                    info!("refreshing the credentials to {:?}", next);
                    next
                }
            })),
            DEFAULT_REFRESH_TIMEOUT,
            DEFAULT_LOAD_TIMEOUT,
            DEFAULT_CREDENTIAL_EXPIRATION,
            DEFAULT_BUFFER_TIME,
        )
@@ -272,14 +295,15 @@ mod tests {
    #[test_env_log::test(tokio::test)]
    async fn initial_populate_credentials() {
        let time = TestTime::new(epoch_secs(100));
        let refresh = Arc::new(async_provide_credentials_fn(|| async {
        let loader = Arc::new(async_provide_credentials_fn(|| async {
            info!("refreshing the credentials");
            Ok(credentials(1000))
        }));
        let provider = LazyCachingCredentialsProvider::new(
            time,
            refresh,
            DEFAULT_REFRESH_TIMEOUT,
            Box::new(TokioSleep::new()),
            loader,
            DEFAULT_LOAD_TIMEOUT,
            DEFAULT_CREDENTIAL_EXPIRATION,
            DEFAULT_BUFFER_TIME,
        );
@@ -295,7 +319,7 @@ mod tests {
    }

    #[test_env_log::test(tokio::test)]
    async fn refresh_expired_credentials() {
    async fn reload_expired_credentials() {
        let time = TestTime::new(epoch_secs(100));
        let time_inner = time.time.clone();
        let provider = test_provider(
@@ -318,7 +342,7 @@ mod tests {
    }

    #[test_env_log::test(tokio::test)]
    async fn refresh_failed_error() {
    async fn load_failed_error() {
        let time = TestTime::new(epoch_secs(100));
        let time_inner = time.time.clone();
        let provider = test_provider(
@@ -335,8 +359,9 @@ mod tests {
    }

    #[test_env_log::test]
    fn refresh_retrieve_contention() {
    fn load_contention() {
        let rt = tokio::runtime::Builder::new_multi_thread()
            .enable_time()
            .worker_threads(16)
            .build()
            .unwrap();
@@ -377,4 +402,25 @@ mod tests {
            }
        }
    }

    #[test_env_log::test(tokio::test)]
    async fn load_timeout() {
        let time = TestTime::new(epoch_secs(100));
        let provider = LazyCachingCredentialsProvider::new(
            time,
            Box::new(TokioSleep::new()),
            Arc::new(async_provide_credentials_fn(|| async {
                tokio::time::sleep(Duration::from_millis(10)).await;
                Ok(credentials(1000))
            })),
            Duration::from_millis(5),
            DEFAULT_CREDENTIAL_EXPIRATION,
            DEFAULT_BUFFER_TIME,
        );

        assert!(matches!(
            provider.provide_credentials().await,
            Err(CredentialsError::ProviderTimedOut(_))
        ));
    }
}
Loading