Unverified Commit 0af394d0 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Add support for `httpPrefixHeaders` in server requests (#1140)

The implemented behavior is to fail to deserialize HTTP requests if
there is more than one value for the same header name. See #1139.
parent f76bc159
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -36,4 +36,8 @@ class ServerRequestBindingGenerator(
        structuredHandler,
        HttpMessageType.REQUEST
    )

    fun generateDeserializePrefixHeadersFn(
        binding: HttpBindingDescriptor,
    ): RuntimeType = httpBindingGenerator.generateDeserializePrefixHeaderFn(binding)
}
+0 −2
Original line number Diff line number Diff line
@@ -503,8 +503,6 @@ class ServerProtocolTestGenerator(
        private val ExpectFail = setOf<FailingTest>(
            // Headers.
            FailingTest(RestJson, "RestJsonHttpWithHeadersButNoPayload", Action.Request),
            FailingTest(RestJson, "RestJsonHttpPrefixHeadersArePresent", Action.Request),
            FailingTest(RestJson, "RestJsonHttpPrefixHeadersAreNotPresent", Action.Request),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatHeaderInputs", Action.Request),
            FailingTest(RestJson, "RestJsonInputAndOutputWithQuotedStringHeaders", Action.Response),

+21 −1
Original line number Diff line number Diff line
@@ -604,6 +604,7 @@ private class ServerHttpProtocolImplGenerator(
        val errorSymbol = getDeserializeErrorSymbol(binding)
        return when (binding.location) {
            HttpLocation.HEADER -> writable { serverRenderHeaderParser(this, binding, operationShape) }
            HttpLocation.PREFIX_HEADERS -> writable { serverRenderPrefixHeadersParser(this, binding, operationShape) }
            HttpLocation.PAYLOAD -> {
                return if (binding.member.isStreaming(model)) {
                    writable {
@@ -918,7 +919,7 @@ private class ServerHttpProtocolImplGenerator(
    private fun serverRenderHeaderParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) {
        val httpBindingGenerator =
            ServerRequestBindingGenerator(
                ServerRestJson(codegenContext),
                protocol,
                codegenContext,
                operationShape,
            )
@@ -932,6 +933,25 @@ private class ServerHttpProtocolImplGenerator(
        )
    }

    private fun serverRenderPrefixHeadersParser(writer: RustWriter, binding: HttpBindingDescriptor, operationShape: OperationShape) {
        check(binding.location == HttpLocation.PREFIX_HEADERS)

        val httpBindingGenerator =
            ServerRequestBindingGenerator(
                protocol,
                codegenContext,
                operationShape,
            )
        val deserializer = httpBindingGenerator.generateDeserializePrefixHeadersFn(binding)
        writer.rustTemplate(
            """
            #{deserializer}(request.headers().ok_or(#{SmithyHttpServer}::rejection::HeadersAlreadyExtracted)?)?
            """.trimIndent(),
            "deserializer" to deserializer,
            *codegenScope
        )
    }

    private fun generateParsePercentEncodedStrFn(binding: HttpBindingDescriptor): RuntimeType {
        // HTTP bindings we support that contain percent-encoded data.
        check(binding.location == HttpLocation.LABEL || binding.location == HttpLocation.QUERY)
+6 −4
Original line number Diff line number Diff line
@@ -157,7 +157,9 @@ class HttpBindingGenerator(
                    let headers = #T::headers_for_prefix(header_map, ${binding.locationName.dq()});
                    let out: std::result::Result<_, _> = headers.map(|(key, header_name)| {
                        let values = header_map.get_all(header_name);
                        #T(values.iter()).map(|v| (key.to_string(), v.unwrap()))
                        #T(values.iter()).map(|v| (key.to_string(), v.expect(
                            "we have checked there is at least one value for this header name; please file a bug report under https://github.com/awslabs/smithy-rs/issues
                        ")))
                    }).collect();
                    out.map(Some)
                    """,
@@ -295,7 +297,7 @@ class HttpBindingGenerator(

    /**
     * Parse a value from a header.
     * This function produces an expression which produces the precise output type required by the output shape.
     * This function produces an expression which produces the precise type required by the target shape.
     */
    private fun RustWriter.deserializeFromHeader(targetType: Shape, memberShape: MemberShape) {
        val rustType = symbolProvider.toSymbol(targetType).rustType().stripOuter<RustType.Option>()
@@ -387,9 +389,9 @@ class HttpBindingGenerator(
    }

    /**
     * Generate a unique name for the deserializer function for a given operationShape -> member pair.
     * Generate a unique name for the deserializer function for a given [operationShape] and HTTP binding.
     */
    // rename here technically not required, operations and members cannot be renamed
    // Rename here technically not required, operations and members cannot be renamed.
    private fun fnName(operationShape: OperationShape, binding: HttpBindingDescriptor) =
        "${operationShape.id.getName(service).toSnakeCase()}_${binding.member.container.name.toSnakeCase()}_${binding.memberName.toSnakeCase()}"