Unverified Commit cbcbed95 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Implement ProvideErrorMetadata for service errors (#3189)

This PR implements the `ProvideErrorMetadata` trait for service errors
as a prerequisite for implementing
[RFC-39](https://github.com/awslabs/smithy-rs/blob/main/design/src/rfcs/rfc0039_forward_compatible_errors.md).

Related SDK issue: https://github.com/awslabs/aws-sdk-rust/issues/780

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent f9241adb
Loading
Loading
Loading
Loading
+24 −0
Original line number Diff line number Diff line
@@ -113,3 +113,27 @@ message = "The `RequestId` trait has moved from the aws-http crate into aws-type
references = ["smithy-rs#3160"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "jdisanti"

[[aws-sdk-rust]]
message = "Add `ProvideErrorMetadata` impl for service `Error` type."
references = ["aws-sdk-rust#780", "smithy-rs#3189"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
author = "jdisanti"

[[smithy-rs]]
message = "Add `ProvideErrorMetadata` impl for service `Error` type."
references = ["aws-sdk-rust#780", "smithy-rs#3189"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "client" }
author = "jdisanti"

[[aws-sdk-rust]]
message = "Remove deprecated error kind type aliases."
references = ["smithy-rs#3189"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "jdisanti"

[[smithy-rs]]
message = "Remove deprecated error kind type aliases."
references = ["smithy-rs#3189"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" }
author = "jdisanti"
+2 −0
Original line number Diff line number Diff line
@@ -103,9 +103,11 @@ class RequiredCustomizations : ClientCodegenDecorator {

                pub use #{DisplayErrorContext};
                pub use #{ProvideErrorMetadata};
                pub use #{ErrorMetadata};
                """,
                "DisplayErrorContext" to RuntimeType.smithyTypes(rc).resolve("error::display::DisplayErrorContext"),
                "ProvideErrorMetadata" to RuntimeType.smithyTypes(rc).resolve("error::metadata::ProvideErrorMetadata"),
                "ErrorMetadata" to RuntimeType.smithyTypes(rc).resolve("error::metadata::ErrorMetadata"),
                "R" to RuntimeType.smithyRuntimeApiClient(rc).resolve("client::orchestrator::HttpResponse"),
                "SdkError" to RuntimeType.sdkError(rc),
                // this can't use the auto-rexport because the builder generator is defined in codegen core
+33 −30
Original line number Diff line number Diff line
@@ -35,7 +35,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizat
import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors
import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase

@@ -75,22 +74,6 @@ class OperationErrorGenerator(
            visibility = Visibility.PUBLIC,
        )

        // TODO(deprecated): Remove this temporary alias. This was added so that the compiler
        // points customers in the right direction when they are upgrading. Unfortunately there's no
        // way to provide better backwards compatibility on this change.
        val kindDeprecationMessage = "Operation `*Error/*ErrorKind` types were combined into a single `*Error` enum. " +
            "The `.kind` field on `*Error` no longer exists and isn't needed anymore (you can just match on the " +
            "error directly since it's an enum now)."
        writer.rust(
            """
            /// Do not use this.
            ///
            /// $kindDeprecationMessage
            ##[deprecated(note = ${kindDeprecationMessage.dq()})]
            pub type ${errorSymbol.name}Kind = ${errorSymbol.name};
            """,
        )

        writer.rust("/// Error type for the `${errorSymbol.name}` operation.")
        meta.render(writer)
        writer.rustBlock("enum ${errorSymbol.name}") {
@@ -108,7 +91,18 @@ class OperationErrorGenerator(
                unhandledError(runtimeConfig),
            )
        }
        writer.rustBlock("impl #T for ${errorSymbol.name}", createUnhandledError) {

        writer.renderImpl(errorSymbol, errors)
        writer.renderImplStdError(errorSymbol, errors)
        writer.renderImplDisplay(errorSymbol, errors)
        writer.renderImplProvideErrorKind(errorSymbol, errors)
        writer.renderImplProvideErrorMetadata(errorSymbol, errors)
        writer.renderImplCreateUnhandledError(errorSymbol)
        writer.writeCustomizations(customizations, ErrorSection.OperationErrorAdditionalTraitImpls(errorSymbol, errors))
    }

    private fun RustWriter.renderImplCreateUnhandledError(errorSymbol: Symbol) {
        rustBlock("impl #T for ${errorSymbol.name}", createUnhandledError) {
            rustBlockTemplate(
                """
                fn create_unhandled_error(
@@ -132,27 +126,32 @@ class OperationErrorGenerator(
                )
            }
        }
        writer.rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.Display) {
    }

    private fun RustWriter.renderImplDisplay(errorSymbol: Symbol, errors: List<StructureShape>) {
        rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.Display) {
            rustBlock("fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result") {
                delegateToVariants(errors) {
                    writable { rust("_inner.fmt(f)") }
                }
            }
        }
    }

    private fun RustWriter.renderImplProvideErrorMetadata(errorSymbol: Symbol, errors: List<StructureShape>) {
        val errorMetadataTrait = RuntimeType.provideErrorMetadataTrait(runtimeConfig)
        writer.rustBlock("impl #T for ${errorSymbol.name}", errorMetadataTrait) {
        rustBlock("impl #T for ${errorSymbol.name}", errorMetadataTrait) {
            rustBlock("fn meta(&self) -> &#T", errorMetadata(runtimeConfig)) {
                delegateToVariants(errors) {
                    writable { rust("#T::meta(_inner)", errorMetadataTrait) }
                }
            }
        }
    }

        writer.writeCustomizations(customizations, ErrorSection.OperationErrorAdditionalTraitImpls(errorSymbol, errors))

    private fun RustWriter.renderImplProvideErrorKind(errorSymbol: Symbol, errors: List<StructureShape>) {
        val retryErrorKindT = RuntimeType.retryErrorKind(symbolProvider.config.runtimeConfig)
        writer.rustBlock(
        rustBlock(
            "impl #T for ${errorSymbol.name}",
            RuntimeType.provideErrorKind(symbolProvider.config.runtimeConfig),
        ) {
@@ -182,9 +181,11 @@ class OperationErrorGenerator(
                }
            }
        }
    }

        writer.rustBlock("impl ${errorSymbol.name}") {
            writer.rustTemplate(
    private fun RustWriter.renderImpl(errorSymbol: Symbol, errors: List<StructureShape>) {
        rustBlock("impl ${errorSymbol.name}") {
            rustTemplate(
                """
                /// Creates the `${errorSymbol.name}::Unhandled` variant from any error type.
                pub fn unhandled(err: impl #{Into}<#{Box}<dyn #{StdError} + #{Send} + #{Sync} + 'static>>) -> Self {
@@ -201,13 +202,13 @@ class OperationErrorGenerator(
                "StdError" to RuntimeType.StdError,
                "Unhandled" to unhandledError(runtimeConfig),
            )
            writer.docs(
            docs(
                """
                Returns error metadata, which includes the error code, message,
                request ID, and potentially additional information.
                """,
            )
            writer.rustBlock("pub fn meta(&self) -> &#T", errorMetadata) {
            rustBlock("pub fn meta(&self) -> &#T", errorMetadata) {
                rust("use #T;", RuntimeType.provideErrorMetadataTrait(runtimeConfig))
                rustBlock("match self") {
                    errors.forEach { error ->
@@ -220,14 +221,16 @@ class OperationErrorGenerator(
            errors.forEach { error ->
                val errorVariantSymbol = symbolProvider.toSymbol(error)
                val fnName = errorVariantSymbol.name.toSnakeCase()
                writer.rust("/// Returns `true` if the error kind is `${errorSymbol.name}::${errorVariantSymbol.name}`.")
                writer.rustBlock("pub fn is_$fnName(&self) -> bool") {
                rust("/// Returns `true` if the error kind is `${errorSymbol.name}::${errorVariantSymbol.name}`.")
                rustBlock("pub fn is_$fnName(&self) -> bool") {
                    rust("matches!(self, Self::${errorVariantSymbol.name}(_))")
                }
            }
        }
    }

        writer.rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.StdError) {
    private fun RustWriter.renderImplStdError(errorSymbol: Symbol, errors: List<StructureShape>) {
        rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.StdError) {
            rustBlockTemplate(
                "fn source(&self) -> #{Option}<&(dyn #{StdError} + 'static)>",
                *preludeScope,
+29 −0
Original line number Diff line number Diff line
@@ -21,9 +21,11 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.unhandledError
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
@@ -67,6 +69,7 @@ class ServiceErrorGenerator(
            renderDefinition()
            renderImplDisplay()
            renderImplFromBuildError()
            renderImplProvideErrorMetadata()
            // Every operation error can be converted into service::Error
            operations.forEach { operationShape ->
                // operation errors
@@ -177,6 +180,32 @@ class ServiceErrorGenerator(
        }
    }

    private fun RustWriter.renderImplProvideErrorMetadata() {
        rustTemplate(
            """
            impl #{ProvideErrorMetadata} for Error {
                fn meta(&self) -> &#{ErrorMetadata} {
                    match self {
                        #{matchers}
                        Self::Unhandled(inner) => inner.meta(),
                    }
                }
            }
            """,
            *preludeScope,
            "ErrorMetadata" to RuntimeType.smithyTypes(codegenContext.runtimeConfig)
                .resolve("error::metadata::ErrorMetadata"),
            "ProvideErrorMetadata" to RuntimeType.smithyTypes(codegenContext.runtimeConfig)
                .resolve("error::metadata::ProvideErrorMetadata"),
            "matchers" to writable {
                allErrors.forEach { errorShape ->
                    val errSymbol = symbolProvider.toSymbol(errorShape)
                    rust("Self::${errSymbol.name}(inner) => inner.meta(),")
                }
            },
        )
    }

    private fun RustWriter.renderDefinition() {
        rust("/// All possible error types for this service.")
        RustMetadata(
+26 −0
Original line number Diff line number Diff line
@@ -84,4 +84,30 @@ internal class ServiceErrorGeneratorTest {
            }
        }
    }

    @Test
    fun `provides error metadata`() {
        clientIntegrationTest(model) { _, rustCrate ->
            rustCrate.moduleFor(model.lookup<StructureShape>("com.example#CanYouRepeatThat")) {
                unitTest(
                    name = "generates_combined_error_enums",
                    test = """
                        use crate::Error;
                        use crate::error::{ErrorMetadata, ProvideErrorMetadata};
                        use crate::operation::say_hello::SayHelloError;
                        use crate::types::error::*;

                        // Unhandled variants properly delegate source.
                        let error = Error::from(SayHelloError::SorryBusy(
                            SorryBusy::builder()
                                .meta(ErrorMetadata::builder().code("some code").message("some message").build())
                                .build()
                        ));
                        assert_eq!("some code", error.code().expect("code field"));
                        assert_eq!("some message", error.message().expect("message field"));
                    """,
                )
            }
        }
    }
}