Unverified Commit d800d33e authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Use the orchestrator client for ECS and IMDS credentials in aws-config (#2997)

This ports the direct uses of the `aws_smithy_client::Client` in
aws_config to the orchestrator.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 33cd698f
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -241,3 +241,15 @@ message = "The `futures_core::stream::Stream` trait has been removed from [`Byte
references = ["smithy-rs#2983"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "ysaito1001"

[[smithy-rs]]
message = "`StaticUriEndpointResolver`'s `uri` constructor now takes a `String` instead of a `Uri`."
references = ["smithy-rs#2997"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"

[[aws-sdk-rust]]
message = "The IMDS Client builder's `build()` method is no longer async."
references = ["smithy-rs#2997"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "jdisanti"
+5 −1
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ license = "Apache-2.0"
repository = "https://github.com/awslabs/smithy-rs"

[features]
client-hyper = ["aws-smithy-client/client-hyper"]
client-hyper = ["aws-smithy-client/client-hyper", "aws-smithy-runtime/connector-hyper"]
rustls = ["aws-smithy-client/rustls", "client-hyper"]
native-tls = []
allow-compilation = [] # our tests use `cargo test --all-features` and native-tls breaks CI
@@ -27,7 +27,10 @@ aws-smithy-client = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-client", de
aws-smithy-http = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-http" }
aws-smithy-http-tower = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-http-tower" }
aws-smithy-json = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-json" }
aws-smithy-runtime = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-runtime", features = ["client"] }
aws-smithy-runtime-api = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-runtime-api", features = ["client"] }
aws-smithy-types = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-types" }
aws-runtime = { path = "../../sdk/build/aws-sdk/sdk/aws-runtime" }
aws-types = { path = "../../sdk/build/aws-sdk/sdk/aws-types" }
hyper = { version = "0.14.26", default-features = false }
time = { version = "0.3.4", features = ["parsing"] }
@@ -48,6 +51,7 @@ hex = { version = "0.4.3", optional = true }
zeroize = { version = "1", optional = true }

[dev-dependencies]
aws-smithy-runtime = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-runtime", features = ["client", "test-util"] }
futures-util = { version = "0.3.16", default-features = false }
tracing-test = "0.2.1"
tracing-subscriber = { version = "0.3.16", features = ["fmt", "json"] }
+2 −2
Original line number Diff line number Diff line
@@ -12,8 +12,8 @@
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    use aws_config::imds::Client;

    let imds = Client::builder().build().await?;
    let imds = Client::builder().build();
    let instance_id = imds.get("/latest/meta-data/instance-id").await?;
    println!("current instance id: {}", instance_id);
    println!("current instance id: {}", instance_id.as_ref());
    Ok(())
}
+11 −2
Original line number Diff line number Diff line
@@ -55,7 +55,7 @@ use aws_credential_types::provider::{self, error::CredentialsError, future, Prov
use aws_smithy_client::erase::boxclone::BoxCloneService;
use aws_smithy_http::endpoint::apply_endpoint;
use aws_smithy_types::error::display::DisplayErrorContext;
use http::uri::{InvalidUri, Scheme};
use http::uri::{InvalidUri, PathAndQuery, Scheme};
use http::{HeaderValue, Uri};
use tower::{Service, ServiceExt};

@@ -166,6 +166,15 @@ impl Provider {
            Err(EcsConfigurationError::NotConfigured) => return Provider::NotConfigured,
            Err(err) => return Provider::InvalidConfiguration(err),
        };
        let path = uri.path().to_string();
        let endpoint = {
            let mut parts = uri.into_parts();
            parts.path_and_query = Some(PathAndQuery::from_static("/"));
            Uri::from_parts(parts)
        }
        .expect("parts will be valid")
        .to_string();

        let http_provider = HttpCredentialProvider::builder()
            .configure(&provider_config)
            .connector_settings(
@@ -174,7 +183,7 @@ impl Provider {
                    .read_timeout(DEFAULT_READ_TIMEOUT)
                    .build(),
            )
            .build("EcsContainer", uri);
            .build("EcsContainer", &endpoint, path);
        Provider::Configured(http_provider)
    }

+231 −196
Original line number Diff line number Diff line
@@ -8,35 +8,43 @@
//!
//! Future work will stabilize this interface and enable it to be used directly.

use crate::connector::expect_connector;
use crate::json_credentials::{parse_json_credentials, JsonCredentials, RefreshableCredentials};
use crate::provider_config::ProviderConfig;
use aws_credential_types::provider::{self, error::CredentialsError};
use aws_credential_types::Credentials;
use aws_smithy_client::erase::DynConnector;
use aws_smithy_client::http_connector::ConnectorSettings;
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::operation::{Operation, Request};
use aws_smithy_http::response::ParseStrictResponse;
use aws_smithy_http::result::{SdkError, SdkSuccess};
use aws_smithy_http::retry::ClassifyRetry;
use aws_smithy_types::retry::{ErrorKind, RetryKind};

use crate::connector::expect_connector;
use crate::json_credentials::{parse_json_credentials, JsonCredentials, RefreshableCredentials};
use crate::provider_config::ProviderConfig;

use bytes::Bytes;
use aws_smithy_http::result::SdkError;
use aws_smithy_runtime::client::connectors::adapter::DynConnectorAdapter;
use aws_smithy_runtime::client::orchestrator::operation::Operation;
use aws_smithy_runtime::client::retries::classifier::{
    HttpStatusCodeClassifier, SmithyErrorClassifier,
};
use aws_smithy_runtime_api::client::connectors::SharedHttpConnector;
use aws_smithy_runtime_api::client::interceptors::context::{Error, InterceptorContext};
use aws_smithy_runtime_api::client::orchestrator::{
    HttpResponse, OrchestratorError, SensitiveOutput,
};
use aws_smithy_runtime_api::client::retries::{ClassifyRetry, RetryClassifiers, RetryReason};
use aws_smithy_runtime_api::client::runtime_plugin::StaticRuntimePlugin;
use aws_smithy_types::config_bag::Layer;
use aws_smithy_types::retry::{ErrorKind, RetryConfig};
use http::header::{ACCEPT, AUTHORIZATION};
use http::{HeaderValue, Response, Uri};
use http::{HeaderValue, Response};
use std::time::Duration;
use tower::layer::util::Identity;

const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);

#[derive(Debug)]
struct HttpProviderAuth {
    auth: Option<HeaderValue>,
}

#[derive(Debug)]
pub(crate) struct HttpCredentialProvider {
    uri: Uri,
    client: aws_smithy_client::Client<DynConnector, Identity>,
    provider_name: &'static str,
    operation: Operation<HttpProviderAuth, Credentials, CredentialsError>,
}

impl HttpCredentialProvider {
@@ -45,34 +53,13 @@ impl HttpCredentialProvider {
    }

    pub(crate) async fn credentials(&self, auth: Option<HeaderValue>) -> provider::Result {
        let credentials = self.client.call(self.operation(auth)).await;
        let credentials = self.operation.invoke(HttpProviderAuth { auth }).await;
        match credentials {
            Ok(creds) => Ok(creds),
            Err(SdkError::ServiceError(context)) => Err(context.into_err()),
            Err(other) => Err(CredentialsError::unhandled(other)),
        }
    }

    fn operation(
        &self,
        auth: Option<HeaderValue>,
    ) -> Operation<CredentialsResponseParser, HttpCredentialRetryClassifier> {
        let mut http_req = http::Request::builder()
            .uri(&self.uri)
            .header(ACCEPT, "application/json");

        if let Some(auth) = auth {
            http_req = http_req.header(AUTHORIZATION, auth);
        }
        let http_req = http_req.body(SdkBody::empty()).expect("valid request");
        Operation::new(
            Request::new(http_req),
            CredentialsResponseParser {
                provider_name: self.provider_name,
            },
        )
        .with_retry_classifier(HttpCredentialRetryClassifier)
    }
}

#[derive(Default)]
@@ -92,7 +79,12 @@ impl Builder {
        self
    }

    pub(crate) fn build(self, provider_name: &'static str, uri: Uri) -> HttpCredentialProvider {
    pub(crate) fn build(
        self,
        provider_name: &'static str,
        endpoint: &str,
        path: impl Into<String>,
    ) -> HttpCredentialProvider {
        let provider_config = self.provider_config.unwrap_or_default();
        let connector_settings = self.connector_settings.unwrap_or_else(|| {
            ConnectorSettings::builder()
@@ -104,36 +96,74 @@ impl Builder {
            "The HTTP credentials provider",
            provider_config.connector(&connector_settings),
        );
        let mut client_builder = aws_smithy_client::Client::builder()
            .connector(connector)
            .middleware(Identity::new());
        client_builder.set_sleep_impl(provider_config.sleep());
        let client = client_builder.build();
        HttpCredentialProvider {
            uri,
            client,
            provider_name,

        // The following errors are retryable:
        //   - Socket errors
        //   - Networking timeouts
        //   - 5xx errors
        //   - Non-parseable 200 responses.
        let retry_classifiers = RetryClassifiers::new()
            .with_classifier(HttpCredentialRetryClassifier)
            // Socket errors and network timeouts
            .with_classifier(SmithyErrorClassifier::<Error>::new())
            // 5xx errors
            .with_classifier(HttpStatusCodeClassifier::default());

        let mut builder = Operation::builder()
            .service_name("HttpCredentialProvider")
            .operation_name("LoadCredentials")
            .http_connector(SharedHttpConnector::new(DynConnectorAdapter::new(
                connector,
            )))
            .endpoint_url(endpoint)
            .no_auth()
            .runtime_plugin(StaticRuntimePlugin::new().with_config({
                let mut layer = Layer::new("SensitiveOutput");
                layer.store_put(SensitiveOutput);
                layer.freeze()
            }));
        if let Some(sleep_impl) = provider_config.sleep() {
            builder = builder
                .standard_retry(&RetryConfig::standard())
                .retry_classifiers(retry_classifiers)
                .sleep_impl(sleep_impl);
        } else {
            builder = builder.no_retry();
        }
        let path = path.into();
        let operation = builder
            .serializer(move |input: HttpProviderAuth| {
                let mut http_req = http::Request::builder()
                    .uri(path.clone())
                    .header(ACCEPT, "application/json");
                if let Some(auth) = input.auth {
                    http_req = http_req.header(AUTHORIZATION, auth);
                }
                Ok(http_req.body(SdkBody::empty()).expect("valid request"))
            })
            .deserializer(move |response| parse_response(provider_name, response))
            .build();
        HttpCredentialProvider { operation }
    }
}

#[derive(Clone, Debug)]
struct CredentialsResponseParser {
fn parse_response(
    provider_name: &'static str,
}
impl ParseStrictResponse for CredentialsResponseParser {
    type Output = provider::Result;

    fn parse(&self, response: &Response<Bytes>) -> Self::Output {
    response: &Response<SdkBody>,
) -> Result<Credentials, OrchestratorError<CredentialsError>> {
    if !response.status().is_success() {
            return Err(CredentialsError::provider_error(format!(
        return Err(OrchestratorError::operation(
            CredentialsError::provider_error(format!(
                "Non-success status from HTTP credential provider: {:?}",
                response.status()
            )));
            )),
        ));
    }
        let str_resp =
            std::str::from_utf8(response.body().as_ref()).map_err(CredentialsError::unhandled)?;
        let json_creds = parse_json_credentials(str_resp).map_err(CredentialsError::unhandled)?;
    let resp_bytes = response.body().bytes().expect("non-streaming deserializer");
    let str_resp = std::str::from_utf8(resp_bytes)
        .map_err(|err| OrchestratorError::operation(CredentialsError::unhandled(err)))?;
    let json_creds = parse_json_credentials(str_resp)
        .map_err(|err| OrchestratorError::operation(CredentialsError::unhandled(err)))?;
    match json_creds {
        JsonCredentials::RefreshableCredentials(RefreshableCredentials {
            access_key_id,
@@ -145,157 +175,162 @@ impl ParseStrictResponse for CredentialsResponseParser {
            secret_access_key,
            Some(session_token.to_string()),
            Some(expiration),
                self.provider_name,
            provider_name,
        )),
        JsonCredentials::Error { code, message } => Err(OrchestratorError::operation(
            CredentialsError::provider_error(format!(
                "failed to load credentials [{}]: {}",
                code, message
            )),
            JsonCredentials::Error { code, message } => Err(CredentialsError::provider_error(
                format!("failed to load credentials [{}]: {}", code, message),
        )),
        }
    }

    fn sensitive(&self) -> bool {
        true
    }
}

#[derive(Clone, Debug)]
struct HttpCredentialRetryClassifier;

impl ClassifyRetry<SdkSuccess<Credentials>, SdkError<CredentialsError>>
    for HttpCredentialRetryClassifier
{
    fn classify_retry(
        &self,
        response: Result<&SdkSuccess<Credentials>, &SdkError<CredentialsError>>,
    ) -> RetryKind {
        /* The following errors are retryable:
         *   - Socket errors
         *   - Networking timeouts
         *   - 5xx errors
         *   - Non-parseable 200 responses.
         *  */
        match response {
            Ok(_) => RetryKind::Unnecessary,
            // socket errors, networking timeouts
            Err(SdkError::DispatchFailure(client_err))
                if client_err.is_timeout() || client_err.is_io() =>
            {
                RetryKind::Error(ErrorKind::TransientError)
            }
            // non-parseable 200s
            Err(SdkError::ServiceError(context))
                if matches!(context.err(), CredentialsError::Unhandled { .. })
                    && context.raw().http().status().is_success() =>
            {
                RetryKind::Error(ErrorKind::ServerError)
impl ClassifyRetry for HttpCredentialRetryClassifier {
    fn name(&self) -> &'static str {
        "HttpCredentialRetryClassifier"
    }
            // 5xx errors
            Err(SdkError::ResponseError(context))
                if context.raw().http().status().is_server_error() =>
            {
                RetryKind::Error(ErrorKind::ServerError)
            }
            Err(SdkError::ServiceError(context))
                if context.raw().http().status().is_server_error() =>

    fn classify_retry(&self, ctx: &InterceptorContext) -> Option<RetryReason> {
        let output_or_error = ctx.output_or_error()?;
        let error = match output_or_error {
            Ok(_) => return None,
            Err(err) => err,
        };

        // Retry non-parseable 200 responses
        if let Some((err, status)) = error
            .as_operation_error()
            .and_then(|err| err.downcast_ref::<CredentialsError>())
            .zip(ctx.response().map(HttpResponse::status))
        {
                RetryKind::Error(ErrorKind::ServerError)
            if matches!(err, CredentialsError::Unhandled { .. }) && status.is_success() {
                return Some(RetryReason::Error(ErrorKind::ServerError));
            }
            Err(_) => RetryKind::UnretryableFailure,
        }

        None
    }
}

#[cfg(test)]
mod test {
    use crate::http_credential_provider::{
        CredentialsResponseParser, HttpCredentialRetryClassifier,
    };
    use super::*;
    use aws_credential_types::provider::error::CredentialsError;
    use aws_credential_types::Credentials;
    use aws_smithy_client::test_connection::TestConnection;
    use aws_smithy_http::body::SdkBody;
    use aws_smithy_http::operation;
    use aws_smithy_http::response::ParseStrictResponse;
    use aws_smithy_http::result::{SdkError, SdkSuccess};
    use aws_smithy_http::retry::ClassifyRetry;
    use aws_smithy_types::retry::{ErrorKind, RetryKind};
    use bytes::Bytes;

    fn sdk_resp(
        resp: http::Response<&'static str>,
    ) -> Result<SdkSuccess<Credentials>, SdkError<CredentialsError>> {
        let resp = resp.map(|data| Bytes::from_static(data.as_bytes()));
        match (CredentialsResponseParser {
            provider_name: "test",
        })
        .parse(&resp)
        {
            Ok(creds) => Ok(SdkSuccess {
                raw: operation::Response::new(resp.map(SdkBody::from)),
                parsed: creds,
            }),
            Err(err) => Err(SdkError::service_error(
                err,
                operation::Response::new(resp.map(SdkBody::from)),
            )),
        }
    }

    #[test]
    fn non_parseable_is_retriable() {
        let bad_response = http::Response::builder()
            .status(200)
            .body("notjson")
            .unwrap();
    use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
    use http::{Request, Response, Uri};
    use std::time::SystemTime;

        assert_eq!(
            HttpCredentialRetryClassifier.classify_retry(sdk_resp(bad_response).as_ref()),
            RetryKind::Error(ErrorKind::ServerError)
        );
    async fn provide_creds(
        connector: TestConnection<SdkBody>,
    ) -> Result<Credentials, CredentialsError> {
        let provider_config = ProviderConfig::default().with_http_connector(connector.clone());
        let provider = HttpCredentialProvider::builder()
            .configure(&provider_config)
            .build("test", "http://localhost:1234/", "/some-creds");
        provider.credentials(None).await
    }

    #[test]
    fn ok_response_not_retriable() {
        let ok_response = http::Response::builder()
    fn successful_req_resp() -> (HttpRequest, HttpResponse) {
        (
            Request::builder()
                .uri(Uri::from_static("http://localhost:1234/some-creds"))
                .body(SdkBody::empty())
                .unwrap(),
            Response::builder()
                .status(200)
            .body(
                .body(SdkBody::from(
                    r#"{
                        "AccessKeyId" : "MUA...",
                        "SecretAccessKey" : "/7PC5om....",
                        "Token" : "AQoDY....=",
                        "Expiration" : "2016-02-25T06:03:31Z"
                    }"#,
                ))
                .unwrap(),
        )
            .unwrap();
        let sdk_result = sdk_resp(ok_response);
    }

    #[tokio::test]
    async fn successful_response() {
        let connector = TestConnection::new(vec![successful_req_resp()]);
        let creds = provide_creds(connector.clone()).await.expect("success");
        assert_eq!("MUA...", creds.access_key_id());
        assert_eq!("/7PC5om....", creds.secret_access_key());
        assert_eq!(Some("AQoDY....="), creds.session_token());
        assert_eq!(
            HttpCredentialRetryClassifier.classify_retry(sdk_result.as_ref()),
            RetryKind::Unnecessary
            Some(SystemTime::UNIX_EPOCH + Duration::from_secs(1456380211)),
            creds.expiry()
        );
        connector.assert_requests_match(&[]);
    }

        assert!(sdk_result.is_ok(), "should be ok: {:?}", sdk_result)
    #[tokio::test]
    async fn retry_nonparseable_response() {
        let connector = TestConnection::new(vec![
            (
                Request::builder()
                    .uri(Uri::from_static("http://localhost:1234/some-creds"))
                    .body(SdkBody::empty())
                    .unwrap(),
                Response::builder()
                    .status(200)
                    .body(SdkBody::from(r#"not json"#))
                    .unwrap(),
            ),
            successful_req_resp(),
        ]);
        let creds = provide_creds(connector.clone()).await.expect("success");
        assert_eq!("MUA...", creds.access_key_id());
        connector.assert_requests_match(&[]);
    }

    #[test]
    fn explicit_error_not_retriable() {
        let error_response = http::Response::builder()
            .status(400)
            .body(r#"{ "Code": "Error", "Message": "There was a problem, it was your fault" }"#)
            .unwrap();
        let sdk_result = sdk_resp(error_response);
        assert_eq!(
            HttpCredentialRetryClassifier.classify_retry(sdk_result.as_ref()),
            RetryKind::UnretryableFailure
        );
        let sdk_error = sdk_result.expect_err("should be error");
    #[tokio::test]
    async fn retry_error_code() {
        let connector = TestConnection::new(vec![
            (
                Request::builder()
                    .uri(Uri::from_static("http://localhost:1234/some-creds"))
                    .body(SdkBody::empty())
                    .unwrap(),
                Response::builder()
                    .status(500)
                    .body(SdkBody::from(r#"it broke"#))
                    .unwrap(),
            ),
            successful_req_resp(),
        ]);
        let creds = provide_creds(connector.clone()).await.expect("success");
        assert_eq!("MUA...", creds.access_key_id());
        connector.assert_requests_match(&[]);
    }

    #[tokio::test]
    async fn explicit_error_not_retriable() {
        let connector = TestConnection::new(vec![(
            Request::builder()
                .uri(Uri::from_static("http://localhost:1234/some-creds"))
                .body(SdkBody::empty())
                .unwrap(),
            Response::builder()
                .status(400)
                .body(SdkBody::from(
                    r#"{ "Code": "Error", "Message": "There was a problem, it was your fault" }"#,
                ))
                .unwrap(),
        )]);
        let err = provide_creds(connector.clone())
            .await
            .expect_err("it should fail");
        assert!(
            matches!(
                sdk_error,
                SdkError::ServiceError(ref context) if matches!(context.err(), CredentialsError::ProviderError { .. })
            ),
            "should be provider error: {}",
            sdk_error
            matches!(err, CredentialsError::ProviderError { .. }),
            "should be CredentialsError::ProviderError: {err}",
        );
        connector.assert_requests_match(&[]);
    }
}
Loading