Unverified Commit 980b5c44 authored by ysaito1001's avatar ysaito1001 Committed by GitHub
Browse files

Add fallback_on_interrupt to the ProvideCredentials trait (#2246)



* Implement RFC for providing fallback credentials

This commit implements the changes checklist in the RFC for providing
fallback credentials.

* Remove needless lifetime parameter

* Update CHANGELOG.next.toml

---------

Co-authored-by: default avatarYuki Saito <awsaito@amazon.com>
parent e6c3a4b9
Loading
Loading
Loading
Loading
+47 −1
Original line number Diff line number Diff line
@@ -10,3 +10,49 @@
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"

[[aws-sdk-rust]]
message = """
Provide a way to retrieve fallback credentials if a call to `provide_credentials` is interrupted. An interrupt can occur when a timeout future is raced against a future for `provide_credentials`, and the former wins the race. A new method, `fallback_on_interrupt` on the `ProvideCredentials` trait, can be used in that case. The following code snippet from `LazyCredentialsCache::provide_cached_credentials` has been updated like so:

Before:
```rust
let timeout_future = self.sleeper.sleep(self.load_timeout);
// --snip--
let future = Timeout::new(provider.provide_credentials(), timeout_future);
let result = cache
    .get_or_load(|| {
        async move {
            let credentials = future.await.map_err(|_err| {
                CredentialsError::provider_timed_out(load_timeout)
            })??;
            // --snip--
        }
    }).await;
// --snip--
```

After:
```rust
let timeout_future = self.sleeper.sleep(self.load_timeout);
// --snip--
let future = Timeout::new(provider.provide_credentials(), timeout_future);
let result = cache
    .get_or_load(|| {
        async move {
           let credentials = match future.await {
                Ok(creds) => creds?,
                Err(_err) => match provider.fallback_on_interrupt() { // can provide fallback credentials
                    Some(creds) => creds,
                    None => return Err(CredentialsError::provider_timed_out(load_timeout)),
                }
            };
            // --snip--
        }
    }).await;
// --snip--
```
"""
references = ["smithy-rs#2246"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "ysaito1001"
+5 −0
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@
use std::borrow::Cow;

use aws_credential_types::provider::{self, future, ProvideCredentials};
use aws_credential_types::Credentials;
use tracing::Instrument;

use crate::environment::credentials::EnvironmentVariableCredentialsProvider;
@@ -83,6 +84,10 @@ impl ProvideCredentials for DefaultCredentialsChain {
    {
        future::ProvideCredentials::new(self.credentials())
    }

    fn fallback_on_interrupt(&self) -> Option<Credentials> {
        self.provider_chain.fallback_on_interrupt()
    }
}

/// Builder for [`DefaultCredentialsChain`](DefaultCredentialsChain)
+115 −1
Original line number Diff line number Diff line
@@ -3,7 +3,10 @@
 * SPDX-License-Identifier: Apache-2.0
 */

use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
use aws_credential_types::{
    provider::{self, error::CredentialsError, future, ProvideCredentials},
    Credentials,
};
use aws_smithy_types::error::display::DisplayErrorContext;
use std::borrow::Cow;
use tracing::Instrument;
@@ -104,4 +107,115 @@ impl ProvideCredentials for CredentialsProviderChain {
    {
        future::ProvideCredentials::new(self.credentials())
    }

    fn fallback_on_interrupt(&self) -> Option<Credentials> {
        for (_, provider) in &self.providers {
            match provider.fallback_on_interrupt() {
                creds @ Some(_) => return creds,
                None => {}
            }
        }
        None
    }
}

#[cfg(test)]
mod tests {
    use std::time::Duration;

    use aws_credential_types::{
        credential_fn::provide_credentials_fn,
        provider::{error::CredentialsError, future, ProvideCredentials},
        Credentials,
    };
    use aws_smithy_async::future::timeout::Timeout;

    use crate::meta::credentials::CredentialsProviderChain;

    #[derive(Debug)]
    struct FallbackCredentials(Credentials);

    impl ProvideCredentials for FallbackCredentials {
        fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
        where
            Self: 'a,
        {
            future::ProvideCredentials::new(async {
                tokio::time::sleep(Duration::from_millis(200)).await;
                Ok(self.0.clone())
            })
        }

        fn fallback_on_interrupt(&self) -> Option<Credentials> {
            Some(self.0.clone())
        }
    }

    #[tokio::test]
    async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider2_was_providing_credentials(
    ) {
        let chain = CredentialsProviderChain::first_try(
            "provider1",
            provide_credentials_fn(|| async {
                tokio::time::sleep(Duration::from_millis(200)).await;
                Err(CredentialsError::not_loaded(
                    "no providers in chain provided credentials",
                ))
            }),
        )
        .or_else("provider2", FallbackCredentials(Credentials::for_tests()));

        // Let the first call to `provide_credentials` succeed.
        let expected = chain.provide_credentials().await.unwrap();

        // Let the second call fail with an external timeout.
        let timeout = Timeout::new(
            chain.provide_credentials(),
            tokio::time::sleep(Duration::from_millis(300)),
        );
        match timeout.await {
            Ok(_) => assert!(false, "provide_credentials completed before timeout future"),
            Err(_err) => match chain.fallback_on_interrupt() {
                Some(actual) => assert_eq!(actual, expected),
                None => assert!(
                    false,
                    "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
                ),
            },
        };
    }

    #[tokio::test]
    async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider1_was_providing_credentials(
    ) {
        let chain = CredentialsProviderChain::first_try(
            "provider1",
            provide_credentials_fn(|| async {
                tokio::time::sleep(Duration::from_millis(200)).await;
                Err(CredentialsError::not_loaded(
                    "no providers in chain provided credentials",
                ))
            }),
        )
        .or_else("provider2", FallbackCredentials(Credentials::for_tests()));

        // Let the first call to `provide_credentials` succeed.
        let expected = chain.provide_credentials().await.unwrap();

        // Let the second call fail with an external timeout.
        let timeout = Timeout::new(
            chain.provide_credentials(),
            tokio::time::sleep(Duration::from_millis(100)),
        );
        match timeout.await {
            Ok(_) => assert!(false, "provide_credentials completed before timeout future"),
            Err(_err) => match chain.fallback_on_interrupt() {
                Some(actual) => assert_eq!(actual, expected),
                None => assert!(
                    false,
                    "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
                ),
            },
        };
    }
}
+12 −3
Original line number Diff line number Diff line
@@ -78,10 +78,19 @@ impl ProvideCachedCredentials for LazyCredentialsCache {
                let result = cache
                    .get_or_load(|| {
                        let span = info_span!("lazy_load_credentials");
                        let provider = provider.clone();
                        async move {
                            let credentials = future.await.map_err(|_err| {
                                CredentialsError::provider_timed_out(load_timeout)
                            })??;
                            let credentials = match future.await {
                                Ok(creds) => creds?,
                                Err(_err) => match provider.fallback_on_interrupt() {
                                    Some(creds) => creds,
                                    None => {
                                        return Err(CredentialsError::provider_timed_out(
                                            load_timeout,
                                        ))
                                    }
                                },
                            };
                            // If the credentials don't have an expiration time, then create a default one
                            let expiry = credentials
                                .expiry()
+14 −1
Original line number Diff line number Diff line
@@ -249,7 +249,7 @@ pub mod future {

    type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;

    /// Future new-type that the `ProvideCredentials` trait must return.
    /// Future new-type that `ProvideCredentials::provide_credentials` must return.
    #[derive(Debug)]
    pub struct ProvideCredentials<'a>(NowOrLater<super::Result, BoxFuture<'a, super::Result>>);

@@ -280,6 +280,19 @@ pub trait ProvideCredentials: Send + Sync + std::fmt::Debug {
    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
    where
        Self: 'a;

    /// Returns fallback credentials.
    ///
    /// This method should be used as a fallback plan, i.e., when
    /// a call to `provide_credentials` is interrupted and its future
    /// fails to complete.
    ///
    /// The fallback credentials should be set aside and ready to be returned
    /// immediately. Therefore, the user should NOT go fetch new credentials
    /// within this method, which might cause a long-running operation.
    fn fallback_on_interrupt(&self) -> Option<Credentials> {
        None
    }
}

impl ProvideCredentials for Credentials {