Unverified Commit 681d3b33 authored by ysaito1001's avatar ysaito1001 Committed by GitHub
Browse files

Add static stability support to IMDS credentials provider (#2258)

* Add static stability support to ImdsCredentialsProvider

This commit adds static stability support to `ImdsCredentialsProvider`.
Static stability refers to continued availability of a service in the
face of impaired dependencies. In case IMDS is not available, we still
allow requests to be dispatched with expired credentials. This, in turn,
allows the target service to makes the ultimate decision as to whether
requests sent are valid or not instead of the client SDK determining
their validity.

The way it is implemented is `ImdsCredentialsProvider` now stores a last
retrieved credentials which will later be served when IMDS is unreachable.

* Add tests to IMDS credentials provider

This commit adds tests to IMDS credentials providers for static stability
support. These tests are prescribed in https://github.com/awslabs/smithy-rs/issues/2117

.
From an IMDS credentials provider' perspective, however, some of the tests
are considered to fall under the same equivalence class with others.
Therefore, a single test can cover multiple test cases.

* Update CHANGELOG.next.toml

* Update CHANGELOG.next.toml

Co-authored-by: default avatarJohn DiSanti <jdisanti@amazon.com>

---------

Co-authored-by: default avatarYuki Saito <awsaito@amazon.com>
Co-authored-by: default avatarJohn DiSanti <jdisanti@amazon.com>
parent a6ef0788
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -62,3 +62,11 @@ message = "The [`@uniqueItems`](https://smithy.io/2.0/spec/constraint-traits.htm
references = ["smithy-rs#2232", "smithy-rs#1670"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "server"}
author = "david-perez"

[[aws-sdk-rust]]
message = """
Add static stability support to IMDS credentials provider. It does not alter common use cases for the provider, but allows the provider to serve expired credentials in case IMDS is unreachable. This allows requests to be dispatched to a target service with expired credentials. This, in turn, allows the target service to make the ultimate decision as to whether requests sent are valid or not.
"""
references = ["smithy-rs#2258"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
author = "ysaito1001"
+3 −0
Original line number Diff line number Diff line
@@ -38,6 +38,9 @@ ring = "0.16"
hex = "0.4.3"
zeroize = "1"

# implementation detail of IMDS credentials provider
fastrand = "1"

bytes = "1.1.0"
http = "0.2.4"
tower = { version = "0.4.8" }
+26 −6
Original line number Diff line number Diff line
@@ -236,6 +236,12 @@ mod test {
            make_test!($name, execute_from_live_traffic);
        };
        ($name: ident, $func: ident) => {
            make_test!($name, $func, std::convert::identity);
        };
        ($name: ident, $provider_config_builder: expr) => {
            make_test!($name, execute, $provider_config_builder);
        };
        ($name: ident, $func: ident, $provider_config_builder: expr) => {
            #[traced_test]
            #[tokio::test]
            async fn $name() {
@@ -243,7 +249,9 @@ mod test {
                    "./test-data/default-provider-chain/",
                    stringify!($name)
                ))
                .await
                .unwrap()
                .with_provider_config($provider_config_builder)
                .$func(|conf| async {
                    crate::default_provider::credentials::Builder::default()
                        .configure(conf)
@@ -269,12 +277,23 @@ mod test {

    make_test!(imds_no_iam_role);
    make_test!(imds_default_chain_error);
    make_test!(imds_default_chain_success);
    make_test!(imds_default_chain_success, |config| {
        config.with_time_source(aws_credential_types::time_source::TimeSource::testing(
            &aws_credential_types::time_source::TestingTimeSource::new(std::time::UNIX_EPOCH),
        ))
    });
    make_test!(imds_assume_role);
    make_test!(imds_config_with_no_creds);
    make_test!(imds_config_with_no_creds, |config| {
        config.with_time_source(aws_credential_types::time_source::TimeSource::testing(
            &aws_credential_types::time_source::TestingTimeSource::new(std::time::UNIX_EPOCH),
        ))
    });
    make_test!(imds_disabled);
    make_test!(imds_default_chain_retries);

    make_test!(imds_default_chain_retries, |config| {
        config.with_time_source(aws_credential_types::time_source::TimeSource::testing(
            &aws_credential_types::time_source::TestingTimeSource::new(std::time::UNIX_EPOCH),
        ))
    });
    make_test!(ecs_assume_role);
    make_test!(ecs_credentials);
    make_test!(ecs_credentials_invalid_profile);
@@ -284,11 +303,12 @@ mod test {

    #[tokio::test]
    async fn profile_name_override() {
        let (_, conf) =
        let conf =
            TestEnvironment::from_dir("./test-data/default-provider-chain/profile_static_keys")
                .await
                .unwrap()
                .provider_config()
                .await;
                .clone();
        let provider = DefaultCredentialsChain::builder()
            .profile_name("secondary")
            .configure(conf)
+238 −8
Original line number Diff line number Diff line
@@ -14,11 +14,16 @@ use crate::imds::client::LazyClient;
use crate::json_credentials::{parse_json_credentials, JsonCredentials, RefreshableCredentials};
use crate::provider_config::ProviderConfig;
use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
use aws_credential_types::time_source::TimeSource;
use aws_credential_types::Credentials;
use aws_types::os_shim_internal::Env;
use std::borrow::Cow;
use std::error::Error as StdError;
use std::fmt;
use std::sync::{Arc, RwLock};
use std::time::{Duration, SystemTime};

const CREDENTIAL_EXPIRATION_INTERVAL: Duration = Duration::from_secs(15 * 60);

#[derive(Debug)]
struct ImdsCommunicationError {
@@ -45,6 +50,8 @@ pub struct ImdsCredentialsProvider {
    client: LazyClient,
    env: Env,
    profile: Option<String>,
    time_source: TimeSource,
    last_retrieved_credentials: Arc<RwLock<Option<Credentials>>>,
}

/// Builder for [`ImdsCredentialsProvider`]
@@ -53,6 +60,7 @@ pub struct Builder {
    provider_config: Option<ProviderConfig>,
    profile_override: Option<String>,
    imds_override: Option<imds::Client>,
    last_retrieved_credentials: Option<Credentials>,
}

impl Builder {
@@ -86,6 +94,13 @@ impl Builder {
        self
    }

    #[allow(dead_code)]
    #[cfg(test)]
    fn last_retrieved_credentials(mut self, credentials: Credentials) -> Self {
        self.last_retrieved_credentials = Some(credentials);
        self
    }

    /// Create an [`ImdsCredentialsProvider`] from this builder.
    pub fn build(self) -> ImdsCredentialsProvider {
        let provider_config = self.provider_config.unwrap_or_default();
@@ -102,6 +117,8 @@ impl Builder {
            client,
            env,
            profile: self.profile_override,
            time_source: provider_config.time_source(),
            last_retrieved_credentials: Arc::new(RwLock::new(self.last_retrieved_credentials)),
        }
    }
}
@@ -117,6 +134,10 @@ impl ProvideCredentials for ImdsCredentialsProvider {
    {
        future::ProvideCredentials::new(self.credentials())
    }

    fn fallback_on_interrupt(&self) -> Option<Credentials> {
        self.last_retrieved_credentials.read().unwrap().clone()
    }
}

impl ImdsCredentialsProvider {
@@ -167,7 +188,36 @@ impl ImdsCredentialsProvider {
        }
    }

    async fn credentials(&self) -> provider::Result {
    // Extend the cached expiration time if necessary
    //
    // This allows continued use of the credentials even when IMDS returns expired ones.
    fn maybe_extend_expiration(&self, expiration: SystemTime) -> SystemTime {
        let rng = fastrand::Rng::with_seed(
            self.time_source
                .now()
                .duration_since(SystemTime::UNIX_EPOCH)
                .expect("now should be after UNIX EPOCH")
                .as_secs(),
        );
        // calculate credentials' refresh offset with jitter
        let refresh_offset =
            CREDENTIAL_EXPIRATION_INTERVAL + Duration::from_secs(rng.u64(120..=600));
        let new_expiry = self.time_source.now() + refresh_offset;

        if new_expiry < expiration {
            return expiration;
        }

        tracing::warn!(
            "Attempting credential expiration extension due to a credential service availability issue. \
            A refresh of these credentials will be attempted again within the next {:.2} minutes.",
            refresh_offset.as_secs_f64() / 60.0,
        );

        new_expiry
    }

    async fn retrieve_credentials(&self) -> provider::Result {
        if self.imds_disabled() {
            tracing::debug!("IMDS disabled because $AWS_EC2_METADATA_DISABLED was set to `true`");
            return Err(CredentialsError::not_loaded(
@@ -196,13 +246,18 @@ impl ImdsCredentialsProvider {
                session_token,
                expiration,
                ..
            })) => Ok(Credentials::new(
            })) => {
                let expiration = self.maybe_extend_expiration(expiration);
                let creds = Credentials::new(
                    access_key_id,
                    secret_access_key,
                    Some(session_token.to_string()),
                    expiration.into(),
                    "IMDSv2",
            )),
                );
                *self.last_retrieved_credentials.write().unwrap() = Some(creds.clone());
                Ok(creds)
            }
            Ok(JsonCredentials::Error { code, message })
                if code == codes::ASSUME_ROLE_UNAUTHORIZED_ACCESS =>
            {
@@ -222,16 +277,34 @@ impl ImdsCredentialsProvider {
            Err(invalid) => Err(CredentialsError::unhandled(invalid)),
        }
    }

    async fn credentials(&self) -> provider::Result {
        match self.retrieve_credentials().await {
            creds @ Ok(_) => creds,
            // Any failure while retrieving credentials MUST NOT impede use of existing credentials.
            err => match &*self.last_retrieved_credentials.read().unwrap() {
                Some(creds) => Ok(creds.clone()),
                _ => err,
            },
        }
    }
}

#[cfg(test)]
mod test {
    use std::time::{Duration, UNIX_EPOCH};

    use crate::imds::client::test::{
        imds_request, imds_response, make_client, token_request, token_response,
    };
    use crate::imds::credentials::ImdsCredentialsProvider;
    use crate::provider_config::ProviderConfig;
    use aws_credential_types::provider::ProvideCredentials;
    use aws_credential_types::time_source::{TestingTimeSource, TimeSource};
    use aws_smithy_async::rt::sleep::TokioSleep;
    use aws_smithy_client::erase::DynConnector;
    use aws_smithy_client::test_connection::TestConnection;
    use tracing_test::traced_test;

    const TOKEN_A: &str = "token_a";

@@ -268,4 +341,161 @@ mod test {
        assert_eq!(creds2.access_key_id(), "ASIARTEST2");
        connection.assert_requests_match(&[]);
    }

    #[tokio::test]
    #[traced_test]
    async fn expired_credentials_should_be_extended() {
        let connection = TestConnection::new(vec![
                (
                    token_request("http://169.254.169.254", 21600),
                    token_response(21600, TOKEN_A),
                ),
                (
                    imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/", TOKEN_A),
                    imds_response(r#"profile-name"#),
                ),
                (
                    imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/profile-name", TOKEN_A),
                    imds_response("{\n  \"Code\" : \"Success\",\n  \"LastUpdated\" : \"2021-09-20T21:42:26Z\",\n  \"Type\" : \"AWS-HMAC\",\n  \"AccessKeyId\" : \"ASIARTEST\",\n  \"SecretAccessKey\" : \"testsecret\",\n  \"Token\" : \"testtoken\",\n  \"Expiration\" : \"2021-09-21T04:16:53Z\"\n}"),
                ),
            ]);

        // set to 2021-09-21T17:41:25Z that renders fetched credentials already expired (2021-09-21T04:16:53Z)
        let time_of_request_to_fetch_credentials = UNIX_EPOCH + Duration::from_secs(1632246085);
        let time_source = TimeSource::testing(&TestingTimeSource::new(
            time_of_request_to_fetch_credentials,
        ));

        tokio::time::pause();

        let provider_config = ProviderConfig::no_configuration()
            .with_http_connector(DynConnector::new(connection.clone()))
            .with_time_source(time_source)
            .with_sleep(TokioSleep::new());
        let client = crate::imds::Client::builder()
            .configure(&provider_config)
            .build()
            .await
            .expect("valid client");
        let provider = ImdsCredentialsProvider::builder()
            .configure(&provider_config)
            .imds_client(client)
            .build();
        let creds = provider.provide_credentials().await.expect("valid creds");
        assert!(creds.expiry().unwrap() > time_of_request_to_fetch_credentials);
        connection.assert_requests_match(&[]);

        // We should inform customers that expired credentials are being used for stability.
        assert!(logs_contain("Attempting credential expiration extension"));
    }

    #[tokio::test]
    #[cfg(any(feature = "rustls", feature = "native-tls"))]
    async fn read_timeout_during_credentials_refresh_should_yield_last_retrieved_credentials() {
        let client = crate::imds::Client::builder()
            // 240.* can never be resolved
            .endpoint(http::Uri::from_static("http://240.0.0.0"))
            .build()
            .await
            .expect("valid client");
        let expected = aws_credential_types::Credentials::for_tests();
        let provider = ImdsCredentialsProvider::builder()
            .imds_client(client)
            // seed fallback credentials for testing
            .last_retrieved_credentials(expected.clone())
            .build();
        let actual = provider.provide_credentials().await;
        assert_eq!(actual.unwrap(), expected);
    }

    #[tokio::test]
    #[cfg(any(feature = "rustls", feature = "native-tls"))]
    async fn read_timeout_during_credentials_refresh_should_error_without_last_retrieved_credentials(
    ) {
        let client = crate::imds::Client::builder()
            // 240.* can never be resolved
            .endpoint(http::Uri::from_static("http://240.0.0.0"))
            .build()
            .await
            .expect("valid client");
        let provider = ImdsCredentialsProvider::builder()
            .imds_client(client)
            // no fallback credentials provided
            .build();
        let actual = provider.provide_credentials().await;
        assert!(matches!(
            actual,
            Err(aws_credential_types::provider::error::CredentialsError::CredentialsNotLoaded(_))
        ));
    }

    #[tokio::test]
    #[cfg(any(feature = "rustls", feature = "native-tls"))]
    async fn external_timeout_during_credentials_refresh_should_yield_last_retrieved_credentials() {
        use aws_sdk_sso::config::AsyncSleep;
        let client = crate::imds::Client::builder()
            // 240.* can never be resolved
            .endpoint(http::Uri::from_static("http://240.0.0.0"))
            .build()
            .await
            .expect("valid client");
        let expected = aws_credential_types::Credentials::for_tests();
        let provider = ImdsCredentialsProvider::builder()
            .imds_client(client)
            // seed fallback credentials for testing
            .last_retrieved_credentials(expected.clone())
            .build();
        let sleeper = aws_smithy_async::rt::sleep::TokioSleep::new();
        let timeout = aws_smithy_async::future::timeout::Timeout::new(
            provider.provide_credentials(),
            // make sure `sleeper.sleep` will be timed out first by setting a shorter duration than connect timeout
            sleeper.sleep(std::time::Duration::from_millis(100)),
        );
        match timeout.await {
            Ok(_) => assert!(false, "provide_credentials completed before timeout future"),
            Err(_err) => match provider.fallback_on_interrupt() {
                Some(actual) => assert_eq!(actual, expected),
                None => assert!(
                    false,
                    "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
                ),
            },
        };
    }

    #[tokio::test]
    async fn fallback_credentials_should_be_used_when_imds_returns_500_during_credentials_refresh()
    {
        let connection = TestConnection::new(vec![
                // The next three request/response pairs will correspond to the first call to `provide_credentials`.
                // During the call, it populates last_retrieved_credentials.
                (
                    token_request("http://169.254.169.254", 21600),
                    token_response(21600, TOKEN_A),
                ),
                (
                    imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/", TOKEN_A),
                    imds_response(r#"profile-name"#),
                ),
                (
                    imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/profile-name", TOKEN_A),
                    imds_response("{\n  \"Code\" : \"Success\",\n  \"LastUpdated\" : \"2021-09-20T21:42:26Z\",\n  \"Type\" : \"AWS-HMAC\",\n  \"AccessKeyId\" : \"ASIARTEST\",\n  \"SecretAccessKey\" : \"testsecret\",\n  \"Token\" : \"testtoken\",\n  \"Expiration\" : \"2021-09-21T04:16:53Z\"\n}"),
                ),
                // The following request/response pair corresponds to the second call to `provide_credentials`.
                // During the call, IMDS returns response code 500.
                (
                    imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/", TOKEN_A),
                    http::Response::builder().status(500).body("").unwrap(),
                ),
            ]);
        let provider = ImdsCredentialsProvider::builder()
            .imds_client(make_client(&connection).await)
            .build();
        let creds1 = provider.provide_credentials().await.expect("valid creds");
        assert_eq!(creds1.access_key_id(), "ASIARTEST");
        // `creds1` should be returned as fallback credentials and assigned to `creds2`
        let creds2 = provider.provide_credentials().await.expect("valid creds");
        assert_eq!(creds1, creds2);
        connection.assert_requests_match(&[]);
    }
}
+1 −0
Original line number Diff line number Diff line
@@ -470,6 +470,7 @@ mod test {
                    "./test-data/profile-provider/",
                    stringify!($name)
                ))
                .await
                .unwrap()
                .execute(|conf| async move { Builder::default().configure(&conf).build() })
                .await
Loading