Unverified Commit 38f84dc5 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Add support for web identity token providers (#654)

* Add support for web identity token providers

This diff adds support for WebIdentity token providers which may be utilized in three ways:
1. Via environment variables
2. Via a profile directly
3. As a source profile of another assume role profile

Along the way, I overhauled the way that tests work to be primarily self-contained test directories.

* make fs::read_to_end async

For future compatibility I updated read_to_end to be an async function. This drove a subsequent change to push profile file parsing into credentials loading. This should be generally useful anyway since we'll want a way to force reparsing of the profile file.

* Revert change to env

This removes the need for an allocation in the non-test case.

* Fix broken doc tests

* Fix clippy

* Update changelog, fix docs

* Remove caching from profile provider
parent 01cdd25d
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ vNext (Month Day, Year)
- Remove Bintray/JCenter source from gradle build. (#651)
- Add support for the smithy auth trait. This enables authorizations that explicitly disable authorization to work when no credentials have been provided. (#652)
- :bug: Fix STS Assume Role with WebIdentity & Assume role with SAML to support clients with no credentials provided (#652)
- Add support for WebIdentityToken providers via profile & environment variables. (#654)
- (internal): Update sigv4 tests to work around behavior change in httparse 1.5. (#656)

v0.20 (August 10th, 2021)
+6 −4
Original line number Diff line number Diff line
@@ -45,7 +45,6 @@ impl ChainProvider {
    }

    async fn credentials(&self) -> CredentialsResult {
        let mut last_error = CredentialsError::Unhandled("no providers".into());
        for (name, provider) in &self.providers {
            let span = tracing::info_span!("load_credentials", provider = %name);
            match provider.provide_credentials().instrument(span).await {
@@ -53,13 +52,16 @@ impl ChainProvider {
                    tracing::info!(provider = %name, "loaded credentials");
                    return Ok(credentials);
                }
                Err(CredentialsError::CredentialsNotLoaded) => {
                    tracing::info!(provider = %name, "provider in chain did not provide credentials");
                }
                Err(e) => {
                    tracing::info!(provider = %name, error = %e, "provider in chain did not provide credentials");
                    last_error = e
                    tracing::warn!(provider = %name, error = %e, "provider failed to provide credentials");
                    return Err(e);
                }
            }
        }
        return Err(last_error);
        return Err(CredentialsError::CredentialsNotLoaded);
    }
}

+65 −48
Original line number Diff line number Diff line
@@ -58,6 +58,7 @@ impl AsyncProvideCredentials for DefaultProviderChain {
#[derive(Default)]
pub struct Builder {
    profile_file_builder: crate::profile::Builder,
    web_identity_builder: crate::web_identity_token::Builder,
    credential_cache: aws_auth::provider::lazy_caching::builder::Builder,
    env: Option<Env>,
}
@@ -68,6 +69,7 @@ impl Builder {
    /// When unset, the default region resolver chain will be used.
    pub fn region(mut self, region: &dyn ProvideRegion) -> Self {
        self.profile_file_builder.set_region(region.region());
        self.web_identity_builder.set_region(region.region());
        self
    }

@@ -76,7 +78,9 @@ impl Builder {
    /// If a connector other than Hyper is used or if the Tokio/Hyper features have been disabled
    /// this method MUST be used to specify a custom connector.
    pub fn connector(mut self, connector: DynConnector) -> Self {
        self.profile_file_builder.set_connector(Some(connector));
        self.profile_file_builder
            .set_connector(Some(connector.clone()));
        self.web_identity_builder.set_connector(Some(connector));
        self
    }

@@ -133,7 +137,8 @@ impl Builder {
    ///
    /// This method exists primarily for testing credential providers
    pub fn fs(mut self, fs: Fs) -> Self {
        self.profile_file_builder.set_fs(Some(fs));
        self.profile_file_builder.set_fs(fs.clone());
        self.web_identity_builder.set_fs(fs);
        self
    }

@@ -143,7 +148,8 @@ impl Builder {
    /// This method exists primarily for testing credential providers
    pub fn env(mut self, env: Env) -> Self {
        self.env = Some(env.clone());
        self.profile_file_builder.set_env(Some(env));
        self.profile_file_builder.set_env(env.clone());
        self.web_identity_builder.set_env(env);
        self
    }

@@ -151,8 +157,10 @@ impl Builder {
        let profile_provider = self.profile_file_builder.build();
        let env_provider =
            EnvironmentVariableCredentialsProvider::new_with_env(self.env.unwrap_or_default());
        let web_identity_token_provider = self.web_identity_builder.build();
        let provider_chain = crate::chain::ChainProvider::first_try("Environment", env_provider)
            .or_else("Profile", profile_provider);
            .or_else("Profile", profile_provider)
            .or_else("WebIdentityToken", web_identity_token_provider);
        let cached_provider = self.credential_cache.load(provider_chain);
        DefaultProviderChain(cached_provider.build())
    }
@@ -160,53 +168,62 @@ impl Builder {

#[cfg(test)]
mod test {
    use crate::DefaultProviderChain;
    use aws_auth::provider::AsyncProvideCredentials;
    use aws_hyper::DynConnector;
    use aws_types::os_shim_internal::{Env, Fs};
    use smithy_client::dvr::ReplayingConnection;
    use tracing_test::traced_test;

    macro_rules! make_test {
        ($name: ident) => {
            #[traced_test]
            #[tokio::test]
    async fn prefer_environment() {
        let env = Env::from_slice(&[
            ("AWS_ACCESS_KEY_ID", "correct_key"),
            ("AWS_SECRET_ACCESS_KEY", "correct_secret"),
            ("HOME", "/Users/me"),
        ]);

        let fs = Fs::from_test_dir("test-data/aws-config/e2e-assume-role", "/Users/me");
        // empty connection will error if it is used
        let connection = ReplayingConnection::new(vec![]);
        let provider = DefaultProviderChain::builder()
            .fs(fs)
            async fn $name() {
                crate::test_case::TestEnvironment::from_dir(concat!(
                    "./test-data/default-provider-chain/",
                    stringify!($name)
                ))
                .unwrap()
                .execute(|fs, env, conn| {
                    crate::default_provider_chain::Builder::default()
                        .env(env)
            .connector(DynConnector::new(connection))
            .build();
        // empty connection will error if it is used
        let creds = provider.provide_credentials().await.expect("valid creds");
        assert_eq!(creds.access_key_id(), "correct_key");
        assert_eq!(creds.secret_access_key(), "correct_secret")
                        .fs(fs)
                        .region(&Region::from_static("us-east-1"))
                        .connector(conn)
                        .build()
                })
                .await
            }
        };
    }

    #[traced_test]
    use aws_sdk_sts::Region;

    use tracing_test::traced_test;

    make_test!(prefer_environment);
    make_test!(profile_static_keys);
    make_test!(web_identity_token_env);
    make_test!(web_identity_source_profile_no_env);
    make_test!(web_identity_token_invalid_jwt);
    make_test!(web_identity_token_source_profile);
    make_test!(web_identity_token_profile);
    make_test!(profile_overrides_web_identity);

    /// Helper that uses `execute_and_update` instead of execute
    ///
    /// If you run this, it will add another HTTP traffic log which re-records the request
    /// data
    #[tokio::test]
    async fn fallback_to_profile() {
        let env = Env::from_slice(&[
            // access keys not in environment
            ("HOME", "/Users/me"),
        ]);

        let fs = Fs::from_test_dir("./test-data/static-keys/aws-config", "/Users/me/.aws");
        // empty connection will error if it is used
        let connection = ReplayingConnection::new(vec![]);
        let provider = DefaultProviderChain::builder()
            .fs(fs)
    #[ignore]
    async fn update_test() {
        crate::test_case::TestEnvironment::from_dir(concat!(
            "./test-data/default-provider-chain/web_identity_token_source_profile",
        ))
        .unwrap()
        .execute_and_update(|fs, env, conn| {
            crate::default_provider_chain::Builder::default()
                .env(env)
            .connector(DynConnector::new(connection))
            .build();
        let creds = provider.provide_credentials().await.expect("valid creds");
        assert_eq!(creds.access_key_id(), "correct_key");
        assert_eq!(creds.secret_access_key(), "correct_secret")
                .fs(fs)
                .region(&Region::from_static("us-east-1"))
                .connector(conn)
                .build()
        })
        .await
    }
}
+14 −1
Original line number Diff line number Diff line
@@ -12,6 +12,9 @@ pub mod profile;

/// Credentials Provider that evaluates a series of providers
pub mod chain;
mod sts_util;
mod test_case;
pub mod web_identity_token;

// create a default connector given the currently enabled cargo features.
// rustls  | native tls | result
@@ -21,6 +24,10 @@ pub mod chain;
// no      | yes        | native_tls
// no      | no         | no default

fn must_have_connector() -> DynConnector {
    default_connector().expect("A connector was not available. Either set a custom connector or enable the `rustls` and `native-tls` crate features.")
}

#[cfg(feature = "rustls")]
fn default_connector() -> Option<DynConnector> {
    Some(DynConnector::new(smithy_client::conns::https()))
@@ -39,6 +46,12 @@ fn default_connector() -> Option<DynConnector> {
// because this doesn't provide any configuration, a runtime and connector must be provided.
#[cfg(all(any(feature = "native-tls", feature = "rustls"), feature = "rt-tokio"))]
/// Default AWS provider chain
///
/// This provider chain will use defaults for all settings. The region will be resolved with the default
/// provider chain. To construct a custom provider, use [`default_provider_chain::Builder`](default_provider_chain::Builder).
pub fn default_provider() -> impl AsyncProvideCredentials {
    default_provider_chain::Builder::default().build()
    let resolved_region = aws_types::region::default_provider();
    default_provider_chain::Builder::default()
        .region(&resolved_region)
        .build()
}
+97 −196
Original line number Diff line number Diff line
@@ -33,9 +33,10 @@ use aws_hyper::DynConnector;
use aws_sdk_sts::Region;
use aws_types::os_shim_internal::{Env, Fs};
use aws_types::profile::ProfileParseError;
use aws_types::region::ProvideRegion;
use tracing::Instrument;

use crate::default_connector;
use crate::must_have_connector;
use crate::profile::exec::named::NamedProviderFactory;
use crate::profile::exec::{ClientConfiguration, ProviderChain};

@@ -66,6 +67,10 @@ impl AsyncProvideCredentials for ProfileFileCredentialProvider {
/// let provider = ProfileFileCredentialProvider::builder().build();
/// ```
///
/// **Note:** Profile providers to not implement any caching. They will reload and reparse the profile
/// from the file system when called. See [lazy_caching](aws_auth::provider::lazy_caching) for
/// more information about caching.
///
/// This provider supports several different credentials formats:
/// ### Credentials defined explicitly within the file
/// ```ini
@@ -117,8 +122,12 @@ impl AsyncProvideCredentials for ProfileFileCredentialProvider {
///
/// Other more complex configurations are possible, consult `test-data/assume-role-tests.json`.
pub struct ProfileFileCredentialProvider {
    inner: Result<ProviderChain, ProfileFileError>,
    factory: NamedProviderFactory,
    client_config: ClientConfiguration,
    fs: Fs,
    env: Env,
    region: Option<Region>,
    connector: DynConnector,
}

impl ProfileFileCredentialProvider {
@@ -127,10 +136,26 @@ impl ProfileFileCredentialProvider {
    }

    async fn load_credentials(&self) -> CredentialsResult {
        let inner = self.inner.as_ref().map_err(|err| {
            CredentialsError::Unhandled(format!("failed to load: {}", &err).into())
        // 1. grab a read lock, use it to see if the base profile has already been loaded
        // 2. If it's loaded, great, lets use it.
        //    If not, upgrade to a write lock and use that to load the profile file.
        // 3. Finally, downgrade to ensure no one swapped in the intervening time, then use try_load()
        //    to pull the new state.
        let profile = build_provider_chain(
            &self.fs,
            &self.env,
            &self.region,
            &self.connector,
            &self.factory,
        )
        .await;
        let inner_provider = profile.map_err(|err| match err {
            ProfileFileError::NoProfilesDefined => CredentialsError::CredentialsNotLoaded,
            _ => CredentialsError::InvalidConfiguration(
                format!("ProfileFile provider could not be built: {}", &err).into(),
            ),
        })?;
        let mut creds = match inner
        let mut creds = match inner_provider
            .base()
            .provide_credentials()
            .instrument(tracing::info_span!("load_base_credentials"))
@@ -142,10 +167,10 @@ impl ProfileFileCredentialProvider {
            }
            Err(e) => {
                tracing::warn!(error = %e, "failed to load base credentials");
                return Err(e);
                return Err(CredentialsError::ProviderError(e.into()));
            }
        };
        for provider in inner.chain().iter() {
        for provider in inner_provider.chain().iter() {
            let next_creds = provider
                .credentials(creds, &self.client_config)
                .instrument(tracing::info_span!("load_assume_role", provider = ?provider))
@@ -157,7 +182,7 @@ impl ProfileFileCredentialProvider {
                }
                Err(e) => {
                    tracing::warn!(provider = ?provider, "failed to load assume role credentials");
                    return Err(e);
                    return Err(CredentialsError::ProviderError(e.into()));
                }
            }
        }
@@ -169,6 +194,7 @@ impl ProfileFileCredentialProvider {
#[non_exhaustive]
pub enum ProfileFileError {
    CouldNotParseProfile(ProfileParseError),
    NoProfilesDefined,
    CredentialLoop {
        profiles: Vec<String>,
        next: String,
@@ -216,6 +242,7 @@ impl Display for ProfileFileError {
                "profile referenced `{}` provider but that provider is not supported",
                name
            ),
            ProfileFileError::NoProfilesDefined => write!(f, "No profiles were defined"),
        }
    }
}
@@ -231,8 +258,8 @@ impl Error for ProfileFileError {

#[derive(Default)]
pub struct Builder {
    fs: Option<Fs>,
    env: Option<Env>,
    fs: Fs,
    env: Env,
    region: Option<Region>,
    connector: Option<DynConnector>,
    custom_providers: HashMap<Cow<'static, str>, Arc<dyn AsyncProvideCredentials>>,
@@ -240,21 +267,21 @@ pub struct Builder {

impl Builder {
    pub fn fs(mut self, fs: Fs) -> Self {
        self.fs = Some(fs);
        self.fs = fs;
        self
    }

    pub fn set_fs(&mut self, fs: Option<Fs>) -> &mut Self {
    pub fn set_fs(&mut self, fs: Fs) -> &mut Self {
        self.fs = fs;
        self
    }

    pub fn env(mut self, env: Env) -> Self {
        self.env = Some(env);
        self.env = env;
        self
    }

    pub fn set_env(&mut self, env: Option<Env>) -> &mut Self {
    pub fn set_env(&mut self, env: Env) -> &mut Self {
        self.env = env;
        self
    }
@@ -292,9 +319,9 @@ impl Builder {
    pub fn build(self) -> ProfileFileCredentialProvider {
        let build_span = tracing::info_span!("build_profile_provider");
        let _enter = build_span.enter();
        let fs = self.fs.unwrap_or_default();
        let env = self.env.unwrap_or_default();
        let mut named_providers = self.custom_providers;
        let env = self.env.clone();
        let fs = self.fs;
        let mut named_providers = self.custom_providers.clone();
        named_providers
            .entry("Environment".into())
            .or_insert_with(|| {
@@ -304,215 +331,89 @@ impl Builder {
            });
        // TODO: ECS, IMDS, and other named providers
        let factory = exec::named::NamedProviderFactory::new(named_providers);
        let chain = build_provider_chain(&fs, &env, &factory);
        let connector = self.connector.or_else(default_connector).expect(
            "a connector must be provided or the `rustls` or `native-tls` features must be enabled",
        );
        let connector = self.connector.clone().unwrap_or_else(must_have_connector);
        let core_client = aws_hyper::Builder::<()>::new()
            .map_connector(|_| connector)
            .map_connector(|_| connector.clone())
            .build();

        ProfileFileCredentialProvider {
            inner: chain,
            factory,
            client_config: ClientConfiguration {
                core_client,
                region: self.region,
                region: self.region.clone(),
            },
            fs,
            env,
            region: self.region.clone(),
            connector,
        }
    }
}

fn build_provider_chain(
async fn build_provider_chain(
    fs: &Fs,
    env: &Env,
    region: &dyn ProvideRegion,
    connector: &DynConnector,
    factory: &NamedProviderFactory,
) -> Result<ProviderChain, ProfileFileError> {
    let profile_set = aws_types::profile::load(&fs, &env).map_err(|err| {
    let profile_set = aws_types::profile::load(&fs, &env).await.map_err(|err| {
        tracing::warn!(err = %err, "failed to parse profile");
        ProfileFileError::CouldNotParseProfile(err)
    })?;
    let repr = repr::resolve_chain(&profile_set)?;
    tracing::info!(chain = ?repr, "constructed abstract provider from config file");
    exec::ProviderChain::from_repr(repr, &factory)
    exec::ProviderChain::from_repr(fs.clone(), connector, region, repr, &factory)
}

#[cfg(test)]
mod test {
    use std::fmt::Debug;
    use std::future::Future;
    use std::time::{Duration, UNIX_EPOCH};

    use aws_auth::provider::AsyncProvideCredentials;
    use aws_hyper::DynConnector;
    use aws_sdk_sts::Region;
    use aws_types::os_shim_internal::{Env, Fs};
    use smithy_client::dvr::{NetworkTraffic, RecordingConnection, ReplayingConnection};
    use tracing_test::traced_test;

    use crate::profile::{Builder, ProfileFileCredentialProvider};

    /// Record an interaction with a `ProfileFileCredentialProvider` to a network traffic trace
    #[allow(dead_code)]
    async fn record_test<F, T>(
        test_name: &str,
        f: impl Fn(ProfileFileCredentialProvider) -> F,
    ) -> (RecordingConnection<impl Debug>, T)
    where
        F: Future<Output = T>,
    {
        let fs = Fs::from_test_dir(
            format!("test-data/{}/aws-config", test_name),
            "/Users/me/.aws",
        );
        let env = Env::from_slice(&[("HOME", "/Users/me")]);
        let http_traffic_path = format!("test-data/{}/http-traffic.json", test_name);
        let conn = RecordingConnection::https();
        let provider = Builder::default()
            .env(env)
            .fs(fs)
            .region(Region::from_static("us-east-1"))
            .connector(DynConnector::new(conn.clone()));
        let provider = provider.build();
        let result = f(provider).await;
        let traffic = serde_json::to_string(&conn.network_traffic()).unwrap();
        std::fs::write(http_traffic_path, traffic).unwrap();
        (conn, result)
    }
    use crate::profile::Builder;
    use crate::test_case::TestEnvironment;

    async fn execute_test<F, T>(
        test_name: &str,
        f: impl Fn(ProfileFileCredentialProvider) -> F,
    ) -> (ReplayingConnection, T)
    where
        F: Future<Output = T>,
    {
        let fs = Fs::from_test_dir(
            format!("test-data/{}/aws-config", test_name),
            "/Users/me/.aws",
        );
        let env = Env::from_slice(&[("HOME", "/Users/me")]);
        let events =
            std::fs::read_to_string(format!("test-data/{}/http-traffic.json", test_name)).unwrap();
        let traffic: NetworkTraffic = serde_json::from_str(&events).unwrap();
        let conn = ReplayingConnection::new(traffic.events().clone());
        let provider = Builder::default()
    macro_rules! make_test {
        ($name: ident) => {
            #[traced_test]
            #[tokio::test]
            async fn $name() {
                TestEnvironment::from_dir(concat!(
                    "./test-data/profile-provider/",
                    stringify!($name)
                ))
                .unwrap()
                .execute(|fs, env, conn| {
                    Builder::default()
                        .env(env)
                        .fs(fs)
                        .region(Region::from_static("us-east-1"))
            .connector(DynConnector::new(conn.clone()));
        let provider = provider.build();
        (conn, f(provider).await)
    }

    #[tokio::test]
    async fn success_test() {
        let (conn, creds) = execute_test("e2e-assume-role", |provider| async move {
            provider.provide_credentials().await
                        .connector(conn)
                        .build()
                })
        .await;
        let creds = creds.expect("credentials should be valid");
        assert_eq!(creds.access_key_id(), "ASIARTESTID");
        assert_eq!(creds.secret_access_key(), "TESTSECRETKEY");
        assert_eq!(creds.session_token(), Some("TESTSESSIONTOKEN"));
        assert_eq!(
            creds.expiry(),
            Some(UNIX_EPOCH + Duration::from_secs(1628193482))
        );
        let reqs = conn.take_requests();
        assert_eq!(reqs.len(), 1);
        let req = reqs.first().unwrap();
        // TODO: perform more request validation
        assert_eq!(
            req.uri().to_string(),
            "https://sts.us-east-1.amazonaws.com/"
        );
                .await
            }

    #[tokio::test]
    async fn region_override() {
        let (conn, creds) = execute_test("region-override", |mut provider| async move {
            // manually override the region, normally this will be set by the builder during
            // provider construction
            provider.client_config.region = Some(Region::new("us-east-2"));
            provider.provide_credentials().await
        })
        .await;
        let creds = creds.expect("credentials should be valid");
        assert_eq!(creds.access_key_id(), "ASIARTESTID");
        assert_eq!(creds.secret_access_key(), "TESTSECRETKEY");
        assert_eq!(creds.session_token(), Some("TESTSESSIONTOKEN"));
        assert_eq!(
            creds.expiry(),
            Some(UNIX_EPOCH + Duration::from_secs(1628193482))
        );
        let reqs = conn.take_requests();
        assert_eq!(reqs.len(), 1);
        let req = reqs.first().unwrap();
        // TODO: perform more request validation
        assert_eq!(
            req.uri().to_string(),
            "https://sts.us-east-2.amazonaws.com/"
        );
        };
    }

    #[tokio::test]
    #[traced_test]
    async fn invalid_profile() {
        let (conn, _) = execute_test("invalid-config", |provider| async move {
            let error = provider
                .provide_credentials()
                .await
                .expect_err("config was invalid");
            assert!(
                format!("{}", error).contains("could not parse profile file"),
                "{} should contain correct error",
                error
            )
        })
        .await;
        assert!(
            conn.take_requests().is_empty(),
            "no network traffic should occur"
        );
    }
    make_test!(e2e_assume_role);
    make_test!(empty_config);
    make_test!(retry_on_error);
    make_test!(invalid_config);

    #[traced_test]
    #[tokio::test]
    async fn no_profile() {
        let (conn, _) = execute_test("empty-config", |provider| async move {
            let error = provider
                .provide_credentials()
                .await
                .expect_err("config was invalid");
            assert!(
                format!("{}", error).contains("profile `default` was not defined"),
                "{} should contain correct error",
                error
            )
    async fn region_override() {
        TestEnvironment::from_dir("./test-data/profile-provider/region_override")
            .unwrap()
            .execute(|fs, env, conn| {
                Builder::default()
                    .env(env)
                    .fs(fs)
                    .region(Region::from_static("us-east-2"))
                    .connector(conn)
                    .build()
            })
        .await;
        assert!(
            conn.take_requests().is_empty(),
            "no network traffic should occur"
        );
    }

    #[tokio::test]
    #[traced_test]
    async fn retry_on_error() {
        let (conn, creds) = execute_test("retry-on-error", |provider| async move {
            provider
                .provide_credentials()
            .await
                .expect("eventual success")
        })
        .await;
        assert_eq!(creds.access_key_id(), "ASIARTESTID");
        assert_eq!(creds.secret_access_key(), "TESTSECRETKEY");
        assert_eq!(creds.session_token(), Some("TESTSESSIONTOKEN"));
        assert_eq!(
            creds.expiry(),
            Some(UNIX_EPOCH + Duration::from_secs(1628193482))
        );
        assert_eq!(conn.take_requests().len(), 2);
    }
}
Loading