Unverified Commit afde25bd authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Generate simpler operation error types for the server (#804)

Currently, operation error types are designed to be used by clients.
Clients need to cope with a large scale of possible errors, including
unmodeled errors and errors coming from different layers of the stack.
On the other hand, all operation server errors should be related to
business logic; we want all of them to be correctly modeled and have
compilation fail if any of them is not properly handled.

This commit revamps error generation by the `rust-server-codegen`
plugin, generating a simple Rust enum type for each fallible operation,
holding one variant per modeled error.

Closes #780.
parent 610ed2f2
Loading
Loading
Loading
Loading
+112 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.documentShape
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.util.toSnakeCase

/**
 * Generates a unified error enum for [operation]. [ErrorGenerator] handles generating the individual variants,
 * but we must still combine those variants into an enum covering all possible errors for a given operation.
 */
class ServerCombinedErrorGenerator(
    private val model: Model,
    private val symbolProvider: RustSymbolProvider,
    private val operation: OperationShape
) {
    private val operationIndex = OperationIndex.of(model)

    fun render(writer: RustWriter) {
        val errors = operationIndex.getErrors(operation)
        val operationSymbol = symbolProvider.toSymbol(operation)
        val symbol = operation.errorSymbol(symbolProvider)
        val meta = RustMetadata(
            derives = Attribute.Derives(setOf(RuntimeType.Debug)),
            public = true
        )

        writer.rust("/// Error type for the `${operationSymbol.name}` operation.")
        writer.rust("/// Each variant represents an error that can occur for the `${operationSymbol.name}` operation.")
        meta.render(writer)
        writer.rustBlock("enum ${symbol.name}") {
            errors.forEach { errorVariant ->
                documentShape(errorVariant, model)
                val errorVariantSymbol = symbolProvider.toSymbol(errorVariant)
                write("${errorVariantSymbol.name}(#T),", errorVariantSymbol)
            }
        }

        writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.stdfmt.member("Display")) {
            rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") {
                delegateToVariants {
                    rust("_inner.fmt(f)")
                }
            }
        }

        writer.rustBlock("impl ${symbol.name}") {
            errors.forEach { error ->
                val errorSymbol = symbolProvider.toSymbol(error)
                val fnName = errorSymbol.name.toSnakeCase()
                writer.rust("/// Returns `true` if the error kind is `${symbol.name}::${errorSymbol.name}`.")
                writer.rustBlock("pub fn is_$fnName(&self) -> bool") {
                    rust("matches!(&self, ${symbol.name}::${errorSymbol.name}(_))")
                }
            }
        }

        writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.StdError) {
            rustBlock("fn source(&self) -> Option<&(dyn #T + 'static)>", RuntimeType.StdError) {
                delegateToVariants {
                    rust("Some(_inner)");
                }
            }
        }
    }

    /**
     * Generates code to delegate behavior to the variants, for example:
     *
     * ```rust
     *  match &self {
     *      GreetingWithErrorsError::InvalidGreeting(_inner) => inner.fmt(f),
     *      GreetingWithErrorsError::ComplexError(_inner) => inner.fmt(f),
     *      GreetingWithErrorsError::FooError(_inner) => inner.fmt(f),
     *      GreetingWithErrorsError::Unhandled(_inner) => _inner.fmt(f),
     *  }
     *  ```
     *
     * A [writable] is passed containing the content to be written for each variant.
     *
     *  The field will always be bound as `_inner`.
     */
    private fun RustWriter.delegateToVariants(
        writable: Writable
    ) {
        val errors = operationIndex.getErrors(operation)
        val symbol = operation.errorSymbol(symbolProvider)
        rustBlock("match &self") {
            errors.forEach {
                val errorSymbol = symbolProvider.toSymbol(it)
                rust("""${symbol.name}::${errorSymbol.name}(_inner) => """)
                writable(this)
                write(",")
            }
        }
    }
}
+7 −8
Original line number Diff line number Diff line
@@ -11,8 +11,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RustCrate
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
import software.amazon.smithy.rust.codegen.smithy.generators.error.CombinedErrorGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.error.TopLevelErrorGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport

@@ -34,7 +33,6 @@ class ServerServiceGenerator(
    /**
     * Render Service Specific code. Code will end up in different files via [useShapeWriter]. See `SymbolVisitor.kt`
     * which assigns a symbol location to each shape.
     *
     */
    fun render() {
        val operations = index.getContainedOperations(context.serviceShape).sortedBy { it.id }
@@ -48,12 +46,13 @@ class ServerServiceGenerator(
                ServerProtocolTestGenerator(context, protocolSupport, operation, operationWriter)
                    .render()
            }

            if (operation.errors.isNotEmpty()) {
                rustCrate.withModule(RustModule.Error) { writer ->
                CombinedErrorGenerator(context.model, context.symbolProvider, operation)
                    ServerCombinedErrorGenerator(context.model, context.symbolProvider, operation)
                        .render(writer)
                }
            }

        TopLevelErrorGenerator(context, operations).render(rustCrate)
        }
    }
}
+0 −5
Original line number Diff line number Diff line
@@ -7,8 +7,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.shapes.DoubleShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
@@ -33,13 +31,10 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.HttpServerTra
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.Instantiator
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.isStreaming
import software.amazon.smithy.rust.codegen.util.orNull
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase
+12 −41
Original line number Diff line number Diff line
@@ -37,7 +37,6 @@ import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolTr
import software.amazon.smithy.rust.codegen.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolBodyGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolTraitImplGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.transformers.errorMessageMember
@@ -121,7 +120,6 @@ private class ServerHttpProtocolImplGenerator(
        "PercentEncoding" to CargoDependency("percent-encoding", CratesIo("2.1.0")).asType(),
        "SdkBody" to RuntimeType.sdkBody(runtimeConfig)
    )
    private val httpBoundProtocolGenerator = HttpBoundProtocolTraitImplGenerator(codegenContext, protocol)

    override fun generateTraitImpls(operationWriter: RustWriter, operationShape: OperationShape) {
        val inputSymbol = symbolProvider.toSymbol(operationShape.inputShape(model))
@@ -155,26 +153,6 @@ private class ServerHttpProtocolImplGenerator(
        operationShape: OperationShape
    ) {
        val errorSymbol = operationShape.errorSymbol(symbolProvider)
        val successCode = httpBindingResolver.httpTrait(operationShape).code
        /* Implement `ParseStrictResponse` from client codegen to be used inside tests */
        rustTemplate(
            """
            impl #{ParseStrictResponse} for $operationName {
                type Output = std::result::Result<#{O}, #{E}>;
                fn parse(&self, response: &#{http}::Response<#{Bytes}>) -> Self::Output {
                     if !response.status().is_success() && response.status().as_u16() != $successCode {
                        #{parse_error}(response)
                     } else {
                        #{parse_response}(response)
                     }
                }
            }""",
            *codegenScope,
            "O" to outputSymbol,
            "E" to operationShape.errorSymbol(symbolProvider),
            "parse_error" to httpBoundProtocolGenerator.parseError(operationShape),
            "parse_response" to httpBoundProtocolGenerator.parseResponse(operationShape)
        )
        /* Implement `ParseHttpRequest` for non streaming types. This is done by only implementing `parse_loaded` */
        rustTemplate(
            """
@@ -206,11 +184,12 @@ private class ServerHttpProtocolImplGenerator(
            "serialize_response" to serverSerializeResponse(operationShape)
        )
        /* Implement `SerializeHttpError` for non streaming types. This is done by only implementing `serialize` */
        if (operationShape.errors.isNotEmpty()) {
            rustTemplate(
                """
            impl #{SerializeHttpError} for $operationName {
                type Output = std::result::Result<#{http}::Response<#{Bytes}>, #{Error}>;
                type Struct = #{E}Kind;
                type Struct = #{E};
                fn serialize(&self, error: &Self::Struct) -> Self::Output {
                    #{serialize_error}(error)
                }
@@ -220,6 +199,7 @@ private class ServerHttpProtocolImplGenerator(
                "serialize_error" to serverSerializeError(operationShape)
            )
        }
    }

    /*
     * TODO: implement streaming traits
@@ -281,7 +261,7 @@ private class ServerHttpProtocolImplGenerator(
        return RuntimeType.forInlineFun(fnName, operationSerModule) {
            Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it)
            it.rustBlockTemplate(
                "pub fn $fnName(error: &#{E}Kind) -> std::result::Result<#{http}::Response<#{Bytes}>, #{Error}>",
                "pub fn $fnName(error: &#{E}) -> std::result::Result<#{http}::Response<#{Bytes}>, #{Error}>",
                *codegenScope,
                "E" to errorSymbol
            ) {
@@ -309,10 +289,10 @@ private class ServerHttpProtocolImplGenerator(
                val variantSymbol = symbolProvider.toSymbol(variantShape)
                val data = safeName("var")
                val serializerSymbol = structuredDataSerializer.serverErrorSerializer(it)
                rustBlock("#TKind::${variantSymbol.name}($data) =>", errorSymbol) {
                rustBlock("#T::${variantSymbol.name}($data) =>", errorSymbol) {
                    rust(
                        """
                        #T(&$data)?;
                        #T($data)?;
                        object.key(${"code".dq()}).string(${httpBindingResolver.errorCode(variantShape).dq()});
                        """.trimIndent(),
                        serializerSymbol
@@ -341,15 +321,6 @@ private class ServerHttpProtocolImplGenerator(
                    rust("response = response.status($status);")
                }
            }
            rust(
                """
                #TKind::Unhandled(_) => {
                    object.key(${"code".dq()}).string(${"Unhandled".dq()});
                    response = response.status(500);
                }
                """.trimIndent(),
                errorSymbol
            )
        }
        rust("object.finish();")
        rustTemplate(