Unverified Commit 0cba3d89 authored by Zelda Hessler's avatar Zelda Hessler Committed by GitHub
Browse files

Update retry classifiers to act on InterceptorContext instead of OrchestratorError (#2737)

## Description
<!--- Describe your changes in detail -->
- update retry classifiers to act directly on the `InterceptorContext`
- add codegen method for inserting retry classifiers into the config bag
- update `InterceptorContext` to return options instead of panicking

## Testing
<!--- Please describe in detail how you tested your changes -->
<!--- Include details of your testing environment, and the tests you ran
to -->
<!--- see how your change affects other areas of the code, etc. -->
Tests were updated as needed.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 40612a93
Loading
Loading
Loading
Loading
+6 −1
Original line number Diff line number Diff line
@@ -80,7 +80,12 @@ mod tests {
    use http::HeaderValue;

    fn expect_header<'a>(context: &'a InterceptorContext, header_name: &str) -> &'a HeaderValue {
        context.request().headers().get(header_name).unwrap()
        context
            .request()
            .expect("request is set")
            .headers()
            .get(header_name)
            .unwrap()
    }

    #[test]
+1 −1
Original line number Diff line number Diff line
@@ -159,7 +159,7 @@ mod tests {
        RecursionDetectionInterceptor { env }
            .modify_before_signing(&mut ctx, &mut config)
            .expect("interceptor must succeed");
        let mutated_request = context.request();
        let mutated_request = context.request().expect("request is set");
        for name in mutated_request.headers().keys() {
            assert_eq!(
                mutated_request.headers().get_all(name).iter().count(),
+1 −0
Original line number Diff line number Diff line
@@ -168,6 +168,7 @@ mod tests {
    fn expect_header<'a>(context: &'a InterceptorContext, header_name: &str) -> &'a str {
        context
            .request()
            .expect("request is set")
            .headers()
            .get(header_name)
            .unwrap()
+93 −41
Original line number Diff line number Diff line
@@ -4,10 +4,13 @@
 */

use aws_smithy_http::http::HttpHeaders;
use aws_smithy_http::result::SdkError;
use aws_smithy_runtime_api::client::retries::RetryReason;
use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
use aws_smithy_runtime_api::client::retries::{ClassifyRetry, RetryReason};
use aws_smithy_types::error::metadata::ProvideErrorMetadata;
use aws_smithy_types::retry::ErrorKind;
use std::error::Error as StdError;
use std::marker::PhantomData;

/// AWS error codes that represent throttling errors.
pub const THROTTLING_ERRORS: &[&str] = &[
@@ -31,16 +34,31 @@ pub const THROTTLING_ERRORS: &[&str] = &[
pub const TRANSIENT_ERRORS: &[&str] = &["RequestTimeout", "RequestTimeoutException"];

/// A retry classifier for determining if the response sent by an AWS service requires a retry.
#[derive(Debug)]
pub struct AwsErrorCodeClassifier;

impl AwsErrorCodeClassifier {
    /// Classify an error code to check if represents a retryable error. The codes of retryable
    /// errors are defined [here](THROTTLING_ERRORS) and [here](TRANSIENT_ERRORS).
    pub fn classify_error<E: ProvideErrorMetadata, R>(
        &self,
        error: &SdkError<E, R>,
    ) -> Option<RetryReason> {
#[derive(Debug, Default)]
pub struct AwsErrorCodeClassifier<E> {
    _inner: PhantomData<E>,
}

impl<E> AwsErrorCodeClassifier<E> {
    /// Create a new AwsErrorCodeClassifier
    pub fn new() -> Self {
        Self {
            _inner: PhantomData,
        }
    }
}

impl<E> ClassifyRetry for AwsErrorCodeClassifier<E>
where
    E: StdError + ProvideErrorMetadata + Send + Sync + 'static,
{
    fn classify_retry(&self, ctx: &InterceptorContext) -> Option<RetryReason> {
        let error = ctx
            .output_or_error()?
            .err()
            .and_then(OrchestratorError::as_operation_error)?
            .downcast_ref::<E>()?;

        if let Some(error_code) = error.code() {
            if THROTTLING_ERRORS.contains(&error_code) {
                return Some(RetryReason::Error(ErrorKind::ThrottlingError));
@@ -51,18 +69,27 @@ impl AwsErrorCodeClassifier {

        None
    }

    fn name(&self) -> &'static str {
        "AWS Error Code"
    }
}

/// A retry classifier that checks for `x-amz-retry-after` headers. If one is found, a
/// [`RetryReason::Explicit`] is returned containing the duration to wait before retrying.
#[derive(Debug)]
#[derive(Debug, Default)]
pub struct AmzRetryAfterHeaderClassifier;

impl AmzRetryAfterHeaderClassifier {
    /// Classify an AWS responses error code to determine how (and if) it should be retried.
    pub fn classify_error<E>(&self, error: &SdkError<E>) -> Option<RetryReason> {
        error
            .raw_response()
    /// Create a new `AmzRetryAfterHeaderClassifier`.
    pub fn new() -> Self {
        Self
    }
}

impl ClassifyRetry for AmzRetryAfterHeaderClassifier {
    fn classify_retry(&self, ctx: &InterceptorContext) -> Option<RetryReason> {
        ctx.response()
            .and_then(|res| res.http_headers().get("x-amz-retry-after"))
            .and_then(|header| header.to_str().ok())
            .and_then(|header| header.parse::<u64>().ok())
@@ -70,18 +97,23 @@ impl AmzRetryAfterHeaderClassifier {
                RetryReason::Explicit(std::time::Duration::from_millis(retry_after_delay))
            })
    }

    fn name(&self) -> &'static str {
        "'Retry After' Header"
    }
}

#[cfg(test)]
mod test {
    use super::{AmzRetryAfterHeaderClassifier, AwsErrorCodeClassifier};
    use aws_smithy_http::body::SdkBody;
    use aws_smithy_http::operation;
    use aws_smithy_http::result::SdkError;
    use aws_smithy_runtime_api::client::retries::RetryReason;
    use aws_smithy_runtime_api::client::interceptors::InterceptorContext;
    use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
    use aws_smithy_runtime_api::client::retries::{ClassifyRetry, RetryReason};
    use aws_smithy_types::error::metadata::ProvideErrorMetadata;
    use aws_smithy_types::error::ErrorMetadata;
    use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind};
    use aws_smithy_types::type_erasure::{TypeErasedBox, TypeErasedError};
    use std::fmt;
    use std::time::Duration;

@@ -96,6 +128,17 @@ mod test {

    impl std::error::Error for UnmodeledError {}

    impl ProvideErrorKind for UnmodeledError {
        fn retryable_error_kind(&self) -> Option<ErrorKind> {
            None
        }

        fn code(&self) -> Option<&str> {
            None
        }
    }

    #[derive(Debug)]
    struct CodedError {
        metadata: ErrorMetadata,
    }
@@ -108,16 +151,14 @@ mod test {
        }
    }

    impl ProvideErrorKind for UnmodeledError {
        fn retryable_error_kind(&self) -> Option<ErrorKind> {
            None
        }

        fn code(&self) -> Option<&str> {
            None
    impl fmt::Display for CodedError {
        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
            write!(f, "Coded Error")
        }
    }

    impl std::error::Error for CodedError {}

    impl ProvideErrorMetadata for CodedError {
        fn meta(&self) -> &ErrorMetadata {
            &self.metadata
@@ -126,31 +167,39 @@ mod test {

    #[test]
    fn classify_by_error_code() {
        let policy = AwsErrorCodeClassifier;
        let res = http::Response::new("OK");
        let err = SdkError::service_error(CodedError::new("Throttling"), res);
        let policy = AwsErrorCodeClassifier::<CodedError>::new();
        let mut ctx = InterceptorContext::new(TypeErasedBox::new("doesntmatter"));
        ctx.set_output_or_error(Err(OrchestratorError::operation(TypeErasedError::new(
            CodedError::new("Throttling"),
        ))));

        assert_eq!(
            policy.classify_error(&err),
            policy.classify_retry(&ctx),
            Some(RetryReason::Error(ErrorKind::ThrottlingError))
        );

        let res = http::Response::new("OK");
        let err = SdkError::service_error(CodedError::new("RequestTimeout"), res);
        let mut ctx = InterceptorContext::new(TypeErasedBox::new("doesntmatter"));
        ctx.set_output_or_error(Err(OrchestratorError::operation(TypeErasedError::new(
            CodedError::new("RequestTimeout"),
        ))));
        assert_eq!(
            policy.classify_error(&err),
            policy.classify_retry(&ctx),
            Some(RetryReason::Error(ErrorKind::TransientError))
        )
    }

    #[test]
    fn classify_generic() {
        let policy = AwsErrorCodeClassifier;
        let res = http::Response::new("OK");
        let policy = AwsErrorCodeClassifier::<ErrorMetadata>::new();
        let err = aws_smithy_types::Error::builder().code("SlowDown").build();
        let err = SdkError::service_error(err, res);
        let test_response = http::Response::new("OK").map(SdkBody::from);

        let mut ctx = InterceptorContext::new(TypeErasedBox::new("doesntmatter"));
        ctx.set_response(test_response);
        ctx.set_output_or_error(Err(OrchestratorError::operation(TypeErasedError::new(err))));

        assert_eq!(
            policy.classify_error(&err),
            policy.classify_retry(&ctx),
            Some(RetryReason::Error(ErrorKind::ThrottlingError))
        );
    }
@@ -163,11 +212,14 @@ mod test {
            .body("retry later")
            .unwrap()
            .map(SdkBody::from);
        let res = operation::Response::new(res);
        let err = SdkError::service_error(UnmodeledError, res);
        let mut ctx = InterceptorContext::new(TypeErasedBox::new("doesntmatter"));
        ctx.set_response(res);
        ctx.set_output_or_error(Err(OrchestratorError::operation(TypeErasedError::new(
            UnmodeledError,
        ))));

        assert_eq!(
            policy.classify_error(&err),
            policy.classify_retry(&ctx),
            Some(RetryReason::Explicit(Duration::from_millis(5000))),
        );
    }
+1 −0
Original line number Diff line number Diff line
@@ -117,6 +117,7 @@ mod tests {
    fn expect_header<'a>(context: &'a InterceptorContext, header_name: &str) -> &'a str {
        context
            .request()
            .expect("request is set")
            .headers()
            .get(header_name)
            .unwrap()
Loading