diff --git a/aws/rust-runtime/aws-auth/Cargo.toml b/aws/rust-runtime/aws-auth/Cargo.toml index d04d7fb71882e2a8ca199210b30ae7d86536e3a5..da29e6d42d96260e3242f68e7b9b826767d3b7f4 100644 --- a/aws/rust-runtime/aws-auth/Cargo.toml +++ b/aws/rust-runtime/aws-auth/Cargo.toml @@ -1,17 +1,21 @@ [package] name = "aws-auth" version = "0.1.0" -authors = ["Russell Cohen "] +authors = ["AWS Rust SDK Team ", "Russell Cohen "] license = "Apache-2.0" edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] smithy-http = { path = "../../../rust-runtime/smithy-http" } +tokio = { version = "1", features = ["sync"] } +tracing = "0.1.25" zeroize = "1.2.0" [dev-dependencies] -http = "0.2.3" -tokio = { version = "1.0", features = ["rt", "macros"] } async-trait = "0.1.50" +env_logger = "*" +http = "0.2.3" +test-env-log = { version = "0.2.7", features = ["trace"] } +tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread", "test-util"] } +tracing-subscriber = { version = "0.2.16", features = ["fmt"] } diff --git a/aws/rust-runtime/aws-auth/src/credentials.rs b/aws/rust-runtime/aws-auth/src/credentials.rs index 21ff726323e63dda6423095fc47574f95c7dc191..00b935cd3d44442552b0ac124d8b5ca0c023ea1d 100644 --- a/aws/rust-runtime/aws-auth/src/credentials.rs +++ b/aws/rust-runtime/aws-auth/src/credentials.rs @@ -19,6 +19,7 @@ use zeroize::Zeroizing; #[derive(Clone)] pub struct Credentials(Arc); +#[derive(Clone)] struct Inner { access_key_id: Zeroizing, secret_access_key: Zeroizing, @@ -89,6 +90,10 @@ impl Credentials { self.0.expires_after } + pub fn expiry_mut(&mut self) -> &mut Option { + &mut Arc::make_mut(&mut self.0).expires_after + } + pub fn session_token(&self) -> Option<&str> { self.0.session_token.as_deref() } diff --git a/aws/rust-runtime/aws-auth/src/provider.rs b/aws/rust-runtime/aws-auth/src/provider.rs index 1adb80de9ae09df1addb3c0af17f5870689e449b..7741d714f85db776b8aa2b6189377ef7310bd41c 100644 --- a/aws/rust-runtime/aws-auth/src/provider.rs +++ b/aws/rust-runtime/aws-auth/src/provider.rs @@ -3,7 +3,10 @@ * SPDX-License-Identifier: Apache-2.0. */ +mod cache; pub mod env; +pub mod lazy_caching; +mod time; use crate::Credentials; use smithy_http::property_bag::PropertyBag; diff --git a/aws/rust-runtime/aws-auth/src/provider/cache.rs b/aws/rust-runtime/aws-auth/src/provider/cache.rs new file mode 100644 index 0000000000000000000000000000000000000000..93e8140716c9720c61858d4824b29f149890ee0d --- /dev/null +++ b/aws/rust-runtime/aws-auth/src/provider/cache.rs @@ -0,0 +1,133 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +use crate::provider::CredentialsResult; +use crate::Credentials; +use std::future::Future; +use std::sync::Arc; +use std::time::{Duration, SystemTime}; +use tokio::sync::{OnceCell, RwLock}; + +#[derive(Clone)] +pub(super) struct Cache { + /// Amount of time before the actual credential expiration time + /// where credentials are considered expired. + buffer_time: Duration, + value: Arc>>, +} + +impl Cache { + pub fn new(buffer_time: Duration) -> Cache { + Cache { + buffer_time, + value: Arc::new(RwLock::new(OnceCell::new())), + } + } + + #[cfg(test)] + async fn get(&self) -> Option { + self.value.read().await.get().cloned() + } + + /// Attempts to refresh the cached credentials with the given async future. + /// If multiple threads attempt to refresh at the same time, one of them will win, + /// and the others will await that thread's result rather than multiple refreshes occurring. + /// The function given to acquire a credentials future, `f`, will not be called + /// if another thread is chosen to load the credentials. + pub async fn get_or_load(&self, f: F) -> CredentialsResult + where + F: FnOnce() -> Fut, + Fut: Future, + { + let lock = self.value.read().await; + let future = lock.get_or_try_init(f); + future.await.map(|credentials| credentials.clone()) + } + + /// If the credentials are expired, clears the cache. Otherwise, yields the current credentials value. + pub async fn yield_or_clear_if_expired(&self, now: SystemTime) -> Option { + // Short-circuit if the credential is not expired + if let Some(credentials) = self.value.read().await.get() { + if !expired(credentials, self.buffer_time, now) { + return Some(credentials.clone()); + } + } + + // Acquire a write lock to clear the cache, but then once the lock is acquired, + // check again that the credential is not already cleared. If it has been cleared, + // then another thread is refreshing the cache by the time the write lock was acquired. + let mut lock = self.value.write().await; + if let Some(credentials) = lock.get() { + // Also check that we're clearing the expired credentials and not credentials + // that have been refreshed by another thread. + if expired(credentials, self.buffer_time, now) { + *lock = OnceCell::new(); + } + } + None + } +} + +fn expired(credentials: &Credentials, buffer_time: Duration, now: SystemTime) -> bool { + credentials + .expiry() + .map(|expiration| now >= (expiration - buffer_time)) + .expect("Cached credentials don't have an expiration time. This is a bug in aws-auth.") +} + +#[cfg(test)] +mod tests { + use super::{expired, Cache}; + use crate::Credentials; + use std::time::{Duration, SystemTime}; + + fn credentials(expired_secs: u64) -> Credentials { + Credentials::new("test", "test", None, Some(epoch_secs(expired_secs)), "test") + } + + fn epoch_secs(secs: u64) -> SystemTime { + SystemTime::UNIX_EPOCH + Duration::from_secs(secs) + } + + #[test] + fn expired_check() { + let creds = credentials(100); + assert!(expired(&creds, Duration::from_secs(10), epoch_secs(1000))); + assert!(expired(&creds, Duration::from_secs(10), epoch_secs(90))); + assert!(!expired(&creds, Duration::from_secs(10), epoch_secs(10))); + } + + #[test_env_log::test(tokio::test)] + async fn cache_clears_if_expired_only() { + let cache = Cache::new(Duration::from_secs(10)); + assert!(cache + .yield_or_clear_if_expired(epoch_secs(100)) + .await + .is_none()); + + cache + .get_or_load(|| async { Ok(credentials(100)) }) + .await + .unwrap(); + assert_eq!(Some(epoch_secs(100)), cache.get().await.unwrap().expiry()); + + // It should not clear the credentials if they're not expired + assert_eq!( + Some(epoch_secs(100)), + cache + .yield_or_clear_if_expired(epoch_secs(10)) + .await + .unwrap() + .expiry() + ); + + // It should clear the credentials if they're expired + assert!(cache + .yield_or_clear_if_expired(epoch_secs(500)) + .await + .is_none()); + assert!(cache.get().await.is_none()); + } +} diff --git a/aws/rust-runtime/aws-auth/src/provider/lazy_caching.rs b/aws/rust-runtime/aws-auth/src/provider/lazy_caching.rs new file mode 100644 index 0000000000000000000000000000000000000000..4bcbd65ee4e4dedefc6b7e05750a58ab9ea37d4b --- /dev/null +++ b/aws/rust-runtime/aws-auth/src/provider/lazy_caching.rs @@ -0,0 +1,380 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +use crate::provider::cache::Cache; +use crate::provider::time::TimeSource; +use crate::provider::{AsyncProvideCredentials, BoxFuture, CredentialsResult}; +use std::sync::Arc; +use std::time::Duration; +use tracing::{trace_span, Instrument}; + +const DEFAULT_REFRESH_TIMEOUT: Duration = Duration::from_secs(5); +const DEFAULT_CREDENTIAL_EXPIRATION: Duration = Duration::from_secs(15 * 60); +const DEFAULT_BUFFER_TIME: Duration = Duration::from_secs(10); + +// TODO: Implement async runtime-agnostic timeouts +// TODO: Add catch_unwind() to handle panics +// TODO: Update doc comment below once catch_unwind and timeouts are implemented +// TODO: Update warning not to use this in the STS example once it's prod ready + +/// `LazyCachingCredentialsProvider` implements [`AsyncProvideCredentials`] by caching +/// credentials that it loads by calling a user-provided [`AsyncProvideCredentials`] implementation. +/// +/// For example, you can provide an [`AsyncProvideCredentials`] implementation that calls +/// AWS STS's AssumeRole operation to get temporary credentials, and `LazyCachingCredentialsProvider` +/// will cache those credentials until they expire. +/// +/// # Note +/// +/// This is __NOT__ production ready yet. Timeouts and panic safety have not been implemented yet. +pub struct LazyCachingCredentialsProvider { + time: Box, + cache: Cache, + refresh: Arc, + _refresh_timeout: Duration, + default_credential_expiration: Duration, +} + +impl LazyCachingCredentialsProvider { + fn new( + time: impl TimeSource, + refresh: Arc, + refresh_timeout: Duration, + default_credential_expiration: Duration, + buffer_time: Duration, + ) -> Self { + LazyCachingCredentialsProvider { + time: Box::new(time), + cache: Cache::new(buffer_time), + refresh, + _refresh_timeout: refresh_timeout, + default_credential_expiration, + } + } + + /// Returns a new `Builder` that can be used to construct the `LazyCachingCredentialsProvider`. + pub fn builder() -> builder::Builder { + builder::Builder::new() + } +} + +impl AsyncProvideCredentials for LazyCachingCredentialsProvider { + fn provide_credentials<'a>(&'a self) -> BoxFuture<'a, CredentialsResult> + where + Self: 'a, + { + let now = self.time.now(); + let refresh = self.refresh.clone(); + let cache = self.cache.clone(); + let default_credential_expiration = self.default_credential_expiration; + + Box::pin(async move { + // Attempt to get cached credentials, or clear the cache if they're expired + if let Some(credentials) = cache.yield_or_clear_if_expired(now).await { + Ok(credentials) + } else { + // If we didn't get credentials from the cache, then we need to try and refresh. + // There may be other threads also refreshing simultaneously, but this is OK + // since the futures are not eagerly executed, and the cache will only run one + // of them. + let span = trace_span!("lazy_refresh_credentials"); + let future = refresh.provide_credentials(); + cache + .get_or_load(|| { + async move { + let mut credentials = future.await?; + // If the credentials don't have an expiration time, then create a default one + if credentials.expiry().is_none() { + *credentials.expiry_mut() = + Some(now + default_credential_expiration); + } + Ok(credentials) + } + // Only instrument the the actual refreshing future so that no span + // is opened if the cache decides not to execute it. + .instrument(span) + }) + .await + } + }) + } +} + +pub mod builder { + use crate::provider::lazy_caching::{ + LazyCachingCredentialsProvider, DEFAULT_BUFFER_TIME, DEFAULT_CREDENTIAL_EXPIRATION, + DEFAULT_REFRESH_TIMEOUT, + }; + use crate::provider::time::SystemTimeSource; + use crate::provider::AsyncProvideCredentials; + use std::sync::Arc; + use std::time::Duration; + + /// Builder for constructing a [`LazyCachingCredentialsProvider`]. + /// + /// # Example + /// + /// ``` + /// use aws_auth::Credentials; + /// use aws_auth::provider::async_provide_credentials_fn; + /// use aws_auth::provider::lazy_caching::LazyCachingCredentialsProvider; + /// use std::sync::Arc; + /// use std::time::Duration; + /// + /// let provider = LazyCachingCredentialsProvider::builder() + /// .refresh(async_provide_credentials_fn(|| async { + /// // An async process to retrieve credentials would go here: + /// Ok(Credentials::from_keys("example", "example", None)) + /// })) + /// .build(); + /// ``` + #[derive(Default)] + pub struct Builder { + refresh: Option>, + refresh_timeout: Option, + buffer_time: Option, + default_credential_expiration: Option, + } + + impl Builder { + pub fn new() -> Self { + Default::default() + } + + /// An implementation of [`AsyncProvideCredentials`] that will be used to refresh + /// the cached credentials once they're expired. + pub fn refresh(mut self, refresh: impl AsyncProvideCredentials + 'static) -> Self { + self.refresh = Some(Arc::new(refresh)); + self + } + + /// (Optional) Timeout for the given [`AsyncProvideCredentials`] implementation. + /// Defaults to 5 seconds. + pub fn refresh_timeout(mut self, timeout: Duration) -> Self { + self.refresh_timeout = Some(timeout); + unimplemented!("refresh_timeout hasn't been implemented yet") + } + + /// (Optional) Amount of time before the actual credential expiration time + /// where credentials are considered expired. For example, if credentials are expiring + /// in 15 minutes, and the buffer time is 10 seconds, then any requests made after + /// 14 minutes and 50 seconds will load new credentials. Defaults to 10 seconds. + pub fn buffer_time(mut self, buffer_time: Duration) -> Self { + self.buffer_time = Some(buffer_time); + self + } + + /// (Optional) Default expiration time to set on credentials if they don't + /// have an expiration time. This is only used if the given [`AsyncProvideCredentials`] + /// returns [`Credentials`](crate::Credentials) that don't have their `expiry` set. + /// This must be at least 15 minutes. + pub fn default_credential_expiration(mut self, duration: Duration) -> Self { + self.default_credential_expiration = Some(duration); + self + } + + /// Creates the [`LazyCachingCredentialsProvider`]. + pub fn build(self) -> LazyCachingCredentialsProvider { + let default_credential_expiration = self + .default_credential_expiration + .unwrap_or(DEFAULT_CREDENTIAL_EXPIRATION); + assert!( + default_credential_expiration >= DEFAULT_CREDENTIAL_EXPIRATION, + "default_credential_expiration must be at least 15 minutes" + ); + LazyCachingCredentialsProvider::new( + SystemTimeSource, + self.refresh.expect("refresh provider is required"), + self.refresh_timeout.unwrap_or(DEFAULT_REFRESH_TIMEOUT), + self.buffer_time.unwrap_or(DEFAULT_BUFFER_TIME), + default_credential_expiration, + ) + } + } +} + +#[cfg(test)] +mod tests { + use crate::provider::lazy_caching::{ + LazyCachingCredentialsProvider, TimeSource, DEFAULT_BUFFER_TIME, + DEFAULT_CREDENTIAL_EXPIRATION, DEFAULT_REFRESH_TIMEOUT, + }; + use crate::provider::{ + async_provide_credentials_fn, AsyncProvideCredentials, CredentialsError, CredentialsResult, + }; + use crate::Credentials; + use std::sync::{Arc, Mutex}; + use std::time::{Duration, SystemTime}; + use tracing::info; + + #[derive(Clone)] + struct TestTime { + time: Arc>, + } + + impl TestTime { + fn new(time: SystemTime) -> Self { + TestTime { + time: Arc::new(Mutex::new(time)), + } + } + + fn set(inner: &Arc>, time: SystemTime) { + *inner.lock().unwrap() = time; + } + } + + impl TimeSource for TestTime { + fn now(&self) -> SystemTime { + *self.time.lock().unwrap() + } + } + + fn test_provider( + time: T, + refresh_list: Vec, + ) -> LazyCachingCredentialsProvider { + let refresh_list = Arc::new(Mutex::new(refresh_list)); + LazyCachingCredentialsProvider::new( + time, + Arc::new(async_provide_credentials_fn(move || { + let list = refresh_list.clone(); + async move { + let next = list.lock().unwrap().remove(0); + info!("refreshing the credentials to {:?}", next); + next + } + })), + DEFAULT_REFRESH_TIMEOUT, + DEFAULT_CREDENTIAL_EXPIRATION, + DEFAULT_BUFFER_TIME, + ) + } + + fn epoch_secs(secs: u64) -> SystemTime { + SystemTime::UNIX_EPOCH + Duration::from_secs(secs) + } + + fn credentials(expired_secs: u64) -> Credentials { + Credentials::new("test", "test", None, Some(epoch_secs(expired_secs)), "test") + } + + async fn expect_creds(expired_secs: u64, provider: &LazyCachingCredentialsProvider) { + let creds = provider + .provide_credentials() + .await + .expect("expected credentials"); + assert_eq!(Some(epoch_secs(expired_secs)), creds.expiry()); + } + + #[test_env_log::test(tokio::test)] + async fn initial_populate_credentials() { + let time = TestTime::new(epoch_secs(100)); + let refresh = Arc::new(async_provide_credentials_fn(|| async { + info!("refreshing the credentials"); + Ok(credentials(1000)) + })); + let provider = LazyCachingCredentialsProvider::new( + time, + refresh, + DEFAULT_REFRESH_TIMEOUT, + DEFAULT_CREDENTIAL_EXPIRATION, + DEFAULT_BUFFER_TIME, + ); + assert_eq!( + epoch_secs(1000), + provider + .provide_credentials() + .await + .unwrap() + .expiry() + .unwrap() + ); + } + + #[test_env_log::test(tokio::test)] + async fn refresh_expired_credentials() { + let time = TestTime::new(epoch_secs(100)); + let time_inner = time.time.clone(); + let provider = test_provider( + time, + vec![ + Ok(credentials(1000)), + Ok(credentials(2000)), + Ok(credentials(3000)), + ], + ); + + expect_creds(1000, &provider).await; + expect_creds(1000, &provider).await; + TestTime::set(&time_inner, epoch_secs(1500)); + expect_creds(2000, &provider).await; + expect_creds(2000, &provider).await; + TestTime::set(&time_inner, epoch_secs(2500)); + expect_creds(3000, &provider).await; + expect_creds(3000, &provider).await; + } + + #[test_env_log::test(tokio::test)] + async fn refresh_failed_error() { + let time = TestTime::new(epoch_secs(100)); + let time_inner = time.time.clone(); + let provider = test_provider( + time, + vec![ + Ok(credentials(1000)), + Err(CredentialsError::CredentialsNotLoaded), + ], + ); + + expect_creds(1000, &provider).await; + TestTime::set(&time_inner, epoch_secs(1500)); + assert!(provider.provide_credentials().await.is_err()); + } + + #[test_env_log::test] + fn refresh_retrieve_contention() { + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(16) + .build() + .unwrap(); + + let time = TestTime::new(epoch_secs(0)); + let time_inner = time.time.clone(); + let provider = Arc::new(test_provider( + time, + vec![ + Ok(credentials(500)), + Ok(credentials(1500)), + Ok(credentials(2500)), + Ok(credentials(3500)), + Ok(credentials(4500)), + ], + )); + + for i in 0..4 { + let mut tasks = Vec::new(); + for j in 0..50 { + let provider = provider.clone(); + let time_inner = time_inner.clone(); + tasks.push(rt.spawn(async move { + let now = epoch_secs(i * 1000 + (4 * j)); + TestTime::set(&time_inner, now); + + let creds = provider.provide_credentials().await.unwrap(); + assert!( + creds.expiry().unwrap() >= now, + "{:?} >= {:?}", + creds.expiry(), + now + ); + })); + } + for task in tasks { + rt.block_on(task).unwrap(); + } + } + } +} diff --git a/aws/rust-runtime/aws-auth/src/provider/time.rs b/aws/rust-runtime/aws-auth/src/provider/time.rs new file mode 100644 index 0000000000000000000000000000000000000000..6ca7e9ec29b836a8abfc2f9820c70405b61cd23a --- /dev/null +++ b/aws/rust-runtime/aws-auth/src/provider/time.rs @@ -0,0 +1,20 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +use std::time::SystemTime; + +/// Allows us to abstract time for tests. +pub(super) trait TimeSource: Send + Sync + 'static { + fn now(&self) -> SystemTime; +} + +#[derive(Copy, Clone)] +pub(super) struct SystemTimeSource; + +impl TimeSource for SystemTimeSource { + fn now(&self) -> SystemTime { + SystemTime::now() + } +} diff --git a/aws/sdk/examples/sts/src/bin/credentials-provider.rs b/aws/sdk/examples/sts/src/bin/credentials-provider.rs index 6558c8a879b0eae28704902c4f9f972bbed1198b..58a7698e648fee4fbc0fd5a7e17443ae05086a8f 100644 --- a/aws/sdk/examples/sts/src/bin/credentials-provider.rs +++ b/aws/sdk/examples/sts/src/bin/credentials-provider.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0. */ +use aws_auth::provider::lazy_caching::LazyCachingCredentialsProvider; use aws_auth::provider::{async_provide_credentials_fn, CredentialsError}; use sts::Credentials; @@ -13,10 +14,10 @@ async fn main() -> Result<(), dynamodb::Error> { tracing_subscriber::fmt::init(); let client = sts::Client::from_env(); - // NOTE: Do not use this in production! This will grab new credentials for every request. - // A high quality caching credential provider implementation is in the roadmap. - let dynamodb_conf = dynamodb::Config::builder() - .credentials_provider(async_provide_credentials_fn(move || { + // NOTE: Do not use LazyCachingCredentialsProvider in production yet! + // It hasn't implemented timeout or panic safety yet. + let sts_provider = LazyCachingCredentialsProvider::builder() + .refresh(async_provide_credentials_fn(move || { let client = client.clone(); async move { let session_token = client @@ -40,6 +41,10 @@ async fn main() -> Result<(), dynamodb::Error> { })) .build(); + let dynamodb_conf = dynamodb::Config::builder() + .credentials_provider(sts_provider) + .build(); + let client = dynamodb::Client::from_conf(dynamodb_conf); println!("tables: {:?}", client.list_tables().send().await?); Ok(())