Unverified Commit 5129c1f5 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Remove ProviderConfig as the configuration source for AssumeRoleProvider (#3014)

## Motivation and Context
`AssumeRoleProvider` currently uses `ProviderConfig` as a source of
configuration, but that API is hard use and not intended for external
consumption.

This fixes the Assume Role issue but only for `AssumeRoleProvider`

## Description
Update the API (see changelog) to be more ergonomic and derive
configuration from `SdkConfig` instead.

## Testing
Existing tests + new unit tests

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent f2041a42
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -253,3 +253,11 @@ message = "The IMDS Client builder's `build()` method is no longer async."
references = ["smithy-rs#2997"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "jdisanti"

[[aws-sdk-rust]]
message = """The API for [`AssumeRoleProvider`](https://docs.rs/aws-config/latest/aws_config/sts/struct.AssumeRoleProvider.html) has been updated to derive configuration from [`SdkConfig`](https://docs.rs/aws-config/latest/aws_config/struct.SdkConfig.html) instead of `ProviderConfig`.

For more information, see the [Change Log Discussion](https://github.com/awslabs/aws-sdk-rust/discussions/906)"""
meta = { "breaking" = true, "tada" = false, "bug" = false }
references = ["smithy-rs#3014"]
author = "rcoh"
+7 −2
Original line number Diff line number Diff line
@@ -808,13 +808,18 @@ mod loader {
                movable.fetch_add(1, Ordering::Relaxed);
                http::Response::new("ok!")
            });
            let config = from_env().http_connector(conn.clone()).load().await;
            let config = from_env()
                .fs(Fs::from_slice(&[]))
                .env(Env::from_slice(&[]))
                .http_connector(conn.clone())
                .load()
                .await;
            config
                .credentials_provider()
                .unwrap()
                .provide_credentials()
                .await
                .expect_err("no traffic is allowed");
                .expect_err("did not expect credentials to be loaded—no traffic is allowed");
            let num_requests = num_requests.load(Ordering::Relaxed);
            assert!(num_requests > 0, "{}", num_requests);
        }
+205 −67
Original line number Diff line number Diff line
@@ -5,18 +5,18 @@

//! Assume credentials for a role through the AWS Security Token Service (STS).

use crate::connector::expect_connector;
use crate::provider_config::ProviderConfig;
use aws_credential_types::cache::CredentialsCache;
use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
use aws_credential_types::provider::{
    self, error::CredentialsError, future, ProvideCredentials, SharedCredentialsProvider,
};
use aws_sdk_sts::operation::assume_role::builders::AssumeRoleFluentBuilder;
use aws_sdk_sts::operation::assume_role::AssumeRoleError;
use aws_sdk_sts::types::PolicyDescriptorType;
use aws_sdk_sts::Client as StsClient;
use aws_smithy_client::erase::DynConnector;
use aws_smithy_http::result::SdkError;
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_types::region::Region;
use aws_types::SdkConfig;
use std::time::Duration;
use tracing::Instrument;

@@ -28,19 +28,44 @@ use tracing::Instrument;
/// the desired role.
///
/// # Examples
/// Create an AssumeRoleProvider explicitly set to us-east-2 that utilizes the default credentials chain.
/// ```no_run
/// use aws_credential_types::Credentials;
/// use aws_config::sts::{AssumeRoleProvider};
/// use aws_config::sts::AssumeRoleProvider;
/// use aws_types::region::Region;
/// use aws_config::environment;
/// use aws_config::environment::credentials::EnvironmentVariableCredentialsProvider;
/// use std::sync::Arc;
///
/// # async fn docs() {
/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
///   .region(Region::from_static("us-east-2"))
///   .session_name("testAR")
///   .build(Arc::new(EnvironmentVariableCredentialsProvider::new()) as Arc<_>);
///   .build().await;
/// }
/// ```
///
/// Create an AssumeRoleProvider from an explicitly configured base configuration.
/// ```no_run
/// use aws_config::sts::AssumeRoleProvider;
/// use aws_types::region::Region;
/// # async fn docs() {
/// let conf = aws_config::from_env().use_fips(true).load().await;
/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
///   .configure(&conf)
///   .session_name("testAR")
///   .build().await;
/// }
/// ```
///
/// Create an AssumeroleProvider that sources credentials from a provider credential provider:
/// ```no_run
/// use aws_config::sts::AssumeRoleProvider;
/// use aws_types::region::Region;
/// use aws_config::environment::EnvironmentVariableCredentialsProvider;
/// # async fn docs() {
/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
///   .session_name("test-assume-role-session")
///   // only consider environment variables, explicitly.
///   .build_from_provider(EnvironmentVariableCredentialsProvider::new()).await;
/// }
/// ```
///
#[derive(Debug)]
pub struct AssumeRoleProvider {
    inner: Inner,
@@ -72,12 +97,14 @@ pub struct AssumeRoleProviderBuilder {
    role_arn: String,
    external_id: Option<String>,
    session_name: Option<String>,
    region: Option<Region>,
    conf: Option<ProviderConfig>,
    session_length: Option<Duration>,
    policy: Option<String>,
    policy_arns: Option<Vec<PolicyDescriptorType>>,

    region_override: Option<Region>,

    credentials_cache: Option<CredentialsCache>,
    sdk_config: Option<SdkConfig>,
}

impl AssumeRoleProviderBuilder {
@@ -91,14 +118,14 @@ impl AssumeRoleProviderBuilder {
    pub fn new(role: impl Into<String>) -> Self {
        Self {
            role_arn: role.into(),
            credentials_cache: None,
            external_id: None,
            session_name: None,
            session_length: None,
            region: None,
            conf: None,
            policy: None,
            policy_arns: None,
            credentials_cache: None,
            sdk_config: None,
            region_override: None,
        }
    }

@@ -162,20 +189,10 @@ impl AssumeRoleProviderBuilder {

    /// Set the region to assume the role in.
    ///
    /// This dictates which STS endpoint the AssumeRole action is invoked on.
    /// This dictates which STS endpoint the AssumeRole action is invoked on. This will override
    /// a region set from `.configure(...)`
    pub fn region(mut self, region: Region) -> Self {
        self.region = Some(region);
        self
    }

    /// If the `rustls` or `nativetls` features are enabled, this field is optional and a default
    /// backing connection will be provided.
    pub fn connection(mut self, conn: impl aws_smithy_client::bounds::SmithyConnector) -> Self {
        let conf = match self.conf {
            Some(conf) => conf.with_http_connector(DynConnector::new(conn)),
            None => ProviderConfig::default().with_http_connector(DynConnector::new(conn)),
        };
        self.conf = Some(conf);
        self.region_override = Some(region);
        self
    }

@@ -191,36 +208,57 @@ impl AssumeRoleProviderBuilder {
        self
    }

    /// Override the configuration used for this provider
    /// Sets the configuration used for this provider
    ///
    /// This enables overriding the connection used to communicate with STS in addition to other internal
    /// fields like the time source and sleep implementation used for caching.
    pub fn configure(mut self, conf: &ProviderConfig) -> Self {
        self.conf = Some(conf.clone());
    ///
    /// If this field is not provided, configuration from [`aws_config::load_from_env().await`] is used.
    ///
    /// # Examples
    /// ```rust
    /// # async fn docs() {
    /// use aws_types::region::Region;
    /// use aws_config::sts::AssumeRoleProvider;
    /// let config = aws_config::from_env().region(Region::from_static("us-west-2")).load().await;
    /// let assume_role_provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/example")
    ///   .configure(&config)
    ///   .build();
    /// }
    pub fn configure(mut self, conf: &SdkConfig) -> Self {
        self.sdk_config = Some(conf.clone());
        self
    }

    /// Build a credentials provider for this role authorized by the given `provider`.
    pub fn build(self, provider: impl ProvideCredentials + 'static) -> AssumeRoleProvider {
        let conf = self.conf.unwrap_or_default();

        let credentials_cache = self
            .credentials_cache
            .unwrap_or_else(CredentialsCache::no_caching);

        let mut config = aws_sdk_sts::Config::builder()
            .credentials_cache(credentials_cache)
            .credentials_provider(provider)
            .time_source(conf.time_source())
            .region(self.region.clone())
            .http_connector(expect_connector(
                "The AssumeRole credentials provider",
                conf.connector(&Default::default()),
            ));
        config.set_sleep_impl(conf.sleep());
    /// Build a credentials provider for this role.
    ///
    /// Base credentials will be used from the [`SdkConfig`] set via [`Self::configure`] or loaded
    /// from [`aws_config::from_env`](crate::from_env) if `configure` was never called.
    pub async fn build(self) -> AssumeRoleProvider {
        let mut conf = match self.sdk_config {
            Some(conf) => conf,
            None => crate::load_from_env().await,
        };
        // ignore a credentials cache set from SdkConfig
        conf = conf
            .into_builder()
            .credentials_cache(
                self.credentials_cache
                    .unwrap_or(CredentialsCache::no_caching()),
            )
            .build();

        // set a region override if one exists
        if let Some(region) = self.region_override {
            conf = conf.into_builder().region(region).build()
        }

        let config = aws_sdk_sts::config::Builder::from(&conf);

        let time_source = conf.time_source().expect("A time source must be provided.");

        let session_name = self.session_name.unwrap_or_else(|| {
            super::util::default_session_name("assume-role-provider", conf.time_source().now())
            super::util::default_session_name("assume-role-provider", time_source.now())
        });

        let sts_client = StsClient::from_conf(config.build());
@@ -237,6 +275,23 @@ impl AssumeRoleProviderBuilder {
            inner: Inner { fluent_builder },
        }
    }

    /// Build a credentials provider for this role authorized by the given `provider`.
    pub async fn build_from_provider(
        mut self,
        provider: impl ProvideCredentials + 'static,
    ) -> AssumeRoleProvider {
        let conf = match self.sdk_config {
            Some(conf) => conf,
            None => crate::load_from_env().await,
        };
        let conf = conf
            .into_builder()
            .credentials_provider(SharedCredentialsProvider::new(provider))
            .build();
        self.sdk_config = Some(conf);
        self.build().await
    }
}

impl Inner {
@@ -287,40 +342,121 @@ impl ProvideCredentials for AssumeRoleProvider {

#[cfg(test)]
mod test {
    use crate::provider_config::ProviderConfig;
    use crate::sts::AssumeRoleProvider;
    use aws_credential_types::credential_fn::provide_credentials_fn;
    use aws_credential_types::provider::ProvideCredentials;
    use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
    use aws_credential_types::Credentials;
    use aws_smithy_async::rt::sleep::TokioSleep;
    use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
    use aws_smithy_async::test_util::instant_time_and_sleep;
    use aws_smithy_async::time::StaticTimeSource;
    use aws_smithy_client::erase::DynConnector;
    use aws_smithy_client::test_connection::{capture_request, TestConnection};
    use aws_smithy_http::body::SdkBody;
    use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
    use aws_types::os_shim_internal::Env;
    use aws_types::region::Region;
    use aws_types::SdkConfig;
    use http::header::AUTHORIZATION;
    use std::time::{Duration, UNIX_EPOCH};

    #[tokio::test]
    async fn configures_session_length() {
        let (server, request) = capture_request(None);
        let provider_conf = ProviderConfig::empty()
            .with_sleep(TokioSleep::new())
            .with_time_source(StaticTimeSource::new(
        let provider_conf = SdkConfig::builder()
            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
            .time_source(StaticTimeSource::new(
                UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
            ))
            .with_http_connector(DynConnector::new(server));
            .http_connector(DynConnector::new(server))
            .region(Region::from_static("this-will-be-overridden"))
            .build();
        let provider = AssumeRoleProvider::builder("myrole")
            .configure(&provider_conf)
            .region(Region::new("us-east-1"))
            .session_length(Duration::from_secs(1234567))
            .build(provide_credentials_fn(|| async {
            .build_from_provider(provide_credentials_fn(|| async {
                Ok(Credentials::for_tests())
            }));
        let _ = provider.provide_credentials().await;
            }))
            .await;
        let _ = dbg!(provider.provide_credentials().await);
        let req = request.expect_request();
        let str_body = std::str::from_utf8(req.body().bytes().unwrap()).unwrap();
        assert!(str_body.contains("1234567"), "{}", str_body);
        assert_eq!(req.uri(), "https://sts.us-east-1.amazonaws.com");
    }

    #[tokio::test]
    async fn loads_region_from_sdk_config() {
        let (server, request) = capture_request(None);
        let provider_conf = SdkConfig::builder()
            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
            .time_source(StaticTimeSource::new(
                UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
            ))
            .http_connector(DynConnector::new(server))
            .credentials_provider(SharedCredentialsProvider::new(provide_credentials_fn(
                || async {
                    panic!("don't call me — will be overridden");
                    #[allow(unreachable_code)]
                    Ok(Credentials::for_tests())
                },
            )))
            .region(Region::from_static("us-west-2"))
            .build();
        let provider = AssumeRoleProvider::builder("myrole")
            .configure(&provider_conf)
            .session_length(Duration::from_secs(1234567))
            .build_from_provider(provide_credentials_fn(|| async {
                Ok(Credentials::for_tests())
            }))
            .await;
        let _ = dbg!(provider.provide_credentials().await);
        let req = request.expect_request();
        assert_eq!(req.uri(), "https://sts.us-west-2.amazonaws.com");
    }

    /// Test that `build()` where no provider is passed still works
    #[tokio::test]
    async fn build_method_from_sdk_config() {
        let _guard = capture_test_logs();
        let (server, request) = capture_request(Some(
            http::Response::builder()
                .status(404)
                .body(SdkBody::from(""))
                .unwrap(),
        ));
        let conf = crate::from_env()
            .env(Env::from_slice(&[
                ("AWS_ACCESS_KEY_ID", "123-key"),
                ("AWS_SECRET_ACCESS_KEY", "456"),
                ("AWS_REGION", "us-west-17"),
            ]))
            .use_dual_stack(true)
            .use_fips(true)
            .time_source(StaticTimeSource::from_secs(1234567890))
            .http_connector(server)
            .load()
            .await;
        let provider = AssumeRoleProvider::builder("role")
            .configure(&conf)
            .build()
            .await;
        let _ = provider.provide_credentials().await;
        let req = request.expect_request();
        let auth_header = req
            .headers()
            .get(AUTHORIZATION)
            .unwrap()
            .to_str()
            .unwrap()
            .to_string();
        let expect = "Credential=123-key/20090213/us-west-17/sts/aws4_request";
        assert!(
            auth_header.contains(expect),
            "Expected header to contain {expect} but it was {auth_header}"
        );
        // ensure that FIPS & DualStack are also respected
        assert_eq!("https://sts-fips.us-west-17.api.aws/", req.uri())
    }

    #[tokio::test]
@@ -340,10 +476,11 @@ mod test {
            UNIX_EPOCH + Duration::from_secs(1234567890 - 120), // 1234567890 since UNIX_EPOCH is 2009-02-13T23:31:30Z
        );

        let provider_conf = ProviderConfig::empty()
            .with_sleep(sleep)
            .with_time_source(testing_time_source.clone())
            .with_http_connector(DynConnector::new(conn));
        let provider_conf = SdkConfig::builder()
            .sleep_impl(SharedAsyncSleep::new(sleep))
            .time_source(testing_time_source.clone())
            .http_connector(DynConnector::new(conn))
            .build();
        let credentials_list = std::sync::Arc::new(std::sync::Mutex::new(vec![
            Credentials::new(
                "test",
@@ -364,13 +501,14 @@ mod test {
        let provider = AssumeRoleProvider::builder("myrole")
            .configure(&provider_conf)
            .region(Region::new("us-east-1"))
            .build(provide_credentials_fn(move || {
            .build_from_provider(provide_credentials_fn(move || {
                let list = credentials_list.clone();
                async move {
                    let next = list.lock().unwrap().remove(0);
                    Ok(next)
                }
            }));
            }))
            .await;

        let creds_first = provider
            .provide_credentials()
+23 −0
Original line number Diff line number Diff line
@@ -602,4 +602,27 @@ impl SdkConfig {
    pub fn builder() -> Builder {
        Builder::default()
    }

    /// Convert this [`SdkConfig`] into a [`Builder`] by cloning it first
    pub fn to_builder(&self) -> Builder {
        self.clone().into_builder()
    }

    /// Convert this [`SdkConfig`] back to a builder to enable modification
    pub fn into_builder(self) -> Builder {
        Builder {
            app_name: self.app_name,
            credentials_cache: self.credentials_cache,
            credentials_provider: self.credentials_provider,
            region: self.region,
            endpoint_url: self.endpoint_url,
            retry_config: self.retry_config,
            sleep_impl: self.sleep_impl,
            time_source: self.time_source,
            timeout_config: self.timeout_config,
            http_connector: self.http_connector,
            use_fips: self.use_fips,
            use_dual_stack: self.use_dual_stack,
        }
    }
}
+6 −1
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@
//! Time source abstraction to support WASM and testing
use std::fmt::Debug;
use std::sync::Arc;
use std::time::SystemTime;
use std::time::{Duration, SystemTime, UNIX_EPOCH};

/// Trait with a `now()` function returning the current time
pub trait TimeSource: Debug + Send + Sync {
@@ -51,6 +51,11 @@ impl StaticTimeSource {
    pub fn new(time: SystemTime) -> Self {
        Self { time }
    }

    /// Creates a new static time source from the provided number of seconds since the UNIX epoch
    pub fn from_secs(epoch_secs: u64) -> Self {
        Self::new(UNIX_EPOCH + Duration::from_secs(epoch_secs))
    }
}

impl TimeSource for StaticTimeSource {