Unverified Commit 85244460 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

`rust-server-codegen`: add support for `@httpPayload` in requests (#1060)

This commit adds support for the `httpPayload` Smithy trait when applied
to operation input members (request deserialization).

The code to deserialize HTTP-bound data from HTTP responses in
`ResponseBindingGenerator.kt` has been moved into a common class,
`HttpBindingGenerator.kt`, since it's useful for both clients and
servers in deserializing data from HTTP requests and responses,
respectively.
parent 391e49ca
Loading
Loading
Loading
Loading
+42 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

package software.amazon.smithy.rust.codegen.server.smithy.generators.http


import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.http.HttpBindingGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.http.HttpMessageType
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol

class ServerRequestBindingGenerator(
    protocol: Protocol,
    codegenContext: CodegenContext,
    operationShape: OperationShape
) {
    private val httpBindingGenerator = HttpBindingGenerator(protocol, codegenContext, operationShape)

    fun generateDeserializeHeaderFn(binding: HttpBindingDescriptor): RuntimeType =
        httpBindingGenerator.generateDeserializeHeaderFn(binding)

    fun generateDeserializePayloadFn(
        operationShape: OperationShape,
        binding: HttpBindingDescriptor,
        errorT: RuntimeType,
        structuredHandler: RustWriter.(String) -> Unit,
        docHandler: RustWriter.(String) -> Unit
    ): RuntimeType = httpBindingGenerator.generateDeserializePayloadFn(
        operationShape,
        binding,
        errorT,
        structuredHandler,
        docHandler,
        HttpMessageType.REQUEST
    )
}
+0 −4
Original line number Diff line number Diff line
@@ -449,8 +449,6 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "DocumentOutputNumber", Action.Response),
            FailingTest(RestJson, "DocumentOutputBoolean", Action.Response),
            FailingTest(RestJson, "DocumentOutputArray", Action.Response),
            FailingTest(RestJson, "DocumentTypeAsPayloadInput", Action.Request),
            FailingTest(RestJson, "DocumentTypeAsPayloadInputString", Action.Request),
            FailingTest(RestJson, "DocumentTypeAsPayloadOutput", Action.Response),
            FailingTest(RestJson, "DocumentTypeAsPayloadOutputString", Action.Response),
            FailingTest(RestJson, "RestJsonEmptyInputAndEmptyOutput", Action.Response),
@@ -465,7 +463,6 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "RestJsonFooErrorWithDunderTypeAndNamespace", Action.Response),
            FailingTest(RestJson, "RestJsonFooErrorWithDunderTypeUriAndNamespace", Action.Response),
            FailingTest(RestJson, "RestJsonHttpChecksumRequired", Action.Request),
            FailingTest(RestJson, "EnumPayloadRequest", Action.Request),
            FailingTest(RestJson, "EnumPayloadResponse", Action.Response),
            FailingTest(RestJson, "RestJsonHttpPayloadTraitsWithBlob", Action.Request),
            FailingTest(RestJson, "RestJsonHttpPayloadTraitsWithNoBlobBody", Action.Request),
@@ -483,7 +480,6 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "HttpPrefixHeadersResponse", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatLabels", Action.Request),
            FailingTest(RestJson, "RestJsonHttpResponseCode", Action.Response),
            FailingTest(RestJson, "StringPayloadRequest", Action.Request),
            FailingTest(RestJson, "StringPayloadResponse", Action.Response),
            FailingTest(RestJson, "RestJsonIgnoreQueryParamsInResponse", Action.Response),
            FailingTest(RestJson, "RestJsonInputAndOutputWithStringHeaders", Action.Request),
+51 −9
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ import software.amazon.smithy.model.node.ExpectationNotMetException
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpErrorTrait
@@ -37,7 +38,7 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.builderSymbol
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.generators.http.ResponseBindingGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.http.ServerRequestBindingGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.MakeOperationGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolTraitImplGenerator
@@ -47,11 +48,13 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescripto
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolBodyGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.isStreaming
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase
import java.util.logging.Logger
@@ -508,6 +511,7 @@ private class ServerHttpProtocolImplGenerator(
        inputShape: StructureShape,
        bindings: List<HttpBindingDescriptor>,
    ) {
        val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape)
        val structuredDataParser = protocol.structuredDataParser(operationShape)
        Attribute.AllowUnusedMut.render(this)
        rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider))
@@ -529,7 +533,7 @@ private class ServerHttpProtocolImplGenerator(
        }
        for (binding in bindings) {
            val member = binding.member
            val parsedValue = serverRenderBindingParser(binding, operationShape)
            val parsedValue = serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser)
            if (parsedValue != null) {
                withBlock("input = input.${member.setterName()}(", ");") {
                    parsedValue(this)
@@ -548,18 +552,56 @@ private class ServerHttpProtocolImplGenerator(
    private fun serverRenderBindingParser(
        binding: HttpBindingDescriptor,
        operationShape: OperationShape,
        httpBindingGenerator: ServerRequestBindingGenerator,
        structuredDataParser: StructuredDataParserGenerator,
    ): Writable? {
        val operationName = symbolProvider.toSymbol(operationShape).name
        return when (val location = binding.location) {
        val errorSymbol = if (model.expectShape(binding.member.target) is StringShape) {
            CargoDependency.SmithyHttpServer(runtimeConfig).asType().member("rejection").member("SmithyRejection")
        } else {
            CargoDependency.smithyJson(runtimeConfig).asType().member("deserialize").member("Error")
        }
        return when (binding.location) {
            HttpLocation.HEADER -> writable { serverRenderHeaderParser(this, binding, operationShape) }
            HttpLocation.LABEL -> {
                null
            HttpLocation.PAYLOAD -> {
                val docShapeHandler: RustWriter.(String) -> Unit = { body ->
                    rust(
                        "#T($body)",
                        structuredDataParser.documentParser(operationShape),
                    )
                }
            HttpLocation.DOCUMENT -> {
                val structureShapeHandler: RustWriter.(String) -> Unit = { body ->
                    rust("#T($body)", structuredDataParser.payloadParser(binding.member))
                }
                val deserializer = httpBindingGenerator.generateDeserializePayloadFn(
                    operationShape,
                    binding,
                    errorSymbol,
                    docHandler = docShapeHandler,
                    structuredHandler = structureShapeHandler
                )
                return if (binding.member.isStreaming(model)) {
                    writable { rust("""todo!("streaming request bodies");""") }
                } else {
                    writable {
                        rustTemplate("""
                            {
                                let body = request.take_body().ok_or(#{SmithyHttpServer}::rejection::BodyAlreadyExtracted)?;
                                let bytes = #{Hyper}::body::to_bytes(body).await?;
                                #{Deserializer}(&bytes)?
                            }
                            """,
                            "Deserializer" to deserializer,
                            *codegenScope
                        )
                    }
                }
            }
            HttpLocation.DOCUMENT, HttpLocation.LABEL, HttpLocation.QUERY, HttpLocation.QUERY_PARAMS -> {
                // All of these are handled separately.
                null
            }
            else -> {
                logger.warning("[rust-server-codegen] $operationName: request parsing does not currently support $location bindings")
                logger.warning("[rust-server-codegen] ${operationShape.id}: request parsing does not currently support ${binding.location} bindings")
                null
            }
        }
@@ -829,7 +871,7 @@ private class ServerHttpProtocolImplGenerator(

    private fun serverRenderHeaderParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) {
        val httpBindingGenerator =
            ResponseBindingGenerator(
            ServerRequestBindingGenerator(
                ServerRestJson(codegenContext),
                codegenContext,
                operationShape,
+386 −0

File added.

Preview size limit exceeded, changes collapsed.

+14 −321

File changed.

Preview size limit exceeded, changes collapsed.