Unverified Commit f7080760 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Split `RuntimeError` and `RequestRejection` by protocol (#2517)

As outlined in the [Protocol Specific Errors] of the [Service Builder
Improvements RFC], `RuntimeError` should be split up into smaller,
protocol specific, errors which accurately model the failure cases of
each protocol.

The same goes for `RequestRejection`.

Closes #1703.

[Protocol Specific Errors]: https://github.com/awslabs/smithy-rs/blob/main/design/src/rfcs/rfc0020_service_builder.md#protocol-specific-errors
[Service Builder Improvements RFC]: https://github.com/awslabs/smithy-rs/blob/main/design/src/rfcs/rfc0020_service_builder.md
parent 3feb4bef
Loading
Loading
Loading
Loading
+16 −6
Original line number Diff line number Diff line
@@ -118,7 +118,7 @@ open class ServerCodegenVisitor(

        val baseModel = baselineTransform(context.model)
        val service = settings.getService(baseModel)
        val (protocol, generator) =
        val (protocolShape, protocolGeneratorFactory) =
            ServerProtocolLoader(
                codegenDecorator.protocols(
                    service.id,
@@ -126,7 +126,7 @@ open class ServerCodegenVisitor(
                ),
            )
                .protocolFor(context.model, service)
        protocolGeneratorFactory = generator
        this.protocolGeneratorFactory = protocolGeneratorFactory

        model = codegenDecorator.transformModel(service, baseModel)

@@ -145,7 +145,7 @@ open class ServerCodegenVisitor(
            serverSymbolProviders.symbolProvider,
            null,
            service,
            protocol,
            protocolShape,
            settings,
            serverSymbolProviders.unconstrainedShapeSymbolProvider,
            serverSymbolProviders.constrainedShapeSymbolProvider,
@@ -169,7 +169,7 @@ open class ServerCodegenVisitor(
            settings.codegenConfig,
            codegenContext.expectModuleDocProvider(),
        )
        protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext)
        protocolGenerator = this.protocolGeneratorFactory.buildProtocolGenerator(codegenContext)
    }

    /**
@@ -315,7 +315,12 @@ open class ServerCodegenVisitor(
        writer: RustWriter,
    ) {
        if (codegenContext.settings.codegenConfig.publicConstrainedTypes || shape.isReachableFromOperationInput()) {
            val serverBuilderGenerator = ServerBuilderGenerator(codegenContext, shape, validationExceptionConversionGenerator)
            val serverBuilderGenerator = ServerBuilderGenerator(
                codegenContext,
                shape,
                validationExceptionConversionGenerator,
                protocolGenerator.protocol,
            )
            serverBuilderGenerator.render(rustCrate, writer)

            if (codegenContext.settings.codegenConfig.publicConstrainedTypes) {
@@ -336,7 +341,12 @@ open class ServerCodegenVisitor(

        if (!codegenContext.settings.codegenConfig.publicConstrainedTypes) {
            val serverBuilderGeneratorWithoutPublicConstrainedTypes =
                ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape, validationExceptionConversionGenerator)
                ServerBuilderGeneratorWithoutPublicConstrainedTypes(
                    codegenContext,
                    shape,
                    validationExceptionConversionGenerator,
                    protocolGenerator.protocol,
                )
            serverBuilderGeneratorWithoutPublicConstrainedTypes.render(rustCrate, writer)

            writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) {
+4 −11
Original line number Diff line number Diff line
@@ -5,7 +5,6 @@

package software.amazon.smithy.rust.codegen.server.smithy

import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType

@@ -15,17 +14,11 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
 * For a runtime type that is used in the client, or in both the client and the server, use [RuntimeType] directly.
 */
object ServerRuntimeType {
    fun forInlineDependency(inlineDependency: InlineDependency) = RuntimeType("crate::${inlineDependency.name}", inlineDependency)
    fun router(runtimeConfig: RuntimeConfig) =
        ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::Router")

    fun router(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::Router")

    fun runtimeError(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("runtime_error::RuntimeError")

    fun requestRejection(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("rejection::RequestRejection")

    fun responseRejection(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("rejection::ResponseRejection")

    fun protocol(name: String, path: String, runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("proto::$path::$name")
    fun protocol(name: String, path: String, runtimeConfig: RuntimeConfig) =
        ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("proto::$path::$name")

    fun protocol(runtimeConfig: RuntimeConfig) = protocol("Protocol", "", runtimeConfig)
}
+4 −7
Original line number Diff line number Diff line
@@ -23,7 +23,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator
import software.amazon.smithy.rust.codegen.server.smithy.generators.BlobLength
import software.amazon.smithy.rust.codegen.server.smithy.generators.CollectionTraitInfo
@@ -35,6 +34,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.StringTraitI
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage

/**
@@ -67,11 +67,7 @@ class ValidationExceptionWithReasonConversionGenerator(private val codegenContex
    override val shapeId: ShapeId =
        ShapeId.from(codegenContext.settings.codegenConfig.experimentalCustomValidationExceptionWithReasonPleaseDoNotUse)

    override fun renderImplFromConstraintViolationForRequestRejection(): Writable = writable {
        val codegenScope = arrayOf(
            "RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig),
            "From" to RuntimeType.From,
        )
    override fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable = writable {
        rustTemplate(
            """
            impl #{From}<ConstraintViolation> for #{RequestRejection} {
@@ -89,7 +85,8 @@ class ValidationExceptionWithReasonConversionGenerator(private val codegenContex
                }
            }
            """,
            *codegenScope,
            "RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
            "From" to RuntimeType.From,
        )
    }

+4 −7
Original line number Diff line number Diff line
@@ -23,7 +23,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator
import software.amazon.smithy.rust.codegen.server.smithy.generators.BlobLength
import software.amazon.smithy.rust.codegen.server.smithy.generators.CollectionTraitInfo
@@ -34,6 +33,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.TraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage

/**
@@ -66,11 +66,7 @@ class SmithyValidationExceptionConversionGenerator(private val codegenContext: S
    }
    override val shapeId: ShapeId = SHAPE_ID

    override fun renderImplFromConstraintViolationForRequestRejection(): Writable = writable {
        val codegenScope = arrayOf(
            "RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig),
            "From" to RuntimeType.From,
        )
    override fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable = writable {
        rustTemplate(
            """
            impl #{From}<ConstraintViolation> for #{RequestRejection} {
@@ -87,7 +83,8 @@ class SmithyValidationExceptionConversionGenerator(private val codegenContext: S
                }
            }
            """,
            *codegenScope,
            "RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
            "From" to RuntimeType.From,
        )
    }

+5 −3
Original line number Diff line number Diff line
@@ -46,8 +46,8 @@ import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.hasConstraintTraitOrTargetHasConstraintTrait
import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait
@@ -92,6 +92,7 @@ class ServerBuilderGenerator(
    val codegenContext: ServerCodegenContext,
    private val shape: StructureShape,
    private val customValidationExceptionWithReasonConversionGenerator: ValidationExceptionConversionGenerator,
    private val protocol: ServerProtocol,
) {
    companion object {
        /**
@@ -148,7 +149,7 @@ class ServerBuilderGenerator(
        ServerBuilderConstraintViolations(codegenContext, shape, takeInUnconstrainedTypes, customValidationExceptionWithReasonConversionGenerator)

    private val codegenScope = arrayOf(
        "RequestRejection" to ServerRuntimeType.requestRejection(runtimeConfig),
        "RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
        "Structure" to structureSymbol,
        "From" to RuntimeType.From,
        "TryFrom" to RuntimeType.TryFrom,
@@ -222,7 +223,8 @@ class ServerBuilderGenerator(
            """
            #{Converter:W}
            """,
            "Converter" to customValidationExceptionWithReasonConversionGenerator.renderImplFromConstraintViolationForRequestRejection(),
            "Converter" to
                customValidationExceptionWithReasonConversionGenerator.renderImplFromConstraintViolationForRequestRejection(protocol),
        )
    }

Loading