Unverified Commit 6e8c1d87 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Improve AssumeRoleProvider (#1296)

This adds to improvements to `AssumeRoleProvider`:
1. This adds a cache aroudn this provider. Since this provider is intended to be used directly by customers,
it should be cached to improve performance and predictability.

2. Add support for configuring the role session length.
parent 775d227c
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -40,3 +40,15 @@ message = "Treat blank environment variable credentials (`AWS_ACCESS_KEY_ID` and
references = ["aws-sdk-rust#1271"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "elrob"

[[aws-sdk-rust]]
message = "Add support for configuring the session length in [AssumeRoleProvider](https://docs.rs/aws-config/latest/aws_config/sts/struct.AssumeRoleProvider.html)"
references = ["aws-sdk-rust#479", "smithy-rs#1296"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "rcoh"

[[aws-sdk-rust]]
message = "Add caching to [AssumeRoleProvider](https://docs.rs/aws-config/latest/aws_config/sts/struct.AssumeRoleProvider.html)"
references = ["smithy-rs#1296"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "rcoh"
+143 −25
Original line number Diff line number Diff line
@@ -8,7 +8,6 @@
use aws_sdk_sts::error::AssumeRoleErrorKind;
use aws_sdk_sts::middleware::DefaultMiddleware;
use aws_sdk_sts::operation::AssumeRole;
use aws_smithy_async::rt::sleep::default_async_sleep;
use aws_smithy_client::erase::DynConnector;
use aws_smithy_client::http_connector::HttpSettings;
use aws_smithy_http::result::SdkError;
@@ -16,8 +15,10 @@ use aws_types::credentials::{
    self, future, CredentialsError, ProvideCredentials, SharedCredentialsProvider,
};
use aws_types::region::Region;
use std::time::Duration;

use crate::connector::{default_connector, expect_connector};
use crate::meta::credentials::LazyCachingCredentialsProvider;
use crate::provider_config::ProviderConfig;
use tracing::Instrument;

/// Credentials provider that uses credentials provided by another provider to assume a role
@@ -42,6 +43,11 @@ use tracing::Instrument;
/// ```
#[derive(Debug)]
pub struct AssumeRoleProvider {
    cache: LazyCachingCredentialsProvider,
}

#[derive(Debug)]
struct Inner {
    sts: aws_smithy_client::Client<DynConnector, DefaultMiddleware>,
    conf: aws_sdk_sts::Config,
    op: aws_sdk_sts::input::AssumeRoleInput,
@@ -68,7 +74,8 @@ pub struct AssumeRoleProviderBuilder {
    external_id: Option<String>,
    session_name: Option<String>,
    region: Option<Region>,
    connection: Option<aws_smithy_client::erase::DynConnector>,
    conf: Option<ProviderConfig>,
    session_length: Option<Duration>,
}

impl AssumeRoleProviderBuilder {
@@ -84,8 +91,9 @@ impl AssumeRoleProviderBuilder {
            role_arn: role.into(),
            external_id: None,
            session_name: None,
            session_length: None,
            region: None,
            connection: None,
            conf: None,
        }
    }

@@ -110,6 +118,23 @@ impl AssumeRoleProviderBuilder {
        self
    }

    /// Set the expiration time of the role session.
    ///
    /// When unset, this value defaults to 1 hour.
    ///
    /// The value specified can range from 900 seconds (15 minutes) up to the maximum session duration
    /// set for the role. The maximum session duration setting can have a value from 1 hour to 12 hours.
    /// If you specify a value higher than this setting or the administrator setting (whichever is lower),
    /// **you will be unable to assume the role**. For example, if you specify a session duration of 12 hours,
    /// but your administrator set the maximum session duration to 6 hours, you cannot assume the role.
    ///
    /// For more information, see
    /// [duration_seconds](aws_sdk_sts::input::assume_role_input::Builder::duration_seconds)
    pub fn session_length(mut self, length: Duration) -> Self {
        self.session_length = Some(length);
        self
    }

    /// Set the region to assume the role in.
    ///
    /// This dictates which STS endpoint the AssumeRole action is invoked on.
@@ -118,32 +143,41 @@ impl AssumeRoleProviderBuilder {
        self
    }

    /// Set the backing connection to use when talking to STS.
    ///
    /// If the `rustls` or `nativetls` features are enabled, this field is optional and a default
    /// backing connection will be provided.
    pub fn connection(mut self, conn: impl aws_smithy_client::bounds::SmithyConnector) -> Self {
        self.connection = Some(aws_smithy_client::erase::DynConnector::new(conn));
        let conf = match self.conf {
            Some(conf) => conf.with_http_connector(DynConnector::new(conn)),
            None => ProviderConfig::default().with_http_connector(DynConnector::new(conn)),
        };
        self.conf = Some(conf);
        self
    }

    /// Override the configuration used for this provider
    ///
    /// This enables overriding the connection used to communicate with STS in addition to other internal
    /// fields like the time source and sleep implementation used for caching.
    pub fn configure(mut self, conf: &ProviderConfig) -> Self {
        self.conf = Some(conf.clone());
        self
    }

    /// Build a credentials provider for this role authorized by the given `provider`.
    pub fn build(self, provider: impl Into<SharedCredentialsProvider>) -> AssumeRoleProvider {
        let conf = self.conf.unwrap_or_default();
        let config = aws_sdk_sts::Config::builder()
            .credentials_provider(provider.into())
            .region(self.region.clone())
            .build();

        let conn = self.connection.unwrap_or_else(|| {
            expect_connector(default_connector(
                &HttpSettings::default(),
                default_async_sleep(),
            ))
        });
        let conn = conf
            .connector(&HttpSettings::default())
            .expect("A connector must be provided");
        let client = aws_smithy_client::Builder::new()
            .connector(conn)
            .middleware(DefaultMiddleware::new())
            .sleep_impl(default_async_sleep())
            .sleep_impl(conf.sleep())
            .build();

        let session_name = self
@@ -154,24 +188,24 @@ impl AssumeRoleProviderBuilder {
            .set_role_arn(Some(self.role_arn))
            .set_external_id(self.external_id)
            .set_role_session_name(Some(session_name))
            .set_duration_seconds(self.session_length.map(|dur| dur.as_secs() as i32))
            .build()
            .expect("operation is valid");

        AssumeRoleProvider {
        let inner = Inner {
            sts: client,
            conf: config,
            op: operation,
        }
        };
        let cache = LazyCachingCredentialsProvider::builder()
            .configure(&conf)
            .load(inner)
            .build();
        AssumeRoleProvider { cache }
    }
}

impl AssumeRoleProvider {
    #[tracing::instrument(
        name = "assume_role",
        level = "info",
        skip(self),
        fields(op = ?self.op)
    )]
impl Inner {
    async fn credentials(&self) -> credentials::Result {
        tracing::info!("assuming role");

@@ -213,11 +247,95 @@ impl AssumeRoleProvider {
    }
}

impl ProvideCredentials for AssumeRoleProvider {
impl ProvideCredentials for Inner {
    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials
    where
        Self: 'a,
    {
        future::ProvideCredentials::new(self.credentials())
        future::ProvideCredentials::new(
            self.credentials()
                .instrument(tracing::info_span!("assume_role")),
        )
    }
}

impl ProvideCredentials for AssumeRoleProvider {
    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
    where
        Self: 'a,
    {
        self.cache.provide_credentials()
    }
}

#[cfg(test)]
mod test {
    use crate::provider_config::ProviderConfig;
    use crate::sts::AssumeRoleProvider;
    use aws_smithy_client::erase::DynConnector;
    use aws_smithy_client::test_connection::capture_request;
    use aws_smithy_http::body::SdkBody;
    use aws_types::credentials::{ProvideCredentials, SharedCredentialsProvider};
    use aws_types::os_shim_internal::{ManualTimeSource, TimeSource};
    use aws_types::region::Region;
    use aws_types::Credentials;
    use std::time::{Duration, UNIX_EPOCH};

    #[tokio::test]
    async fn configures_session_length() {
        let (server, request) = capture_request(None);
        let provider_conf = ProviderConfig::empty()
            .with_time_source(TimeSource::manual(&ManualTimeSource::new(
                UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
            )))
            .with_http_connector(DynConnector::new(server));
        let provider = AssumeRoleProvider::builder("myrole")
            .configure(&provider_conf)
            .region(Region::new("us-east-1"))
            .session_length(Duration::from_secs(1234567))
            .build(SharedCredentialsProvider::new(Credentials::new(
                "base",
                "basesecret",
                Some("token".to_string()),
                None,
                "inner",
            )));
        let _ = provider.provide_credentials().await;
        let req = request.expect_request();
        let str_body = std::str::from_utf8(req.body().bytes().unwrap()).unwrap();
        assert!(str_body.contains("1234567"), "{}", str_body);
    }

    #[tokio::test]
    async fn provider_caches_credentials() {
        let resp = http::Response::new(SdkBody::from(
            "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n  <AssumeRoleResult>\n    <AssumedRoleUser>\n      <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n      <Arn>arn:aws:sts::130633740322:assumed-role/imds-chained-role-test/assume-role-from-profile-1632246085998</Arn>\n    </AssumedRoleUser>\n    <Credentials>\n      <AccessKeyId>ASIARCORRECT</AccessKeyId>\n      <SecretAccessKey>secretkeycorrect</SecretAccessKey>\n      <SessionToken>tokencorrect</SessionToken>\n      <Expiration>2009-02-13T23:31:30Z</Expiration>\n    </Credentials>\n  </AssumeRoleResult>\n  <ResponseMetadata>\n    <RequestId>d9d47248-fd55-4686-ad7c-0fb7cd1cddd7</RequestId>\n  </ResponseMetadata>\n</AssumeRoleResponse>\n",
        ));
        let (server, _request) = capture_request(Some(resp));
        let provider_conf = ProviderConfig::empty()
            .with_time_source(TimeSource::manual(&ManualTimeSource::new(
                UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
            )))
            .with_http_connector(DynConnector::new(server));
        let provider = AssumeRoleProvider::builder("myrole")
            .configure(&provider_conf)
            .region(Region::new("us-east-1"))
            .build(SharedCredentialsProvider::new(Credentials::new(
                "base",
                "basesecret",
                Some("token".to_string()),
                None,
                "inner",
            )));
        let creds_first = provider
            .provide_credentials()
            .await
            .expect("should return valid credentials");

        let creds_second = provider
            .provide_credentials()
            .await
            .expect("cached credentials should be returned");
        assert_eq!(creds_first, creds_second);
    }
}