Unverified Commit 11029ed3 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Implement initial lazy caching credentials provider (#578)

* Implement initial lazy caching credentials provider

* Rename TimeProvider to TimeSource

* Move TimeSource to its own module

* Eliminate Inner layer and add expiry_mut to Credentials

* Move Cache to its own module and fix multithreading issue

* Add comments

* Make refresh_timeout unimplemented

* Combine Provider with LazyCachingCredentialsProvider

* CR feedback
parent f08a60cf
Loading
Loading
Loading
Loading
+8 −4
Original line number Diff line number Diff line
[package]
name = "aws-auth"
version = "0.1.0"
authors = ["Russell Cohen <rcoh@amazon.com>"]
authors = ["AWS Rust SDK Team <aws-sdk-rust@amazon.com>", "Russell Cohen <rcoh@amazon.com>"]
license = "Apache-2.0"
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
smithy-http = { path = "../../../rust-runtime/smithy-http" }
tokio = { version = "1", features = ["sync"] }
tracing = "0.1.25"
zeroize = "1.2.0"

[dev-dependencies]
http = "0.2.3"
tokio = { version = "1.0", features = ["rt", "macros"] }
async-trait = "0.1.50"
env_logger = "*"
http = "0.2.3"
test-env-log = { version = "0.2.7", features = ["trace"] }
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread", "test-util"] }
tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
+5 −0
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ use zeroize::Zeroizing;
#[derive(Clone)]
pub struct Credentials(Arc<Inner>);

#[derive(Clone)]
struct Inner {
    access_key_id: Zeroizing<String>,
    secret_access_key: Zeroizing<String>,
@@ -89,6 +90,10 @@ impl Credentials {
        self.0.expires_after
    }

    pub fn expiry_mut(&mut self) -> &mut Option<SystemTime> {
        &mut Arc::make_mut(&mut self.0).expires_after
    }

    pub fn session_token(&self) -> Option<&str> {
        self.0.session_token.as_deref()
    }
+3 −0
Original line number Diff line number Diff line
@@ -3,7 +3,10 @@
 * SPDX-License-Identifier: Apache-2.0.
 */

mod cache;
pub mod env;
pub mod lazy_caching;
mod time;

use crate::Credentials;
use smithy_http::property_bag::PropertyBag;
+133 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

use crate::provider::CredentialsResult;
use crate::Credentials;
use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::{OnceCell, RwLock};

#[derive(Clone)]
pub(super) struct Cache {
    /// Amount of time before the actual credential expiration time
    /// where credentials are considered expired.
    buffer_time: Duration,
    value: Arc<RwLock<OnceCell<Credentials>>>,
}

impl Cache {
    pub fn new(buffer_time: Duration) -> Cache {
        Cache {
            buffer_time,
            value: Arc::new(RwLock::new(OnceCell::new())),
        }
    }

    #[cfg(test)]
    async fn get(&self) -> Option<Credentials> {
        self.value.read().await.get().cloned()
    }

    /// Attempts to refresh the cached credentials with the given async future.
    /// If multiple threads attempt to refresh at the same time, one of them will win,
    /// and the others will await that thread's result rather than multiple refreshes occurring.
    /// The function given to acquire a credentials future, `f`, will not be called
    /// if another thread is chosen to load the credentials.
    pub async fn get_or_load<F, Fut>(&self, f: F) -> CredentialsResult
    where
        F: FnOnce() -> Fut,
        Fut: Future<Output = CredentialsResult>,
    {
        let lock = self.value.read().await;
        let future = lock.get_or_try_init(f);
        future.await.map(|credentials| credentials.clone())
    }

    /// If the credentials are expired, clears the cache. Otherwise, yields the current credentials value.
    pub async fn yield_or_clear_if_expired(&self, now: SystemTime) -> Option<Credentials> {
        // Short-circuit if the credential is not expired
        if let Some(credentials) = self.value.read().await.get() {
            if !expired(credentials, self.buffer_time, now) {
                return Some(credentials.clone());
            }
        }

        // Acquire a write lock to clear the cache, but then once the lock is acquired,
        // check again that the credential is not already cleared. If it has been cleared,
        // then another thread is refreshing the cache by the time the write lock was acquired.
        let mut lock = self.value.write().await;
        if let Some(credentials) = lock.get() {
            // Also check that we're clearing the expired credentials and not credentials
            // that have been refreshed by another thread.
            if expired(credentials, self.buffer_time, now) {
                *lock = OnceCell::new();
            }
        }
        None
    }
}

fn expired(credentials: &Credentials, buffer_time: Duration, now: SystemTime) -> bool {
    credentials
        .expiry()
        .map(|expiration| now >= (expiration - buffer_time))
        .expect("Cached credentials don't have an expiration time. This is a bug in aws-auth.")
}

#[cfg(test)]
mod tests {
    use super::{expired, Cache};
    use crate::Credentials;
    use std::time::{Duration, SystemTime};

    fn credentials(expired_secs: u64) -> Credentials {
        Credentials::new("test", "test", None, Some(epoch_secs(expired_secs)), "test")
    }

    fn epoch_secs(secs: u64) -> SystemTime {
        SystemTime::UNIX_EPOCH + Duration::from_secs(secs)
    }

    #[test]
    fn expired_check() {
        let creds = credentials(100);
        assert!(expired(&creds, Duration::from_secs(10), epoch_secs(1000)));
        assert!(expired(&creds, Duration::from_secs(10), epoch_secs(90)));
        assert!(!expired(&creds, Duration::from_secs(10), epoch_secs(10)));
    }

    #[test_env_log::test(tokio::test)]
    async fn cache_clears_if_expired_only() {
        let cache = Cache::new(Duration::from_secs(10));
        assert!(cache
            .yield_or_clear_if_expired(epoch_secs(100))
            .await
            .is_none());

        cache
            .get_or_load(|| async { Ok(credentials(100)) })
            .await
            .unwrap();
        assert_eq!(Some(epoch_secs(100)), cache.get().await.unwrap().expiry());

        // It should not clear the credentials if they're not expired
        assert_eq!(
            Some(epoch_secs(100)),
            cache
                .yield_or_clear_if_expired(epoch_secs(10))
                .await
                .unwrap()
                .expiry()
        );

        // It should clear the credentials if they're expired
        assert!(cache
            .yield_or_clear_if_expired(epoch_secs(500))
            .await
            .is_none());
        assert!(cache.get().await.is_none());
    }
}
+380 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

use crate::provider::cache::Cache;
use crate::provider::time::TimeSource;
use crate::provider::{AsyncProvideCredentials, BoxFuture, CredentialsResult};
use std::sync::Arc;
use std::time::Duration;
use tracing::{trace_span, Instrument};

const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_CREDENTIAL_EXPIRATION: Duration = Duration::from_secs(15 * 60);
const DEFAULT_BUFFER_TIME: Duration = Duration::from_secs(10);

// TODO: Implement async runtime-agnostic timeouts
// TODO: Add catch_unwind() to handle panics
// TODO: Update doc comment below once catch_unwind and timeouts are implemented
// TODO: Update warning not to use this in the STS example once it's prod ready

/// `LazyCachingCredentialsProvider` implements [`AsyncProvideCredentials`] by caching
/// credentials that it loads by calling a user-provided [`AsyncProvideCredentials`] implementation.
///
/// For example, you can provide an [`AsyncProvideCredentials`] implementation that calls
/// AWS STS's AssumeRole operation to get temporary credentials, and `LazyCachingCredentialsProvider`
/// will cache those credentials until they expire.
///
/// # Note
///
/// This is __NOT__ production ready yet. Timeouts and panic safety have not been implemented yet.
pub struct LazyCachingCredentialsProvider {
    time: Box<dyn TimeSource>,
    cache: Cache,
    refresh: Arc<dyn AsyncProvideCredentials>,
    _refresh_timeout: Duration,
    default_credential_expiration: Duration,
}

impl LazyCachingCredentialsProvider {
    fn new(
        time: impl TimeSource,
        refresh: Arc<dyn AsyncProvideCredentials>,
        refresh_timeout: Duration,
        default_credential_expiration: Duration,
        buffer_time: Duration,
    ) -> Self {
        LazyCachingCredentialsProvider {
            time: Box::new(time),
            cache: Cache::new(buffer_time),
            refresh,
            _refresh_timeout: refresh_timeout,
            default_credential_expiration,
        }
    }

    /// Returns a new `Builder` that can be used to construct the `LazyCachingCredentialsProvider`.
    pub fn builder() -> builder::Builder {
        builder::Builder::new()
    }
}

impl AsyncProvideCredentials for LazyCachingCredentialsProvider {
    fn provide_credentials<'a>(&'a self) -> BoxFuture<'a, CredentialsResult>
    where
        Self: 'a,
    {
        let now = self.time.now();
        let refresh = self.refresh.clone();
        let cache = self.cache.clone();
        let default_credential_expiration = self.default_credential_expiration;

        Box::pin(async move {
            // Attempt to get cached credentials, or clear the cache if they're expired
            if let Some(credentials) = cache.yield_or_clear_if_expired(now).await {
                Ok(credentials)
            } else {
                // If we didn't get credentials from the cache, then we need to try and refresh.
                // There may be other threads also refreshing simultaneously, but this is OK
                // since the futures are not eagerly executed, and the cache will only run one
                // of them.
                let span = trace_span!("lazy_refresh_credentials");
                let future = refresh.provide_credentials();
                cache
                    .get_or_load(|| {
                        async move {
                            let mut credentials = future.await?;
                            // If the credentials don't have an expiration time, then create a default one
                            if credentials.expiry().is_none() {
                                *credentials.expiry_mut() =
                                    Some(now + default_credential_expiration);
                            }
                            Ok(credentials)
                        }
                        // Only instrument the the actual refreshing future so that no span
                        // is opened if the cache decides not to execute it.
                        .instrument(span)
                    })
                    .await
            }
        })
    }
}

pub mod builder {
    use crate::provider::lazy_caching::{
        LazyCachingCredentialsProvider, DEFAULT_BUFFER_TIME, DEFAULT_CREDENTIAL_EXPIRATION,
        DEFAULT_REFRESH_TIMEOUT,
    };
    use crate::provider::time::SystemTimeSource;
    use crate::provider::AsyncProvideCredentials;
    use std::sync::Arc;
    use std::time::Duration;

    /// Builder for constructing a [`LazyCachingCredentialsProvider`].
    ///
    /// # Example
    ///
    /// ```
    /// use aws_auth::Credentials;
    /// use aws_auth::provider::async_provide_credentials_fn;
    /// use aws_auth::provider::lazy_caching::LazyCachingCredentialsProvider;
    /// use std::sync::Arc;
    /// use std::time::Duration;
    ///
    /// let provider = LazyCachingCredentialsProvider::builder()
    ///     .refresh(async_provide_credentials_fn(|| async {
    ///         // An async process to retrieve credentials would go here:
    ///         Ok(Credentials::from_keys("example", "example", None))
    ///     }))
    ///     .build();
    /// ```
    #[derive(Default)]
    pub struct Builder {
        refresh: Option<Arc<dyn AsyncProvideCredentials>>,
        refresh_timeout: Option<Duration>,
        buffer_time: Option<Duration>,
        default_credential_expiration: Option<Duration>,
    }

    impl Builder {
        pub fn new() -> Self {
            Default::default()
        }

        /// An implementation of [`AsyncProvideCredentials`] that will be used to refresh
        /// the cached credentials once they're expired.
        pub fn refresh(mut self, refresh: impl AsyncProvideCredentials + 'static) -> Self {
            self.refresh = Some(Arc::new(refresh));
            self
        }

        /// (Optional) Timeout for the given [`AsyncProvideCredentials`] implementation.
        /// Defaults to 5 seconds.
        pub fn refresh_timeout(mut self, timeout: Duration) -> Self {
            self.refresh_timeout = Some(timeout);
            unimplemented!("refresh_timeout hasn't been implemented yet")
        }

        /// (Optional) Amount of time before the actual credential expiration time
        /// where credentials are considered expired. For example, if credentials are expiring
        /// in 15 minutes, and the buffer time is 10 seconds, then any requests made after
        /// 14 minutes and 50 seconds will load new credentials. Defaults to 10 seconds.
        pub fn buffer_time(mut self, buffer_time: Duration) -> Self {
            self.buffer_time = Some(buffer_time);
            self
        }

        /// (Optional) Default expiration time to set on credentials if they don't
        /// have an expiration time. This is only used if the given [`AsyncProvideCredentials`]
        /// returns [`Credentials`](crate::Credentials) that don't have their `expiry` set.
        /// This must be at least 15 minutes.
        pub fn default_credential_expiration(mut self, duration: Duration) -> Self {
            self.default_credential_expiration = Some(duration);
            self
        }

        /// Creates the [`LazyCachingCredentialsProvider`].
        pub fn build(self) -> LazyCachingCredentialsProvider {
            let default_credential_expiration = self
                .default_credential_expiration
                .unwrap_or(DEFAULT_CREDENTIAL_EXPIRATION);
            assert!(
                default_credential_expiration >= DEFAULT_CREDENTIAL_EXPIRATION,
                "default_credential_expiration must be at least 15 minutes"
            );
            LazyCachingCredentialsProvider::new(
                SystemTimeSource,
                self.refresh.expect("refresh provider is required"),
                self.refresh_timeout.unwrap_or(DEFAULT_REFRESH_TIMEOUT),
                self.buffer_time.unwrap_or(DEFAULT_BUFFER_TIME),
                default_credential_expiration,
            )
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::provider::lazy_caching::{
        LazyCachingCredentialsProvider, TimeSource, DEFAULT_BUFFER_TIME,
        DEFAULT_CREDENTIAL_EXPIRATION, DEFAULT_REFRESH_TIMEOUT,
    };
    use crate::provider::{
        async_provide_credentials_fn, AsyncProvideCredentials, CredentialsError, CredentialsResult,
    };
    use crate::Credentials;
    use std::sync::{Arc, Mutex};
    use std::time::{Duration, SystemTime};
    use tracing::info;

    #[derive(Clone)]
    struct TestTime {
        time: Arc<Mutex<SystemTime>>,
    }

    impl TestTime {
        fn new(time: SystemTime) -> Self {
            TestTime {
                time: Arc::new(Mutex::new(time)),
            }
        }

        fn set(inner: &Arc<Mutex<SystemTime>>, time: SystemTime) {
            *inner.lock().unwrap() = time;
        }
    }

    impl TimeSource for TestTime {
        fn now(&self) -> SystemTime {
            *self.time.lock().unwrap()
        }
    }

    fn test_provider<T: TimeSource>(
        time: T,
        refresh_list: Vec<CredentialsResult>,
    ) -> LazyCachingCredentialsProvider {
        let refresh_list = Arc::new(Mutex::new(refresh_list));
        LazyCachingCredentialsProvider::new(
            time,
            Arc::new(async_provide_credentials_fn(move || {
                let list = refresh_list.clone();
                async move {
                    let next = list.lock().unwrap().remove(0);
                    info!("refreshing the credentials to {:?}", next);
                    next
                }
            })),
            DEFAULT_REFRESH_TIMEOUT,
            DEFAULT_CREDENTIAL_EXPIRATION,
            DEFAULT_BUFFER_TIME,
        )
    }

    fn epoch_secs(secs: u64) -> SystemTime {
        SystemTime::UNIX_EPOCH + Duration::from_secs(secs)
    }

    fn credentials(expired_secs: u64) -> Credentials {
        Credentials::new("test", "test", None, Some(epoch_secs(expired_secs)), "test")
    }

    async fn expect_creds(expired_secs: u64, provider: &LazyCachingCredentialsProvider) {
        let creds = provider
            .provide_credentials()
            .await
            .expect("expected credentials");
        assert_eq!(Some(epoch_secs(expired_secs)), creds.expiry());
    }

    #[test_env_log::test(tokio::test)]
    async fn initial_populate_credentials() {
        let time = TestTime::new(epoch_secs(100));
        let refresh = Arc::new(async_provide_credentials_fn(|| async {
            info!("refreshing the credentials");
            Ok(credentials(1000))
        }));
        let provider = LazyCachingCredentialsProvider::new(
            time,
            refresh,
            DEFAULT_REFRESH_TIMEOUT,
            DEFAULT_CREDENTIAL_EXPIRATION,
            DEFAULT_BUFFER_TIME,
        );
        assert_eq!(
            epoch_secs(1000),
            provider
                .provide_credentials()
                .await
                .unwrap()
                .expiry()
                .unwrap()
        );
    }

    #[test_env_log::test(tokio::test)]
    async fn refresh_expired_credentials() {
        let time = TestTime::new(epoch_secs(100));
        let time_inner = time.time.clone();
        let provider = test_provider(
            time,
            vec![
                Ok(credentials(1000)),
                Ok(credentials(2000)),
                Ok(credentials(3000)),
            ],
        );

        expect_creds(1000, &provider).await;
        expect_creds(1000, &provider).await;
        TestTime::set(&time_inner, epoch_secs(1500));
        expect_creds(2000, &provider).await;
        expect_creds(2000, &provider).await;
        TestTime::set(&time_inner, epoch_secs(2500));
        expect_creds(3000, &provider).await;
        expect_creds(3000, &provider).await;
    }

    #[test_env_log::test(tokio::test)]
    async fn refresh_failed_error() {
        let time = TestTime::new(epoch_secs(100));
        let time_inner = time.time.clone();
        let provider = test_provider(
            time,
            vec![
                Ok(credentials(1000)),
                Err(CredentialsError::CredentialsNotLoaded),
            ],
        );

        expect_creds(1000, &provider).await;
        TestTime::set(&time_inner, epoch_secs(1500));
        assert!(provider.provide_credentials().await.is_err());
    }

    #[test_env_log::test]
    fn refresh_retrieve_contention() {
        let rt = tokio::runtime::Builder::new_multi_thread()
            .worker_threads(16)
            .build()
            .unwrap();

        let time = TestTime::new(epoch_secs(0));
        let time_inner = time.time.clone();
        let provider = Arc::new(test_provider(
            time,
            vec![
                Ok(credentials(500)),
                Ok(credentials(1500)),
                Ok(credentials(2500)),
                Ok(credentials(3500)),
                Ok(credentials(4500)),
            ],
        ));

        for i in 0..4 {
            let mut tasks = Vec::new();
            for j in 0..50 {
                let provider = provider.clone();
                let time_inner = time_inner.clone();
                tasks.push(rt.spawn(async move {
                    let now = epoch_secs(i * 1000 + (4 * j));
                    TestTime::set(&time_inner, now);

                    let creds = provider.provide_credentials().await.unwrap();
                    assert!(
                        creds.expiry().unwrap() >= now,
                        "{:?} >= {:?}",
                        creds.expiry(),
                        now
                    );
                }));
            }
            for task in tasks {
                rt.block_on(task).unwrap();
            }
        }
    }
}
Loading