Unverified Commit acb693fa authored by Landon James's avatar Landon James Committed by GitHub
Browse files

IMDS retries (#3975)

## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here -->
Addressing https://github.com/awslabs/aws-sdk-rust/issues/1233

## Description
<!--- Describe your changes in detail -->
Add ability to configure the retry classifier on the IMDS client.

## Testing
<!--- Please describe in detail how you tested your changes -->
<!--- Include details of your testing environment, and the tests you ran
to -->
<!--- see how your change affects other areas of the code, etc. -->
Added new unit test to ensure that a user defined retry classifer is
being used.

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] For changes to the smithy-rs codegen or runtime crates, I have
created a changelog entry Markdown file in the `.changelog` directory,
specifying "client," "server," or both in the `applies_to` key.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 4ad631a6
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
---
applies_to: ["client"]
authors: ["landonxjames"]
references: ["aws-sdk-rust#1233"]
breaking: false
new_feature: true
bug_fix: false
---

Allow IMDS clients to be configured with a user-provided `SharedRetryClassifier`.
+1 −1
Original line number Diff line number Diff line
@@ -45,7 +45,7 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"

[[package]]
name = "aws-config"
version = "1.5.14"
version = "1.5.15"
dependencies = [
 "aws-credential-types",
 "aws-runtime",
+2 −0
Original line number Diff line number Diff line
@@ -32,6 +32,8 @@ allowed_external_types = [
   "aws_smithy_runtime_api::client::identity::ResolveIdentity",
   "aws_smithy_runtime_api::client::orchestrator::HttpResponse",
   "aws_smithy_runtime_api::client::result::SdkError",
   "aws_smithy_runtime_api::client::retries::classifiers::ClassifyRetry",
   "aws_smithy_runtime_api::client::retries::classifiers::SharedRetryClassifier",
   "aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig",
   "aws_smithy_types::body::SdkBody",
   "aws_smithy_types::checksum_config::RequestChecksumCalculation",
+157 −8
Original line number Diff line number Diff line
@@ -52,6 +52,8 @@ const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
const DEFAULT_ATTEMPTS: u32 = 4;
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);
const DEFAULT_OPERATION_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_OPERATION_ATTEMPT_TIMEOUT: Duration = Duration::from_secs(10);

fn user_agent() -> AwsUserAgent {
    AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
@@ -238,6 +240,7 @@ impl ImdsCommonRuntimePlugin {
        config: &ProviderConfig,
        endpoint_resolver: ImdsEndpointResolver,
        retry_config: RetryConfig,
        retry_classifier: SharedRetryClassifier,
        timeout_config: TimeoutConfig,
    ) -> Self {
        let mut layer = Layer::new("ImdsCommonRuntimePlugin");
@@ -254,7 +257,7 @@ impl ImdsCommonRuntimePlugin {
                .with_http_client(config.http_client())
                .with_endpoint_resolver(Some(endpoint_resolver))
                .with_interceptor(UserAgentInterceptor::new())
                .with_retry_classifier(SharedRetryClassifier::new(ImdsResponseRetryClassifier))
                .with_retry_classifier(retry_classifier)
                .with_retry_strategy(Some(StandardRetryStrategy::new()))
                .with_time_source(Some(config.time_source()))
                .with_sleep_impl(config.sleep_impl()),
@@ -322,7 +325,10 @@ pub struct Builder {
    token_ttl: Option<Duration>,
    connect_timeout: Option<Duration>,
    read_timeout: Option<Duration>,
    operation_timeout: Option<Duration>,
    operation_attempt_timeout: Option<Duration>,
    config: Option<ProviderConfig>,
    retry_classifier: Option<SharedRetryClassifier>,
}

impl Builder {
@@ -398,6 +404,32 @@ impl Builder {
        self
    }

    /// Override the operation timeout for IMDS
    ///
    /// This value defaults to 1 second
    pub fn operation_timeout(mut self, timeout: Duration) -> Self {
        self.operation_timeout = Some(timeout);
        self
    }

    /// Override the operation attempt timeout for IMDS
    ///
    /// This value defaults to 1 second
    pub fn operation_attempt_timeout(mut self, timeout: Duration) -> Self {
        self.operation_attempt_timeout = Some(timeout);
        self
    }

    /// Override the retry classifier for IMDS
    ///
    /// This defaults to only retrying on server errors and 401s. The [ImdsResponseRetryClassifier] in this
    /// module offers some configuration options and can be wrapped by[SharedRetryClassifier::new()] for use
    /// here or you can create your own fully customized [SharedRetryClassifier].
    pub fn retry_classifier(mut self, retry_classifier: SharedRetryClassifier) -> Self {
        self.retry_classifier = Some(retry_classifier);
        self
    }

    /* TODO(https://github.com/awslabs/aws-sdk-rust/issues/339): Support customizing the port explicitly */
    /*
    pub fn port(mut self, port: u32) -> Self {
@@ -411,6 +443,11 @@ impl Builder {
        let timeout_config = TimeoutConfig::builder()
            .connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
            .read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT))
            .operation_attempt_timeout(
                self.operation_attempt_timeout
                    .unwrap_or(DEFAULT_OPERATION_ATTEMPT_TIMEOUT),
            )
            .operation_timeout(self.operation_timeout.unwrap_or(DEFAULT_OPERATION_TIMEOUT))
            .build();
        let endpoint_source = self
            .endpoint
@@ -421,10 +458,14 @@ impl Builder {
        };
        let retry_config = RetryConfig::standard()
            .with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
        let retry_classifier = self.retry_classifier.unwrap_or(SharedRetryClassifier::new(
            ImdsResponseRetryClassifier::default(),
        ));
        let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new(
            &config,
            endpoint_resolver,
            retry_config,
            retry_classifier,
            timeout_config,
        ));
        let operation = Operation::builder()
@@ -549,8 +590,20 @@ impl ResolveEndpoint for ImdsEndpointResolver {
/// - 403 (IMDS disabled): **Not Retryable**
/// - 404 (Not found): **Not Retryable**
/// - >=500 (server error): **Retryable**
#[derive(Clone, Debug)]
struct ImdsResponseRetryClassifier;
/// - Timeouts: Not retried by default, but this is configurable via [Self::with_retry_connect_timeouts()]
#[derive(Clone, Debug, Default)]
#[non_exhaustive]
pub struct ImdsResponseRetryClassifier {
    retry_connect_timeouts: bool,
}

impl ImdsResponseRetryClassifier {
    /// Indicate whether the IMDS client should retry on connection timeouts
    pub fn with_retry_connect_timeouts(mut self, retry_connect_timeouts: bool) -> Self {
        self.retry_connect_timeouts = retry_connect_timeouts;
        self
    }
}

impl ClassifyRetry for ImdsResponseRetryClassifier {
    fn name(&self) -> &'static str {
@@ -567,7 +620,10 @@ impl ClassifyRetry for ImdsResponseRetryClassifier {
                // This catch-all includes successful responses that fail to parse. These should not be retried.
                _ => RetryAction::NoActionIndicated,
            }
        } else if self.retry_connect_timeouts {
            RetryAction::server_error()
        } else {
            // This is the default behavior.
            // Don't retry timeouts for IMDS, or else it will take ~30 seconds for the default
            // credentials provider chain to fail to provide credentials.
            // Also don't retry non-responses.
@@ -593,7 +649,9 @@ pub(crate) mod test {
        HttpRequest, HttpResponse, OrchestratorError,
    };
    use aws_smithy_runtime_api::client::result::ConnectorError;
    use aws_smithy_runtime_api::client::retries::classifiers::{ClassifyRetry, RetryAction};
    use aws_smithy_runtime_api::client::retries::classifiers::{
        ClassifyRetry, RetryAction, SharedRetryClassifier,
    };
    use aws_smithy_types::body::SdkBody;
    use aws_smithy_types::error::display::DisplayErrorContext;
    use aws_types::os_shim_internal::{Env, Fs};
@@ -603,6 +661,7 @@ pub(crate) mod test {
    use std::collections::HashMap;
    use std::error::Error;
    use std::io;
    use std::time::SystemTime;
    use std::time::{Duration, UNIX_EPOCH};
    use tracing_test::traced_test;

@@ -933,7 +992,7 @@ pub(crate) mod test {
        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
        ctx.set_response(imds_response("").map(|_| SdkBody::empty()));
        let classifier = ImdsResponseRetryClassifier;
        let classifier = ImdsResponseRetryClassifier::default();
        assert_eq!(
            RetryAction::NoActionIndicated,
            classifier.classify_retry(&ctx)
@@ -950,6 +1009,65 @@ pub(crate) mod test {
        );
    }

    /// User provided retry classifier works
    #[tokio::test]
    async fn user_provided_retry_classifier() {
        #[derive(Clone, Debug)]
        struct UserProvidedRetryClassifier;

        impl ClassifyRetry for UserProvidedRetryClassifier {
            fn name(&self) -> &'static str {
                "UserProvidedRetryClassifier"
            }

            // Don't retry anything
            fn classify_retry(&self, _ctx: &InterceptorContext) -> RetryAction {
                RetryAction::RetryForbidden
            }
        }

        let events = vec![
            ReplayEvent::new(
                token_request("http://169.254.169.254", 21600),
                token_response(0, TOKEN_A),
            ),
            ReplayEvent::new(
                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
                http::Response::builder()
                    .status(401)
                    .body(SdkBody::empty())
                    .unwrap(),
            ),
            ReplayEvent::new(
                token_request("http://169.254.169.254", 21600),
                token_response(21600, TOKEN_B),
            ),
            ReplayEvent::new(
                imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
                imds_response("ok"),
            ),
        ];
        let http_client = StaticReplayClient::new(events);

        let imds_client = super::Client::builder()
            .configure(
                &ProviderConfig::no_configuration()
                    .with_sleep_impl(InstantSleep::unlogged())
                    .with_http_client(http_client.clone()),
            )
            .retry_classifier(SharedRetryClassifier::new(UserProvidedRetryClassifier))
            .build();

        let res = imds_client
            .get("/latest/metadata")
            .await
            .expect_err("Client should error");

        // Assert that the operation errored on the initial 401 and did not retry and get
        // the 200 (since the user provided retry classifier never retries)
        assert_full_error_contains!(res, "401");
    }

    // since tokens are sent as headers, the tokens need to be valid header values
    #[tokio::test]
    async fn invalid_token() {
@@ -989,9 +1107,6 @@ pub(crate) mod test {
    #[cfg(feature = "rustls")]
    async fn one_second_connect_timeout() {
        use crate::imds::client::ImdsError;
        use aws_smithy_types::error::display::DisplayErrorContext;
        use std::time::SystemTime;

        let client = Client::builder()
            // 240.* can never be resolved
            .endpoint("http://240.0.0.0")
@@ -1023,6 +1138,40 @@ pub(crate) mod test {
        );
    }

    /// Retry classifier properly retries timeouts when configured to (meaning it takes ~30s to fail)
    #[tokio::test]
    async fn retry_connect_timeouts() {
        let http_client = StaticReplayClient::new(vec![]);
        let imds_client = super::Client::builder()
            .retry_classifier(SharedRetryClassifier::new(
                ImdsResponseRetryClassifier::default().with_retry_connect_timeouts(true),
            ))
            .configure(&ProviderConfig::no_configuration().with_http_client(http_client.clone()))
            .operation_timeout(Duration::from_secs(1))
            .endpoint("http://240.0.0.0")
            .expect("valid uri")
            .build();

        let now = SystemTime::now();
        let _res = imds_client
            .get("/latest/metadata")
            .await
            .expect_err("240.0.0.0 will never resolve");
        let time_elapsed: Duration = now.elapsed().unwrap();

        assert!(
            time_elapsed > Duration::from_secs(1),
            "time_elapsed should be greater than 1s but was {:?}",
            time_elapsed
        );

        assert!(
            time_elapsed < Duration::from_secs(2),
            "time_elapsed should be less than 2s but was {:?}",
            time_elapsed
        );
    }

    #[derive(Debug, Deserialize)]
    struct ImdsConfigTest {
        env: HashMap<String, String>,