Unverified Commit 607c8918 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Refactor aws-config tests to allow other identity types (#3440)

The aws-config test environment and test-data makes a heavy assumption
that there will only ever be credentials. This PR proactively refactors
the tests to make room for other future types of identities, such as
tokens.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 6187f884
Loading
Loading
Loading
Loading
+38 −30
Original line number Diff line number Diff line
@@ -198,18 +198,18 @@ impl Builder {

#[cfg(test)]
mod test {
    use crate::test_case::TestEnvironment;
    use crate::{
        default_provider::credentials::DefaultCredentialsChain, test_case::StaticTestProvider,
    };
    use aws_credential_types::provider::ProvideCredentials;
    use aws_smithy_async::time::StaticTimeSource;
    use std::time::UNIX_EPOCH;

    use crate::default_provider::credentials::DefaultCredentialsChain;

    use crate::test_case::TestEnvironment;

    /// Test generation macro
    ///
    /// # Examples
    /// **Run the test case in `test-data/default-provider-chain/test_name`
    /// **Run the test case in `test-data/default-credential-provider-chain/test_name`
    /// ```no_run
    /// make_test!(test_name);
    /// ```
@@ -245,23 +245,24 @@ mod test {
            $(#[$m])*
            #[tokio::test]
            async fn $name() {
                crate::test_case::TestEnvironment::from_dir(concat!(
                    "./test-data/default-provider-chain/",
                    stringify!($name)
                ))
                .await
                .unwrap()
                .with_provider_config($provider_config_builder)
                .$func(|conf| {
                    let conf = conf.clone();
                let _ = crate::test_case::TestEnvironment::from_dir(
                    concat!("./test-data/default-credential-provider-chain/", stringify!($name)),
                    crate::test_case::test_credentials_provider(|config| {
                        async move {
                            crate::default_provider::credentials::Builder::default()
                            .configure(conf)
                                .configure(config)
                                .build()
                                .await
                                .provide_credentials()
                                .await
                        }
                    })
                )
                .await
                .unwrap()
                .map_provider_config($provider_config_builder)
                .$func()
                .await;
            }
        };
    }
@@ -312,21 +313,28 @@ mod test {

    #[tokio::test]
    async fn profile_name_override() {
        let conf =
            TestEnvironment::from_dir("./test-data/default-provider-chain/profile_static_keys")
        // Only use the TestEnvironment to create a ProviderConfig from the
        // profile_static_keys test directory. We don't actually want to
        // use the expected test output from that directory since we're
        // overriding the profile name on the credentials chain in this test.
        let provider_config = TestEnvironment::<crate::test_case::Credentials, ()>::from_dir(
            "./test-data/default-credential-provider-chain/profile_static_keys",
            StaticTestProvider::new(|_| unreachable!()),
        )
        .await
        .unwrap()
        .provider_config()
        .clone();
        let provider = DefaultCredentialsChain::builder()

        let creds = DefaultCredentialsChain::builder()
            .profile_name("secondary")
            .configure(conf)
            .configure(provider_config)
            .build()
            .await;
        let creds = provider
            .await
            .provide_credentials()
            .await
            .expect("creds should load");

        assert_eq!(creds.access_key_id(), "correct_key_secondary");
    }

+1 −1
Original line number Diff line number Diff line
@@ -478,7 +478,7 @@ mod test {
            let uri = Provider::uri(env, Some(TestDns::default().into_shared()))
                .await
                .map(|uri| uri.to_string());
            self.result.assert_matches(uri);
            self.result.assert_matches(uri.as_ref());
        }
    }

+13 −6
Original line number Diff line number Diff line
@@ -565,19 +565,26 @@ impl ChainProvider {
mod test {
    use crate::profile::credentials::Builder;
    use crate::test_case::TestEnvironment;
    use aws_credential_types::provider::ProvideCredentials;

    macro_rules! make_test {
        ($name: ident) => {
            #[tokio::test]
            async fn $name() {
                TestEnvironment::from_dir(concat!(
                    "./test-data/profile-provider/",
                    stringify!($name)
                ))
                let _ = TestEnvironment::from_dir(
                    concat!("./test-data/profile-provider/", stringify!($name)),
                    crate::test_case::test_credentials_provider(|config| async move {
                        Builder::default()
                            .configure(&config)
                            .build()
                            .provide_credentials()
                            .await
                .unwrap()
                .execute(|conf| async move { Builder::default().configure(&conf).build() })
                    }),
                )
                .await
                .unwrap()
                .execute()
                .await;
            }
        };
    }
+173 −87
Original line number Diff line number Diff line
@@ -6,12 +6,12 @@
use crate::default_provider::use_dual_stack::use_dual_stack_provider;
use crate::default_provider::use_fips::use_fips_provider;
use crate::provider_config::ProviderConfig;
use aws_credential_types::provider::{self, ProvideCredentials};
use aws_smithy_async::rt::sleep::{AsyncSleep, Sleep, TokioSleep};
use aws_smithy_runtime::client::http::test_util::dvr::{
    NetworkTraffic, RecordingClient, ReplayingClient,
};
use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::shared::IntoShared;
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_types::os_shim_internal::{Env, Fs};
@@ -22,18 +22,38 @@ use std::error::Error;
use std::fmt::Debug;
use std::future::Future;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::time::{Duration, UNIX_EPOCH};

mod sealed {
    /// Trait that provides secret values for a given test output type (credentials or tokens)
    pub(crate) trait Secrets {
        fn secrets(&self) -> Vec<String>;
    }
    impl Secrets for () {
        fn secrets(&self) -> Vec<String> {
            Vec::new()
        }
    }
}
use sealed::Secrets;

/// Test case credentials
///
/// Credentials for use in test cases. These implement Serialize/Deserialize and have a
/// non-hidden debug implementation.
#[derive(Deserialize, Debug, Eq, PartialEq)]
struct Credentials {
    access_key_id: String,
    secret_access_key: String,
    session_token: Option<String>,
    expiry: Option<u64>,
pub(crate) struct Credentials {
    pub(crate) access_key_id: String,
    pub(crate) secret_access_key: String,
    pub(crate) session_token: Option<String>,
    pub(crate) expiry: Option<u64>,
}

impl Secrets for Credentials {
    fn secrets(&self) -> Vec<String> {
        vec![self.secret_access_key.clone()]
    }
}

/// Convert real credentials to test credentials
@@ -59,20 +79,6 @@ impl From<aws_credential_types::Credentials> for Credentials {
    }
}

/// Credentials test environment
///
/// A credentials test environment is a directory containing:
/// - an `fs` directory. This is loaded into the test as if it was mounted at `/`
/// - an `env.json` file containing environment variables
/// - an  `http-traffic.json` file containing an http traffic log from [`dvr`](aws_smithy_runtime::client::http::test_utils::dvr)
/// - a `test-case.json` file defining the expected output of the test
pub(crate) struct TestEnvironment {
    metadata: Metadata,
    base_dir: PathBuf,
    http_client: ReplayingClient,
    provider_config: ProviderConfig,
}

/// Connector which expects no traffic
pub(crate) fn no_traffic_client() -> SharedHttpClient {
    ReplayingClient::new(Vec::new()).into_shared()
@@ -92,48 +98,123 @@ pub(crate) enum GenericTestResult<T> {
    ErrorContains(String),
}

impl<T> GenericTestResult<T>
impl<O> GenericTestResult<O>
where
    T: PartialEq + Debug,
    O: PartialEq + Debug,
{
    #[track_caller]
    pub(crate) fn assert_matches(&self, result: Result<impl Into<T>, impl Error>) {
    pub(crate) fn assert_matches<E>(&self, result: Result<&O, &E>)
    where
        E: Error,
    {
        match (result, &self) {
            (Ok(actual), GenericTestResult::Ok(expected)) => {
                assert_eq!(expected, &actual.into(), "incorrect result was returned")
                assert_eq!(expected, actual, "incorrect result was returned")
            }
            (Err(err), GenericTestResult::ErrorContains(substr)) => {
                let message = format!("{}", DisplayErrorContext(&err));
                let message = format!("{}", DisplayErrorContext(err));
                assert!(
                    message.contains(substr),
                    "`{message}` did not contain `{substr}`"
                );
            }
            (Err(actual_error), GenericTestResult::Ok(expected_creds)) => panic!(
            (Err(actual_error), GenericTestResult::Ok(expected_creds)) => {
                panic!(
                    "expected credentials ({:?}) but an error was returned: {}",
                    expected_creds,
                DisplayErrorContext(&actual_error)
            ),
            (Ok(creds), GenericTestResult::ErrorContains(substr)) => panic!(
                    DisplayErrorContext(actual_error)
                )
            }
            (Ok(output), GenericTestResult::ErrorContains(substr)) => panic!(
                "expected an error containing: `{}`, but a result was returned: {:?}",
                substr,
                creds.into()
                substr, output
            ),
        }
    }
}

type TestResult = GenericTestResult<Credentials>;

#[derive(Deserialize)]
pub(crate) struct Metadata {
    result: TestResult,
pub(crate) struct Metadata<T> {
    result: GenericTestResult<T>,
    docs: String,
    name: String,
}

impl TestEnvironment {
    pub(crate) async fn from_dir(dir: impl AsRef<Path>) -> Result<TestEnvironment, Box<dyn Error>> {
pub(crate) trait RunTestProvider {
    type Output: for<'a> Deserialize<'a> + Secrets;
    type Error;

    fn run_provider(
        &self,
        provider_config: ProviderConfig,
    ) -> Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send + 'static>>;
}

type ResultFuture<O, E> = Pin<Box<dyn Future<Output = Result<O, E>> + Send + 'static>>;
pub(crate) struct StaticTestProvider<O, E> {
    run_provider_fn: Box<dyn Fn(ProviderConfig) -> ResultFuture<O, E> + 'static>,
}
impl<O, E> StaticTestProvider<O, E> {
    pub(crate) fn new<F>(run_provider_fn: F) -> Self
    where
        F: Fn(ProviderConfig) -> ResultFuture<O, E> + 'static,
    {
        Self {
            run_provider_fn: Box::new(run_provider_fn) as _,
        }
    }
}
impl<O, E> RunTestProvider for StaticTestProvider<O, E>
where
    O: for<'a> Deserialize<'a> + Secrets,
{
    type Output = O;
    type Error = E;

    fn run_provider(
        &self,
        provider_config: ProviderConfig,
    ) -> Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send + 'static>> {
        (self.run_provider_fn)(provider_config)
    }
}

pub(crate) fn test_credentials_provider<F, Fut, E>(
    run_provider_fn: F,
) -> impl RunTestProvider<Output = Credentials, Error = E>
where
    F: Fn(ProviderConfig) -> Fut + Send + Clone + 'static,
    Fut: Future<Output = Result<aws_credential_types::Credentials, E>> + Send,
{
    StaticTestProvider::<Credentials, E>::new(move |config| {
        let run_provider_fn = run_provider_fn.clone();
        Box::pin(async move { (run_provider_fn)(config).await.map(Into::into) })
    })
}

/// Provider test environment
///
/// A provider test environment is a directory containing:
/// - an `fs` directory. This is loaded into the test as if it was mounted at `/`
/// - an `env.json` file containing environment variables
/// - an  `http-traffic.json` file containing an http traffic log from [`dvr`](aws_smithy_runtime::client::http::test_utils::dvr)
/// - a `test-case.json` file defining the expected output of the test
pub(crate) struct TestEnvironment<O, E> {
    metadata: Metadata<O>,
    base_dir: PathBuf,
    http_client: ReplayingClient,
    provider_config: ProviderConfig,
    run_provider: Box<dyn RunTestProvider<Output = O, Error = E>>,
}

impl<O, E> TestEnvironment<O, E>
where
    O: for<'a> Deserialize<'a>,
{
    pub(crate) async fn from_dir(
        dir: impl AsRef<Path>,
        run_provider: impl RunTestProvider<Output = O, Error = E> + 'static,
    ) -> Result<Self, BoxError> {
        let dir = dir.as_ref();
        let env = std::fs::read_to_string(dir.join("env.json"))
            .map_err(|e| format!("failed to load env: {}", e))?;
@@ -145,7 +226,7 @@ impl TestEnvironment {
            .map_err(|e| format!("failed to load http traffic: {}", e))?;
        let network_traffic: NetworkTraffic = serde_json::from_str(&network_traffic)?;

        let metadata: Metadata = serde_json::from_str(
        let metadata: Metadata<O> = serde_json::from_str(
            &std::fs::read_to_string(dir.join("test-case.json"))
                .map_err(|e| format!("failed to load test case: {}", e))?,
        )?;
@@ -164,15 +245,16 @@ impl TestEnvironment {
            .with_use_fips(use_fips)
            .with_use_dual_stack(use_dual_stack);

        Ok(TestEnvironment {
        Ok(Self {
            base_dir: dir.into(),
            metadata,
            http_client,
            provider_config,
            run_provider: Box::new(run_provider),
        })
    }

    pub(crate) fn with_provider_config<F>(mut self, provider_config_builder: F) -> Self
    pub(crate) fn map_provider_config<F>(mut self, provider_config_builder: F) -> Self
    where
        F: Fn(ProviderConfig) -> ProviderConfig,
    {
@@ -183,20 +265,22 @@ impl TestEnvironment {
    pub(crate) fn provider_config(&self) -> &ProviderConfig {
        &self.provider_config
    }
}

impl<O, E> TestEnvironment<O, E>
where
    O: for<'a> Deserialize<'a> + Secrets + PartialEq + Debug,
    E: Error,
{
    #[allow(unused)]
    #[cfg(all(feature = "client-hyper", feature = "rustls"))]
    /// Record a test case from live (remote) HTTPS traffic
    ///
    /// The `default_connector()` from the crate will be used
    pub(crate) async fn execute_from_live_traffic<F, P>(
        &self,
        make_provider: impl Fn(ProviderConfig) -> F,
    ) where
        F: Future<Output = P>,
        P: ProvideCredentials,
    {
    pub(crate) async fn execute_from_live_traffic(&self) {
        // swap out the connector generated from `http-traffic.json` for a real connector:

        use std::error::Error;
        let live_connector = aws_smithy_runtime::client::http::hyper_014::default_connector(
            &Default::default(),
            self.provider_config.sleep_impl(),
@@ -207,14 +291,13 @@ impl TestEnvironment {
            .provider_config
            .clone()
            .with_http_client(live_client.clone());
        let provider = make_provider(config).await;
        let result = provider.provide_credentials().await;
        let result = self.run_provider.run_provider(config).await;
        std::fs::write(
            self.base_dir.join("http-traffic-recorded.json"),
            serde_json::to_string(&live_client.network_traffic()).unwrap(),
        )
        .unwrap();
        self.check_results(result);
        self.check_results(result.as_ref());
    }

    #[allow(dead_code)]
@@ -222,55 +305,31 @@ impl TestEnvironment {
    ///
    /// A connector will be created with the factory, then request traffic will be recorded.
    /// Response are generated from the existing http-traffic.json.
    pub(crate) async fn execute_and_update<F, P>(&self, make_provider: impl Fn(ProviderConfig) -> F)
    where
        F: Future<Output = P>,
        P: ProvideCredentials,
    {
    pub(crate) async fn execute_and_update(&self) {
        let recording_client = RecordingClient::new(self.http_client.clone());
        let config = self
            .provider_config
            .clone()
            .with_http_client(recording_client.clone());
        let provider = make_provider(config).await;
        let result = provider.provide_credentials().await;
        let result = self.run_provider.run_provider(config).await;
        std::fs::write(
            self.base_dir.join("http-traffic-recorded.json"),
            serde_json::to_string(&recording_client.network_traffic()).unwrap(),
        )
        .unwrap();
        self.check_results(result);
    }

    fn log_info(&self) {
        eprintln!("test case: {}. {}", self.metadata.name, self.metadata.docs);
    }

    fn lines_with_secrets<'a>(&'a self, logs: &'a str) -> Vec<&'a str> {
        logs.lines().filter(|l| self.contains_secret(l)).collect()
    }

    fn contains_secret(&self, log_line: &str) -> bool {
        assert!(log_line.lines().count() <= 1);
        match &self.metadata.result {
            // NOTE: we aren't currently erroring if the session token is leaked, that is in the canonical request among other things
            TestResult::Ok(creds) => log_line.contains(&creds.secret_access_key),
            TestResult::ErrorContains(_) => false,
        }
        self.check_results(result.as_ref());
    }

    /// Execute a test case. Failures lead to panics.
    pub(crate) async fn execute<F, P>(&self, make_provider: impl Fn(ProviderConfig) -> F)
    where
        F: Future<Output = P>,
        P: ProvideCredentials,
    {
    pub(crate) async fn execute(&self) -> Result<O, E> {
        let (_guard, rx) = capture_test_logs();
        let provider = make_provider(self.provider_config.clone()).await;
        let result = provider.provide_credentials().await;
        let result = self
            .run_provider
            .run_provider(self.provider_config.clone())
            .await;
        tokio::time::pause();
        self.log_info();
        self.check_results(result);
        self.check_results(result.as_ref());
        // todo: validate bodies
        match self
            .http_client
@@ -291,11 +350,38 @@ impl TestEnvironment {
            "secret was exposed\n{:?}\nSee the following log lines:\n  {}",
            self.metadata.result,
            leaking_lines.join("\n  ")
        )
        );
        result
    }

    fn log_info(&self) {
        eprintln!("test case: {}. {}", self.metadata.name, self.metadata.docs);
    }

    fn lines_with_secrets<'a>(&'a self, logs: &'a str) -> Vec<&'a str> {
        logs.lines()
            .filter(|l| self.contains_any_secrets(l))
            .collect()
    }

    fn contains_any_secrets(&self, log_line: &str) -> bool {
        assert!(log_line.lines().count() <= 1);
        match &self.metadata.result {
            // NOTE: we aren't currently erroring if the session token is leaked, that is in the canonical request among other things
            GenericTestResult::Ok(output) => output
                .secrets()
                .iter()
                .any(|secret| log_line.contains(secret)),
            GenericTestResult::ErrorContains(_) => false,
        }
    }

    #[track_caller]
    fn check_results(&self, result: provider::Result) {
    fn check_results(&self, result: Result<&O, &E>)
    where
        O: PartialEq + Debug,
        E: Error,
    {
        self.metadata.result.assert_matches(result);
    }
}
+0 −0

File moved.

Loading