Unverified Commit acc56f83 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Always set `Content-Type` header in RestJson1 responses (#1131)

The protocol specifies it should be set by default to `application/json`
in HTTP responses, even when there are no output shape members, since in
that case an empty JSON payload `{}` is serialized. There are server
protocol tests asserting this behavior.

To implement this behavior, the `RestJsonHttpBindingResolver` class has
been introduced, which mostly delegates to the
`HttpTraitHttpBindingResolver`, but tweaks the response `Content-Type`
header.

Note this behavior is different than the client's, which should not
serialize a JSON payload if there are no input shape members to be
serialized to the body, and as such does not need to set the
`Content-Type` header.

This commit also deletes the `ServerRestJson` and `ServerRestXml`
classes, since they were exact copies of their client counterparts. The
latter was hosted in `ServerRustXml.kt`, which has been renamed to
`ServerRestXml.kt`.
parent 0228d5ed
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -506,7 +506,6 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "RestJsonSupportsNaNFloatHeaderInputs", Action.Request),
            FailingTest(RestJson, "RestJsonInputAndOutputWithQuotedStringHeaders", Action.Response),

            FailingTest(RestJson, "RestJsonEmptyInputAndEmptyOutput", Action.Response),
            FailingTest(RestJson, "RestJsonUnitInputAndOutputNoOutput", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatQueryValues", Action.Request),
            FailingTest(RestJson, "DocumentTypeAsPayloadOutput", Action.Response),
@@ -527,10 +526,8 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "RestJsonHttpPayloadTraitsWithMediaTypeWithBlob", Action.Response),
            FailingTest(RestJson, "RestJsonHttpPayloadWithStructure", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatLabels", Action.Request),
            FailingTest(RestJson, "RestJsonHttpResponseCode", Action.Response),
            FailingTest(RestJson, "StringPayloadResponse", Action.Response),
            FailingTest(RestJson, "RestJsonNoInputAndNoOutput", Action.Response),
            FailingTest(RestJson, "RestJsonNoInputAndOutputWithJson", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Request),
            FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithBlob", Action.Response),
            FailingTest(RestJson, "RestJsonHttpWithEmptyBlobPayload", Action.Request),
+1 −1
Original line number Diff line number Diff line
@@ -103,7 +103,7 @@ private class ServerHttpProtocolImplGenerator(
    private val symbolProvider = codegenContext.symbolProvider
    private val model = codegenContext.model
    private val runtimeConfig = codegenContext.runtimeConfig
    val httpBindingResolver = protocol.httpBindingResolver
    private val httpBindingResolver = protocol.httpBindingResolver
    private val operationDeserModule = RustModule.private("operation_deser")
    private val operationSerModule = RustModule.private("operation_ser")

+4 −77
Original line number Diff line number Diff line
@@ -6,35 +6,21 @@
package software.amazon.smithy.rust.codegen.server.smithy.protocols

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpTraitHttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolContentTypes
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.JsonParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.restJsonFieldName
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.JsonSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.RestJson

/*
 * RestJson1 server-side protocol implementation and details. This factory generates the [ServerHttpProtocolGenerator]
 * RestJson1 server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator]
 * with RestJson1 specific configurations.
 */
class ServerRestJsonFactory : ProtocolGeneratorFactory<ServerHttpProtocolGenerator> {
    override fun protocol(codegenContext: CodegenContext): Protocol = ServerRestJson(codegenContext)
    override fun protocol(codegenContext: CodegenContext): Protocol = RestJson(codegenContext)

    override fun buildProtocolGenerator(codegenContext: CodegenContext): ServerHttpProtocolGenerator =
        ServerHttpProtocolGenerator(codegenContext, ServerRestJson(codegenContext))
        ServerHttpProtocolGenerator(codegenContext, RestJson(codegenContext))

    override fun transformModel(model: Model): Model = model

@@ -53,62 +39,3 @@ class ServerRestJsonFactory : ProtocolGeneratorFactory<ServerHttpProtocolGenerat
        )
    }
}

/*
 * RestJson1 implementation. This class is fed to the [ServerHttpProtocolGenerator] to configure things like
 * wire-protocol serializers, content-type, etc.
 */
class ServerRestJson(private val codegenContext: CodegenContext) : Protocol {
    private val runtimeConfig = codegenContext.runtimeConfig
    private val errorScope = arrayOf(
        "Bytes" to RuntimeType.Bytes,
        "Error" to RuntimeType.GenericError(runtimeConfig),
        "HeaderMap" to RuntimeType.http.member("HeaderMap"),
        "JsonError" to CargoDependency.smithyJson(runtimeConfig).asType().member("deserialize::Error"),
        "Response" to RuntimeType.http.member("Response"),
        "JsonErrors" to RuntimeType.jsonErrors(runtimeConfig),
    )
    private val jsonDeserModule = RustModule.private("json_deser")

    override val httpBindingResolver: HttpBindingResolver =
        HttpTraitHttpBindingResolver(codegenContext.model, ProtocolContentTypes.consistent("application/json"))

    override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS

    override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
        return JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)
    }

    override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator {
        return JsonSerializerGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)
    }

    // NOTE: this method is only needed for the little part of client-codegen we use in tests.
    override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType {
        return RuntimeType.forInlineFun("parse_http_generic_error", jsonDeserModule) { writer ->
            writer.rustTemplate(
                """
                pub fn parse_http_generic_error(response: &#{Response}<#{Bytes}>) -> Result<#{Error}, #{JsonError}> {
                    #{JsonErrors}::parse_generic_error(response.body(), response.headers())
                }
                """,
                *errorScope
            )
        }
    }

    // NOTE: this method is only needed for the little part of client-codegen we use in tests.
    override fun parseEventStreamGenericError(operationShape: OperationShape): RuntimeType {
        return RuntimeType.forInlineFun("parse_event_stream_generic_error", jsonDeserModule) { writer ->
            writer.rustTemplate(
                """
                pub fn parse_event_stream_generic_error(payload: &#{Bytes}) -> Result<#{Error}, #{JsonError}> {
                    // Note: HeaderMap::new() doesn't allocate
                    #{json_errors}::parse_generic_error(payload, &#{HeaderMap}::new())
                }
                """,
                *errorScope
            )
        }
    }
}
+42 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

package software.amazon.smithy.rust.codegen.server.smithy.protocols

import software.amazon.smithy.model.Model
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.protocols.RestXml

/*
 * RestXml server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator]
 * with RestXml specific configurations.
 */
class ServerRestXmlFactory(private val generator: (CodegenContext) -> Protocol = { RestXml(it) }) :
    ProtocolGeneratorFactory<ServerHttpProtocolGenerator> {
    override fun protocol(codegenContext: CodegenContext): Protocol = generator(codegenContext)

    override fun buildProtocolGenerator(codegenContext: CodegenContext): ServerHttpProtocolGenerator =
        ServerHttpProtocolGenerator(codegenContext, RestXml(codegenContext))

    override fun transformModel(model: Model): Model = model

    override fun support(): ProtocolSupport {
        return ProtocolSupport(
            /* Client support */
            requestSerialization = false,
            requestBodySerialization = false,
            responseDeserialization = false,
            errorDeserialization = false,
            /* Server support */
            requestDeserialization = true,
            requestBodyDeserialization = true,
            responseSerialization = true,
            errorSerialization = true
        )
    }
}
+0 −106
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

package software.amazon.smithy.rust.codegen.server.smithy.protocols

import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustModule
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpTraitHttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolContentTypes
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.RestXmlParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.XmlBindingTraitSerializerGenerator
import software.amazon.smithy.rust.codegen.util.expectTrait

class ServerRestXmlFactory(private val generator: (CodegenContext) -> Protocol = { ServerRestXml(it) }) :
    ProtocolGeneratorFactory<ServerHttpProtocolGenerator> {
    override fun protocol(codegenContext: CodegenContext): Protocol = generator(codegenContext)

    override fun buildProtocolGenerator(codegenContext: CodegenContext): ServerHttpProtocolGenerator =
        ServerHttpProtocolGenerator(codegenContext, ServerRestXml(codegenContext))

    override fun transformModel(model: Model): Model = model

    override fun support(): ProtocolSupport {
        return ProtocolSupport(
            /* Client support */
            requestSerialization = false,
            requestBodySerialization = false,
            responseDeserialization = false,
            errorDeserialization = false,
            /* Server support */
            requestDeserialization = true,
            requestBodyDeserialization = true,
            responseSerialization = true,
            errorSerialization = true
        )
    }
}

open class ServerRestXml(private val codegenContext: CodegenContext) : Protocol {
    private val restXml = codegenContext.serviceShape.expectTrait<RestXmlTrait>()
    private val runtimeConfig = codegenContext.runtimeConfig
    private val errorScope = arrayOf(
        "Bytes" to RuntimeType.Bytes,
        "Error" to RuntimeType.GenericError(runtimeConfig),
        "HeaderMap" to RuntimeType.http.member("HeaderMap"),
        "Response" to RuntimeType.http.member("Response"),
        "XmlError" to CargoDependency.smithyXml(runtimeConfig).asType().member("decode::XmlError")
    )
    private val xmlDeserModule = RustModule.private("xml_deser")

    protected val restXmlErrors: RuntimeType = when (restXml.isNoErrorWrapping) {
        true -> RuntimeType.unwrappedXmlErrors(runtimeConfig)
        false -> RuntimeType.wrappedXmlErrors(runtimeConfig)
    }

    override val httpBindingResolver: HttpBindingResolver =
        HttpTraitHttpBindingResolver(codegenContext.model, ProtocolContentTypes.consistent("application/xml"))

    override val defaultTimestampFormat: TimestampFormatTrait.Format =
        TimestampFormatTrait.Format.DATE_TIME

    override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
        return RestXmlParserGenerator(codegenContext, restXmlErrors)
    }

    override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator {
        return XmlBindingTraitSerializerGenerator(codegenContext, httpBindingResolver)
    }

    override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType =
        RuntimeType.forInlineFun("parse_http_generic_error", xmlDeserModule) { writer ->
            writer.rustBlockTemplate(
                "pub fn parse_http_generic_error(response: &#{Response}<#{Bytes}>) -> Result<#{Error}, #{XmlError}>",
                *errorScope
            ) {
                rust("#T::parse_generic_error(response.body().as_ref())", restXmlErrors)
            }
        }

    override fun parseEventStreamGenericError(operationShape: OperationShape): RuntimeType =
        RuntimeType.forInlineFun("parse_event_stream_generic_error", xmlDeserModule) { writer ->
            writer.rustBlockTemplate(
                "pub fn parse_event_stream_generic_error(payload: &#{Bytes}) -> Result<#{Error}, #{XmlError}>",
                *errorScope
            ) {
                rust("#T::parse_generic_error(payload.as_ref())", restXmlErrors)
            }
        }
}
Loading