Unverified Commit 273a2698 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

IMDS Client (#701)



* feat(aws-config): Add IMDS Client

* Fix docs

* Add two missing error implementations + usage example

* Fix clippy errors

* Fix doctest in os_shim_internal

* Fix a missed occurence of test connection assertions

* fix doc link

* remove link to private item

* Update changelog, make depedencies optional

* CR feedback cleanups

* Apply suggestions from code review

Co-authored-by: default avatarJohn DiSanti <jdisanti@amazon.com>

* Update aws/rust-runtime/aws-config/src/imds/client.rs

* refactor cache into its own module

* fix clippy warning

* add additional internal changelog entries

* endpoint mode comparsion should be case insensitive

* update docs to remove references to credentials

* Fix bug in token that doubled the buffer period

Co-authored-by: default avatarJohn DiSanti <jdisanti@amazon.com>
parent aa883141
Loading
Loading
Loading
Loading
+6 −3
Original line number Diff line number Diff line
vNext (Month, Day, Year)
vNext (Month Day, Year)
=======================

- Update `Receiver`s to be `Send` (aws-sdk-rust#224)
**New This Week**
- Add IMDS client to `aws-config` (#701)
- Add `TimeSource` to `aws_types::os_shim_internal` (#701)
- User agent construction is now `const fn` (#701)
- Update event stream `Receiver`s to be `Send` (#702, #aws-sdk-rust#224)

v0.23 (September 14th, 2021)
=======================
+3 −2
Original line number Diff line number Diff line
vNext (Month Day, Year)
=======================

- Update `Receiver`s to be `Send` (aws-sdk-rust#224)
**New This Week**
- Add IMDS client to `aws-config`
- Update event stream `Receiver`s to be `Send` (aws-sdk-rust#224)

v0.0.18-alpha (September 14th, 2021)
=======================
+14 −3
Original line number Diff line number Diff line
@@ -9,12 +9,13 @@ exclude = ["test-data/*"]
[features]
default-provider = ["profile", "imds", "meta", "sts", "environment"]
profile = ["sts", "web-identity-token", "meta", "environment"]
# note: IMDS currently unsupported
imds = []
meta = ["tokio/sync"]
imds = ["profile", "smithy-http", "smithy-types", "smithy-http-tower", "tower", "aws-http", "meta"]
environment = ["meta"]
sts = ["aws-sdk-sts", "aws-hyper"]
web-identity-token = ["sts"]
web-identity-token = ["sts", "profile"]

# SSO is not supported
sso = []

rustls = ["smithy-client/rustls"]
@@ -34,6 +35,16 @@ aws-sdk-sts = { path = "../../sdk/build/aws-sdk/sts", optional = true }
# TODO: remove when middleware stacks are moved inside of clients directly
aws-hyper = { path = "../../sdk/build/aws-sdk/aws-hyper", optional = true }

# imds
smithy-http = { path = "../../sdk/build/aws-sdk/smithy-http", optional = true }
smithy-types = { path = "../../sdk/build/aws-sdk/smithy-types", optional = true }
smithy-http-tower = { path = "../../sdk/build/aws-sdk/smithy-http-tower", optional = true }
tower = { version = "0.4.8", optional = true }
aws-http = { path = "../../sdk/build/aws-sdk/aws-http", optional = true }
bytes = "1.1.0"
http = "0.2.4"


[dev-dependencies]
futures-util = "0.3.16"
tracing-test = "0.1.0"
+15 −0
Original line number Diff line number Diff line
use aws_config::imds::Client;
use std::error::Error;

/// IMDSv2 client usage example
///
/// The IMDS client is used with `aws-config` to load credentials and regions, however, you can also
/// use the client directly. This example demonstrates loading the instance-id from IMDS. More
/// fetures of IMDS can be found [here](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html)
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
    let imds = Client::builder().build().await?;
    let instance_id = imds.get("/latest/meta-data/instance-id").await?;
    println!("current instance id: {}", instance_id);
    Ok(())
}
+51 −30
Original line number Diff line number Diff line
@@ -3,31 +3,54 @@
 * SPDX-License-Identifier: Apache-2.0.
 */

use aws_types::credentials::CredentialsError;
use aws_types::{credentials, Credentials};
//! Expiry-aware cache
//!
//! [`ExpiringCache`] implements two important features:
//! 1. Respect expiry of contents
//! 2. Deduplicate load requests to prevent thundering herds when no value is present.

use std::future::Future;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::{OnceCell, RwLock};

#[derive(Clone, Debug)]
pub(super) struct Cache {
    /// Amount of time before the actual credential expiration time
    /// where credentials are considered expired.
#[derive(Debug)]
pub(crate) struct ExpiringCache<T, E> {
    /// Amount of time before the actual expiration time
    /// when the value is considered expired.
    buffer_time: Duration,
    value: Arc<RwLock<OnceCell<(Credentials, SystemTime)>>>,
    value: Arc<RwLock<OnceCell<(T, SystemTime)>>>,
    _phantom: PhantomData<E>,
}

impl<T, E> Clone for ExpiringCache<T, E> {
    fn clone(&self) -> Self {
        Self {
            buffer_time: self.buffer_time,
            value: self.value.clone(),
            _phantom: Default::default(),
        }
    }
}

impl Cache {
    pub fn new(buffer_time: Duration) -> Cache {
        Cache {
impl<T, E> ExpiringCache<T, E>
where
    T: Clone,
{
    pub fn new(buffer_time: Duration) -> Self {
        ExpiringCache {
            buffer_time,
            value: Arc::new(RwLock::new(OnceCell::new())),
            _phantom: Default::default(),
        }
    }

    #[cfg(test)]
    async fn get(&self) -> Option<Credentials> {
    async fn get(&self) -> Option<T>
    where
        T: Clone,
    {
        self.value
            .read()
            .await
@@ -36,39 +59,37 @@ impl Cache {
            .map(|(creds, _expiry)| creds)
    }

    /// Attempts to refresh the cached credentials with the given async future.
    /// Attempts to refresh the cached value with the given future.
    /// If multiple threads attempt to refresh at the same time, one of them will win,
    /// and the others will await that thread's result rather than multiple refreshes occurring.
    /// The function given to acquire a credentials future, `f`, will not be called
    /// if another thread is chosen to load the credentials.
    pub async fn get_or_load<F, Fut>(&self, f: F) -> credentials::Result
    /// The function given to acquire a value future, `f`, will not be called
    /// if another thread is chosen to load the value.
    pub async fn get_or_load<F, Fut>(&self, f: F) -> Result<T, E>
    where
        F: FnOnce() -> Fut,
        Fut: Future<Output = Result<(Credentials, SystemTime), CredentialsError>>,
        Fut: Future<Output = Result<(T, SystemTime), E>>,
    {
        let lock = self.value.read().await;
        let future = lock.get_or_try_init(f);
        future
            .await
            .map(|(credentials, _expiry)| credentials.clone())
        future.await.map(|(value, _expiry)| value.clone())
    }

    /// If the credentials are expired, clears the cache. Otherwise, yields the current credentials value.
    pub async fn yield_or_clear_if_expired(&self, now: SystemTime) -> Option<Credentials> {
        // Short-circuit if the credential is not expired
        if let Some((credentials, expiry)) = self.value.read().await.get() {
    /// If the value is expired, clears the cache. Otherwise, yields the current value.
    pub async fn yield_or_clear_if_expired(&self, now: SystemTime) -> Option<T> {
        // Short-circuit if the value is not expired
        if let Some((value, expiry)) = self.value.read().await.get() {
            if !expired(*expiry, self.buffer_time, now) {
                return Some(credentials.clone());
                return Some(value.clone());
            }
        }

        // Acquire a write lock to clear the cache, but then once the lock is acquired,
        // check again that the credential is not already cleared. If it has been cleared,
        // check again that the value is not already cleared. If it has been cleared,
        // then another thread is refreshing the cache by the time the write lock was acquired.
        let mut lock = self.value.write().await;
        if let Some((_credentials, expiration)) = lock.get() {
            // Also check that we're clearing the expired credentials and not credentials
            // that have been refreshed by another thread.
        if let Some((_value, expiration)) = lock.get() {
            // Also check that we're clearing the expired value and not a value
            // that has been refreshed by another thread.
            if expired(*expiration, self.buffer_time, now) {
                *lock = OnceCell::new();
            }
@@ -83,7 +104,7 @@ fn expired(expiration: SystemTime, buffer_time: Duration, now: SystemTime) -> bo

#[cfg(test)]
mod tests {
    use super::{expired, Cache};
    use super::{expired, ExpiringCache};
    use aws_types::credentials::CredentialsError;
    use aws_types::Credentials;
    use std::time::{Duration, SystemTime};
@@ -110,7 +131,7 @@ mod tests {
    #[traced_test]
    #[tokio::test]
    async fn cache_clears_if_expired_only() {
        let cache = Cache::new(Duration::from_secs(10));
        let cache = ExpiringCache::new(Duration::from_secs(10));
        assert!(cache
            .yield_or_clear_if_expired(epoch_secs(100))
            .await
Loading