Unverified Commit 4c30f00f authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Refactor determining server error type when deserializing an `@httpPayload` (#3752)

Determining the error type when deserializing an `@httpPayload` is a
protocol-specific concern, and as such should not live in
`ServerHttpBoundProtocolGenerator`, which should remain
protocol-agnostic. This commits makes that determination part of the
`ServerProtocol` interface.

As a drive-by improvement, the companion object in
`ServerHttpBoundProtocolGenerator` has also been removed, since its
members have been unused for a long time.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent dc66ae4b
Loading
Loading
Loading
Loading
+3 −5
Original line number Diff line number Diff line
@@ -5,7 +5,6 @@

package software.amazon.smithy.rust.codegen.server.smithy.generators.http

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
@@ -20,12 +19,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindi
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType
import software.amazon.smithy.rust.codegen.core.smithy.mapRustType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape

class ServerRequestBindingGenerator(
    protocol: Protocol,
    val protocol: ServerProtocol,
    codegenContext: ServerCodegenContext,
    operationShape: OperationShape,
    additionalHttpBindingCustomizations: List<HttpBindingCustomization> = listOf(),
@@ -50,12 +49,11 @@ class ServerRequestBindingGenerator(

    fun generateDeserializePayloadFn(
        binding: HttpBindingDescriptor,
        errorSymbol: Symbol,
        structuredHandler: RustWriter.(String) -> Unit,
    ): RuntimeType =
        httpBindingGenerator.generateDeserializePayloadFn(
            binding,
            errorSymbol,
            protocol.deserializePayloadErrorType(binding).toSymbol(),
            structuredHandler,
            HttpMessageType.REQUEST,
        )
+64 −2
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
@@ -17,7 +18,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
@@ -70,8 +73,8 @@ interface ServerProtocol : Protocol {
    fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType

    /**
     * In some protocols, such as restJson1,
     * when there is no modeled body input, content type must not be set and the body must be empty.
     * In some protocols, such as `restJson1` and `rpcv2Cbor`,
     * when there is no modeled body input, `content-type` must not be set and the body must be empty.
     * Returns a boolean indicating whether to perform this check.
     */
    fun serverContentTypeCheckNoModeledInput(): Boolean = false
@@ -90,6 +93,19 @@ interface ServerProtocol : Protocol {
    fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
        ServerCargoDependency.smithyHttpServer(runtimeConfig)
            .toType().resolve("protocol::$protocolModulePath::runtime_error::RuntimeError")

    /**
     * The function that deserializes a payload-bound shape takes as input a byte slab and returns a `Result` holding
     * the deserialized shape if successful. What error type should we use in case of failure?
     *
     * The shape could be payload-bound either because of the `@httpPayload` trait, or because it's part of an event
     * stream.
     *
     * Note that despite the trait (https://smithy.io/2.0/spec/http-bindings.html#httppayload-trait) being able to
     * target any structure member shape, AWS Protocols only support binding the following shape types to the payload
     * (and Smithy does indeed enforce this at model build-time): string, blob, structure, union, and document
     */
    fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType
}

fun returnSymbolToParseFn(codegenContext: ServerCodegenContext): (Shape) -> ReturnSymbolToParse {
@@ -185,6 +201,18 @@ class ServerAwsJsonProtocol(
    override fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
        ServerCargoDependency.smithyHttpServer(runtimeConfig)
            .toType().resolve("protocol::aws_json::runtime_error::RuntimeError")

    /*
     * Note that despite the AWS JSON 1.x protocols not supporting the `@httpPayload` trait, event streams are bound
     * to the payload.
     */
    override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType =
        deserializePayloadErrorType(
            codegenContext,
            binding,
            requestRejection(runtimeConfig),
            RuntimeType.smithyJson(codegenContext.runtimeConfig).resolve("deserialize::error::DeserializeError"),
        )
}

private fun restRouterType(runtimeConfig: RuntimeConfig) =
@@ -227,6 +255,14 @@ class ServerRestJsonProtocol(
    override fun serverRouterRuntimeConstructor() = "new_rest_json_router"

    override fun serverContentTypeCheckNoModeledInput() = true

    override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType =
        deserializePayloadErrorType(
            codegenContext,
            binding,
            requestRejection(runtimeConfig),
            RuntimeType.smithyJson(codegenContext.runtimeConfig).resolve("deserialize::error::DeserializeError"),
        )
}

class ServerRestXmlProtocol(
@@ -252,6 +288,32 @@ class ServerRestXmlProtocol(
    override fun serverRouterRuntimeConstructor() = "new_rest_xml_router"

    override fun serverContentTypeCheckNoModeledInput() = true

    override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType =
        deserializePayloadErrorType(
            codegenContext,
            binding,
            requestRejection(runtimeConfig),
            RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"),
        )
}

/** Just a common function to keep things DRY. **/
fun deserializePayloadErrorType(
    codegenContext: CodegenContext,
    binding: HttpBindingDescriptor,
    requestRejection: RuntimeType,
    protocolSerializationFormatError: RuntimeType,
): RuntimeType {
    check(binding.location == HttpLocation.PAYLOAD)

    if (codegenContext.model.expectShape(binding.member.target) is StringShape) {
        // The only way deserializing a string can fail is if the HTTP body does not contain valid UTF-8.
        // TODO(https://github.com/smithy-lang/smithy-rs/issues/3750): we're returning an incorrect `RequestRejection` variant here.
        return requestRejection
    }

    return protocolSerializationFormatError
}

/**
+8 −49
Original line number Diff line number Diff line
@@ -5,10 +5,6 @@

package software.amazon.smithy.rust.codegen.server.smithy.protocols

import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.node.ExpectationNotMetException
@@ -20,7 +16,6 @@ import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpErrorTrait
@@ -124,13 +119,7 @@ class ServerHttpBoundProtocolGenerator(
) : ServerProtocolGenerator(
        protocol,
        ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations),
    ) {
    // Define suffixes for operation input / output / error wrappers
    companion object {
        const val OPERATION_INPUT_WRAPPER_SUFFIX = "OperationInputWrapper"
        const val OPERATION_OUTPUT_WRAPPER_SUFFIX = "OperationOutputWrapper"
    }
}
    )

class ServerHttpBoundProtocolPayloadGenerator(
    codegenContext: CodegenContext,
@@ -697,8 +686,6 @@ class ServerHttpBoundProtocolTraitImplGenerator(
        inputShape: StructureShape,
        bindings: List<HttpBindingDescriptor>,
    ) {
        val httpBindingGenerator =
            ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations)
        val structuredDataParser = protocol.structuredDataParser()
        Attribute.AllowUnusedMut.render(this)
        rust(
@@ -740,7 +727,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
        for (binding in bindings) {
            val member = binding.member
            val parsedValue =
                serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser)
                serverRenderBindingParser(binding, operationShape, httpBindingGenerator(operationShape), structuredDataParser)
            val valueToSet =
                if (symbolProvider.toSymbol(binding.member).isOptional()) {
                    "Some(value)"
@@ -801,13 +788,8 @@ class ServerHttpBoundProtocolTraitImplGenerator(
                val structureShapeHandler: RustWriter.(String) -> Unit = { body ->
                    rust("#T($body)", structuredDataParser.payloadParser(binding.member))
                }
                val errorSymbol = getDeserializePayloadErrorSymbol(binding)
                val deserializer =
                    httpBindingGenerator.generateDeserializePayloadFn(
                        binding,
                        errorSymbol,
                        structuredHandler = structureShapeHandler,
                    )
                    httpBindingGenerator.generateDeserializePayloadFn(binding, structuredHandler = structureShapeHandler)
                return writable {
                    if (binding.member.isStreaming(model)) {
                        rustTemplate(
@@ -1196,9 +1178,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
        binding: HttpBindingDescriptor,
        operationShape: OperationShape,
    ) {
        val httpBindingGenerator =
            ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations)
        val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding)
        val deserializer = httpBindingGenerator(operationShape).generateDeserializeHeaderFn(binding)
        writer.rustTemplate(
            """
            #{deserializer}(&headers)?
@@ -1215,8 +1195,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
    ) {
        check(binding.location == HttpLocation.PREFIX_HEADERS)

        val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape)
        val deserializer = httpBindingGenerator.generateDeserializePrefixHeadersFn(binding)
        val deserializer = httpBindingGenerator(operationShape).generateDeserializePrefixHeadersFn(binding)
        writer.rustTemplate(
            """
            #{deserializer}(&headers)?
@@ -1300,33 +1279,13 @@ class ServerHttpBoundProtocolTraitImplGenerator(
        }
    }

    /**
     * Returns the error type of the function that deserializes a non-streaming HTTP payload (a byte slab) into the
     * shape targeted by the `httpPayload` trait.
     */
    private fun getDeserializePayloadErrorSymbol(binding: HttpBindingDescriptor): Symbol {
        check(binding.location == HttpLocation.PAYLOAD)

        if (model.expectShape(binding.member.target) is StringShape) {
            return protocol.requestRejection(runtimeConfig).toSymbol()
        }
        return when (codegenContext.protocol) {
            RestJson1Trait.ID, AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> {
                RuntimeType.smithyJson(runtimeConfig).resolve("deserialize::error::DeserializeError").toSymbol()
            }
            RestXmlTrait.ID -> {
                RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError").toSymbol()
            }
            else -> {
                TODO("Protocol ${codegenContext.protocol} not supported yet")
            }
        }
    }

    private fun streamingBodyTraitBounds(operationShape: OperationShape) =
        if (operationShape.inputShape(model).hasStreamingMember(model)) {
            "\n B: Into<#{SmithyTypes}::byte_stream::ByteStream>,"
        } else {
            ""
        }

    private fun httpBindingGenerator(operationShape: OperationShape) =
        ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations)
}