Unverified Commit c7fba4d0 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Add support for ErrorKind and RetryKind to all error models (#180)



* Modeling Structured Retries

* Improvements to docs

Co-authored-by: default avatarDavid Barsky <dbarsky@amazon.com>

* Improve RetryKind documentation

* Remove println, add .

* Rename to classify

* Add RetryPolicy test

* Fix clippy errors

* Remove unused paramaterization of AwsRetryPolicy

* Add missing retry module

* Cut down on generated code

* Fix tests

* Rename error_kind to retryable_error_kind

* Fix tests

Co-authored-by: default avatarDavid Barsky <dbarsky@amazon.com>
parent bbe1f0a9
Loading
Loading
Loading
Loading
+7 −1
Original line number Diff line number Diff line
@@ -29,11 +29,17 @@ impl EnvironmentVariableCredentialsProvider {
    }
}

impl Default for EnvironmentVariableCredentialsProvider {
    fn default() -> Self {
        Self::new()
    }
}

fn var(key: &str) -> Result<String, VarError> {
    std::env::var(key)
}

const ENV_PROVIDER: &'static str = "EnvironmentVariable";
const ENV_PROVIDER: &str = "EnvironmentVariable";

impl ProvideCredentials for EnvironmentVariableCredentialsProvider {
    fn credentials(&self) -> Result<Credentials, CredentialsError> {
+4 −1
Original line number Diff line number Diff line
@@ -3,12 +3,15 @@ name = "aws-http"
version = "0.1.0"
authors = ["Russell Cohen <rcoh@amazon.com>"]
edition = "2018"
description = "HTTP specific AWS SDK behaviors"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
aws-types = { path = "../aws-types" }
smithy-http = { path = "../../../rust-runtime/smithy-http" }
thiserror = "1"
smithy-types = { path = "../../../rust-runtime/smithy-types" }

http = "0.2.3"
thiserror = "1"
lazy_static = "1"
+209 −0
Original line number Diff line number Diff line
pub mod user_agent;
use smithy_http::retry::ClassifyResponse;
use smithy_types::retry::{ErrorKind, ProvideErrorKind, RetryKind};
use std::time::Duration;

/// A retry policy that models AWS error codes as outlined in the SEP
///
/// In order of priority:
/// 1. The `x-amz-retry-after` header is checked
/// 2. The modeled error retry mode is checked
/// 3. The code is checked against a predetermined list of throttling errors & transient error codes
/// 4. The status code is checked against a predetermined list of status codes
#[non_exhaustive]
pub struct AwsErrorRetryPolicy;

const TRANSIENT_ERROR_STATUS_CODES: [u16; 2] = [400, 408];
const THROTTLING_ERRORS: &[&str] = &[
    "Throttling",
    "ThrottlingException",
    "ThrottledException",
    "RequestThrottledException",
    "TooManyRequestsException",
    "ProvisionedThroughputExceededException",
    "TransactionInProgressException",
    "RequestLimitExceeded",
    "BandwidthLimitExceeded",
    "LimitExceededException",
    "RequestThrottled",
    "SlowDown",
    "PriorRequestNotComplete",
    "EC2ThrottledException",
];
const TRANSIENT_ERRORS: &[&str] = &["RequestTimeout", "RequestTimeoutException"];

impl AwsErrorRetryPolicy {
    /// Create an `AwsErrorRetryPolicy` with the default set of known error & status codes
    pub fn new() -> Self {
        AwsErrorRetryPolicy
    }
}

impl Default for AwsErrorRetryPolicy {
    fn default() -> Self {
        Self::new()
    }
}

impl ClassifyResponse for AwsErrorRetryPolicy {
    fn classify<E, B>(&self, err: E, response: &http::Response<B>) -> RetryKind
        where
            E: ProvideErrorKind,
    {
        if let Some(retry_after_delay) = response
            .headers()
            .get("x-amz-retry-after")
            .and_then(|header| header.to_str().ok())
            .and_then(|header| header.parse::<u64>().ok())
        {
            return RetryKind::Explicit(Duration::from_millis(retry_after_delay));
        }
        if let Some(kind) = err.retryable_error_kind() {
            return RetryKind::Error(kind);
        };
        if let Some(code) = err.code() {
            if THROTTLING_ERRORS.contains(&code) {
                return RetryKind::Error(ErrorKind::ThrottlingError);
            }
            if TRANSIENT_ERRORS.contains(&code) {
                return RetryKind::Error(ErrorKind::TransientError);
            }
        };
        if TRANSIENT_ERROR_STATUS_CODES
            .contains(&response.status().as_u16())
        {
            return RetryKind::Error(ErrorKind::TransientError);
        };
        // TODO: is IDPCommunicationError modeled yet?
        RetryKind::NotRetryable
    }
}

#[cfg(test)]
mod test {
    use crate::AwsErrorRetryPolicy;
    use smithy_http::retry::ClassifyResponse;
    use smithy_types::retry::{ErrorKind, ProvideErrorKind, RetryKind};
    use std::time::Duration;

    struct UnmodeledError;

    struct CodedError {
        code: &'static str,
    }

    impl ProvideErrorKind for UnmodeledError {
        fn retryable_error_kind(&self) -> Option<ErrorKind> {
            None
        }

        fn code(&self) -> Option<&str> {
            None
        }
    }

    impl ProvideErrorKind for CodedError {
        fn retryable_error_kind(&self) -> Option<ErrorKind> {
            None
        }

        fn code(&self) -> Option<&str> {
            Some(self.code)
        }
    }

    #[test]
    fn not_an_error() {
        let policy = AwsErrorRetryPolicy::new();
        let test_response = http::Response::new("OK");
        assert_eq!(
            policy.classify(UnmodeledError, &test_response),
            RetryKind::NotRetryable
        );
    }

    #[test]
    fn classify_by_response_status() {
        let policy = AwsErrorRetryPolicy::new();
        let test_resp = http::Response::builder()
            .status(408)
            .body("error!")
            .unwrap();
        assert_eq!(
            policy.classify(UnmodeledError, &test_resp),
            RetryKind::Error(ErrorKind::TransientError)
        );
    }

    #[test]
    fn classify_by_error_code() {
        let test_response = http::Response::new("OK");
        let policy = AwsErrorRetryPolicy::new();

        assert_eq!(
            policy.classify(CodedError { code: "Throttling" }, &test_response),
            RetryKind::Error(ErrorKind::ThrottlingError)
        );

        assert_eq!(
            policy.classify(
                CodedError {
                    code: "RequestTimeout"
                },
                &test_response,
            ),
            RetryKind::Error(ErrorKind::TransientError)
        )
    }

    #[test]
    fn classify_generic() {
        let err = smithy_types::Error {
            code: Some("SlowDown".to_string()),
            message: None,
            request_id: None,
        };
        let test_response = http::Response::new("OK");
        let policy = AwsErrorRetryPolicy::new();
        assert_eq!(
            policy.classify(err, &test_response),
            RetryKind::Error(ErrorKind::ThrottlingError)
        );
    }

    #[test]
    fn classify_by_error_kind() {
        struct ModeledRetries;
        let test_response = http::Response::new("OK");
        impl ProvideErrorKind for ModeledRetries {
            fn retryable_error_kind(&self) -> Option<ErrorKind> {
                Some(ErrorKind::ClientError)
            }

            fn code(&self) -> Option<&str> {
                // code should not be called when `error_kind` is provided
                unimplemented!()
            }
        }

        let policy = AwsErrorRetryPolicy::new();

        assert_eq!(
            policy.classify(ModeledRetries, &test_response),
            RetryKind::Error(ErrorKind::ClientError)
        );
    }

    #[test]
    fn test_retry_after_header() {
        let policy = AwsErrorRetryPolicy::new();
        let test_response = http::Response::builder()
            .header("x-amz-retry-after", "5000")
            .body("retry later")
            .unwrap();

        assert_eq!(
            policy.classify(UnmodeledError, &test_response),
            RetryKind::Explicit(Duration::from_millis(5000))
        );
    }
}
+3 −0
Original line number Diff line number Diff line
@@ -53,6 +53,9 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n
    // TODO: refactor to be RuntimeTypeProvider a la Symbol provider that packages the `RuntimeConfig` state.
    companion object {

        fun errorKind(runtimeConfig: RuntimeConfig) = RuntimeType("ErrorKind", dependency = CargoDependency.SmithyTypes(runtimeConfig), namespace = "${runtimeConfig.cratePrefix}_types::retry")
        fun provideErrorKind(runtimeConfig: RuntimeConfig) = RuntimeType("ProvideErrorKind", dependency = CargoDependency.SmithyTypes(runtimeConfig), namespace = "${runtimeConfig.cratePrefix}_types::retry")

        // val Blob = RuntimeType("Blob", RustDependency.IO_CORE, "blob")
        val From = RuntimeType("From", dependency = null, namespace = "std::convert")
        val AsRef = RuntimeType("AsRef", dependency = null, namespace = "std::convert")
+30 −2
Original line number Diff line number Diff line
@@ -10,6 +10,8 @@ import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.traits.RetryableTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.Derives
import software.amazon.smithy.rust.codegen.rustlang.RustMetadata
@@ -78,6 +80,32 @@ class CombinedErrorGenerator(
            }
        }

        val errorKindT = RuntimeType.errorKind(symbolProvider.config().runtimeConfig)
        writer.rustBlock(
            "impl #T for ${symbol.name}",
            RuntimeType.provideErrorKind(symbolProvider.config().runtimeConfig)
        ) {
            rustBlock("fn code(&self) -> Option<&str>") {
                rust("${symbol.name}::code(self)")
            }

            rustBlock("fn retryable_error_kind(&self) -> Option<#T>", errorKindT) {
                delegateToVariants {
                    when (it) {
                        is VariantMatch.Modeled -> writable {
                            if (it.shape.hasTrait(RetryableTrait::class.java)) {
                                rust("Some(_inner.retryable_error_kind())")
                            } else {
                                rust("None")
                            }
                        }
                        is VariantMatch.Generic -> writable { rust("_inner.retryable_error_kind()") }
                        is VariantMatch.Unhandled -> writable { rust("None") }
                    }
                }
            }
        }

        writer.rustBlock("impl ${symbol.name}") {
            writer.rustBlock("pub fn unhandled<E: Into<Box<dyn #T>>>(err: E) -> Self", RuntimeType.StdError) {
                write("${symbol.name}::Unhandled(err.into())")
@@ -122,7 +150,7 @@ class CombinedErrorGenerator(
    sealed class VariantMatch(name: String) : Section(name) {
        object Unhandled : VariantMatch("Unhandled")
        object Generic : VariantMatch("Generic")
        data class Modeled(val symbol: Symbol) : VariantMatch("Modeled")
        data class Modeled(val symbol: Symbol, val shape: Shape) : VariantMatch("Modeled")
    }

    /**
@@ -154,7 +182,7 @@ class CombinedErrorGenerator(
            errors.forEach {
                val errorSymbol = symbolProvider.toSymbol(it)
                rust("""${symbol.name}::${errorSymbol.name}(_inner) => """)
                handler(VariantMatch.Modeled(errorSymbol))(this)
                handler(VariantMatch.Modeled(errorSymbol, it))(this)
                write(",")
            }
            val genericHandler = handler(VariantMatch.Generic)
Loading