Unverified Commit 5db1f358 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Set the `X-Amzn-Errortype` header in restJson1 responses (#1218)

It turns out this header is mandatory for error responses in restJson1;
there was a bug in the spec that has been fixed in
https://github.com/awslabs/smithy/pull/1099
parent 3d0b98c1
Loading
Loading
Loading
Loading
+2 −26
Original line number Diff line number Diff line
@@ -48,6 +48,7 @@ import software.amazon.smithy.rust.codegen.util.orNull
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase
import java.util.logging.Logger
import kotlin.reflect.KFunction1

/**
 * Generate protocol tests for an operation
@@ -1379,27 +1380,6 @@ class ServerProtocolTestGenerator(
                    """.trimMargin()
                ).asObjectNode().get()
            ).build()
        // The following tests assume that errors in responses are identified by an `X-Amzn-Errortype` header with
        // the error shape name.
        // However, Smithy specifications for AWS protocols that serialize to JSON recommend that new server implementations
        // serialize error types using a `__type` field in the body.
        // Our implementation follows this recommendation, so we fix the tests by removing the header and instead expecting
        // the error type to be in the body.
        private fun fixRestJsonEmptyComplexErrorWithNoMessage(testCase: HttpResponseTestCase): HttpResponseTestCase =
            testCase.toBuilder()
                .headers(emptyMap())
                .body("""{"__type":"ComplexError"}""")
                .build()
        private fun fixRestJsonInvalidGreetingError(testCase: HttpResponseTestCase): HttpResponseTestCase =
            testCase.toBuilder()
                .headers(emptyMap())
                .body("""{"Message":"Hi","__type":"InvalidGreeting"}""")
                .build()
        private fun fixRestJsonComplexErrorWithNoMessage(testCase: HttpResponseTestCase): HttpResponseTestCase =
            testCase.toBuilder()
                .headers(emptyMap())
                .body("""{"Nested":{"Fooooo":"bar"},"TopLevel":"Top level","__type":"ComplexError"}""")
                .build()

        // These are tests whose definitions in the `awslabs/smithy` repository are wrong.
        // This is because they have not been written from a server perspective, and as such the expected `params` field is incomplete.
@@ -1413,10 +1393,6 @@ class ServerProtocolTestGenerator(
            Pair(RestJson, "RestJsonQueryStringEscaping") to ::fixRestJsonQueryStringEscaping,
        )

        private val BrokenResponseTests = mapOf(
            Pair(RestJson, "RestJsonEmptyComplexErrorWithNoMessage") to ::fixRestJsonEmptyComplexErrorWithNoMessage,
            Pair(RestJson, "RestJsonInvalidGreetingError") to ::fixRestJsonInvalidGreetingError,
            Pair(RestJson, "RestJsonComplexErrorWithNoMessage") to ::fixRestJsonComplexErrorWithNoMessage,
        )
        private val BrokenResponseTests: Map<Pair<String, String>, KFunction1<HttpResponseTestCase, HttpResponseTestCase>> = mapOf()
    }
}
+26 −8
Original line number Diff line number Diff line
@@ -495,26 +495,29 @@ private class ServerHttpProtocolImplGenerator(
     * It will generate response headers for the operation's output shape, unless [errorShape] is non-null, in which
     * case it will generate response headers for the given error shape.
     *
     * It serializes as HTTP headers shape members that are bound with `httpHeader` and `httpPrefixHeader`.
     *
     * The `Content-Type` header is also set according to the protocol and the contents of the shape to be serialized.
     * It sets three groups of headers in order. Headers from one group take precedence over headers in a later group.
     *     1. Headers bound by the `httpHeader` and `httpPrefixHeader` traits.
     *     2. The protocol-specific `Content-Type` header for the operation.
     *     3. Additional protocol-specific headers for errors, if [errorShape] is non-null.
     */
    private fun RustWriter.serverRenderResponseHeaders(operationShape: OperationShape, errorShape: StructureShape? = null) {
        val bindingGenerator = ServerResponseBindingGenerator(protocol, codegenContext, operationShape)
        val addHeadersFn = bindingGenerator.generateAddHeadersFn(errorShape ?: operationShape)
        if (addHeadersFn != null) {
            // notice that we need to borrow the output only for output shapes but not for error shapes
            val outputOwnedOrBorrow = if (errorShape == null) "&output" else "output"
            // Notice that we need to borrow the output only for output shapes but not for error shapes.
            val outputOwnedOrBorrowed = if (errorShape == null) "&output" else "output"
            rust(
                """
                builder = #{T}($outputOwnedOrBorrow, builder)?;
                builder = #{T}($outputOwnedOrBorrowed, builder)?;
                """.trimIndent(),
                addHeadersFn
            )
        }

        // set the content type header *after* the response bindings headers have been set
        // to allow operations that bind a member to content-type to take precedence
        // Set the `Content-Type` header *after* the response bindings headers have been set,
        // to allow operations that bind a member to `Content-Type` (which we set earlier) to take precedence (this is
        // because we always use `set_response_header_if_absent`, so the _first_ header value we set for a given
        // header name is the one that takes precedence).
        val contentType = httpBindingResolver.responseContentType(operationShape)
        if (contentType != null) {
            rustTemplate(
@@ -528,6 +531,21 @@ private class ServerHttpProtocolImplGenerator(
                *codegenScope
            )
        }

        if (errorShape != null) {
            for ((headerName, headerValue) in protocol.additionalErrorResponseHeaders(errorShape)) {
                rustTemplate(
                    """
                    builder = #{header_util}::set_response_header_if_absent(
                        builder,
                        http::header::HeaderName::from_static("$headerName"),
                        "$headerValue"
                    );
                    """,
                    *codegenScope
                )
            }
        }
    }

    private fun serverRenderResponseCodeBinding(
+7 −7
Original line number Diff line number Diff line
@@ -410,14 +410,14 @@ class HttpBindingGenerator(
        shape: Shape,
        httpMessageType: HttpMessageType = HttpMessageType.REQUEST
    ): RuntimeType? {
        val headerBindings = when (httpMessageType) {
            HttpMessageType.REQUEST -> index.getRequestBindings(shape, HttpLocation.HEADER)
            HttpMessageType.RESPONSE -> index.getResponseBindings(shape, HttpLocation.HEADER)
        }
        val prefixHeaderBinding = when (httpMessageType) {
            HttpMessageType.REQUEST -> index.getRequestBindings(shape, HttpLocation.PREFIX_HEADERS)
            HttpMessageType.RESPONSE -> index.getResponseBindings(shape, HttpLocation.PREFIX_HEADERS)
        }.getOrNull(0) // Only a single structure member can be bound to `httpPrefixHeaders`.
        val (headerBindings, prefixHeaderBinding) = when (httpMessageType) {
            // Only a single structure member can be bound by `httpPrefixHeaders`, hence the `getOrNull(0)`.
            HttpMessageType.REQUEST -> index.getRequestBindings(shape, HttpLocation.HEADER) to
                index.getRequestBindings(shape, HttpLocation.PREFIX_HEADERS).getOrNull(0)
            HttpMessageType.RESPONSE -> index.getResponseBindings(shape, HttpLocation.HEADER) to
                index.getResponseBindings(shape, HttpLocation.PREFIX_HEADERS).getOrNull(0)
        }

        if (headerBindings.isEmpty() && prefixHeaderBinding == null) {
            return null
        }
+8 −0
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ import software.amazon.smithy.model.knowledge.ServiceIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.model.traits.Trait
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
@@ -42,6 +43,13 @@ interface Protocol {
    /** Returns additional HTTP headers that should be included in HTTP requests for the given operation for this protocol. */
    fun additionalRequestHeaders(operationShape: OperationShape): List<Pair<String, String>> = emptyList()

    /**
     * Returns additional HTTP headers that should be included in HTTP responses for the given error shape.
     * These MUST all be lowercase, or the application will panic, as per
     * https://docs.rs/http/latest/http/header/struct.HeaderName.html#method.from_static
     */
    fun additionalErrorResponseHeaders(errorShape: StructureShape): List<Pair<String, String>> = emptyList()

    /** Returns a deserialization code generator for this protocol */
    fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator

+9 −1
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.smithy.protocols
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.JsonNameTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
@@ -60,7 +61,7 @@ class RestJsonHttpBindingResolver(
) : HttpTraitHttpBindingResolver(model, contentTypes) {
    /**
     * In the RestJson1 protocol, HTTP responses have a default `Content-Type: application/json` header if it is not
     * overridden by a specific mechanism (e.g. an output shape member is targeted with `httpPayload` or `mediaType` traits.
     * overridden by a specific mechanism e.g. an output shape member is targeted with `httpPayload` or `mediaType` traits.
     */
    override fun responseContentType(operationShape: OperationShape): String =
        super.responseContentType(operationShape) ?: "application/json"
@@ -83,6 +84,13 @@ class RestJson(private val codegenContext: CodegenContext) : Protocol {

    override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS

    /**
     * RestJson1 implementations can denote errors in responses in several ways.
     * New server-side protocol implementations MUST use a header field named `X-Amzn-Errortype`.
     */
    override fun additionalErrorResponseHeaders(errorShape: StructureShape): List<Pair<String, String>> =
        listOf("x-amzn-errortype" to errorShape.id.name)

    override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator =
        JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName)

Loading