Unverified Commit 171d76fc authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Add support for `httpPayload` in server responses (#1158)

Note that with this change, payload serializers now return `Vec<u8>` or
`ByteStream` instead of always `SdkBody`, since the server does not use
`SdkBody`. The caller then needs to convert the byte slab into the
appropriate HTTP body type. To this end, `HttpBoundProtocolBodyGenerator`
has been extracted into its own file and renamed to
`HttpBoundProtocolPayloadGenerator`.

This commit also makes the body types exports from
`aws-smithy-http-server` `#[doc(hidden)]`, since they should only be
used by generated code, and exports all of them from `body.rs` instead
of `lib.rs`.

Comments have been added where appropriate to document which serializers
are called by the client / server and how their use differs.
parent 548f6ed4
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -36,7 +36,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.client.FluentClient
import software.amazon.smithy.rust.codegen.smithy.generators.client.FluentClientSection
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.MakeOperationGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolBodyGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.util.cloneOperation
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
@@ -164,7 +164,7 @@ class AwsInputPresignedMethod(
            MakeOperationGenerator(
                codegenContext,
                protocol,
                HttpBoundProtocolBodyGenerator(codegenContext, protocol),
                HttpBoundProtocolPayloadGenerator(codegenContext, protocol),
                // Prefixed with underscore to avoid colliding with modeled functions
                functionName = makeOperationFn,
                public = false,
+5 −5
Original line number Diff line number Diff line
@@ -73,7 +73,7 @@ class ServerOperationHandlerGenerator(
                let error = aws_smithy_http_server::ExtensionRejection::new(r.to_string());
                let mut response = r.into_response();
                response.extensions_mut().insert(error);
                return response.map($serverCrate::boxed);
                return response.map($serverCrate::body::boxed);
                }
            """.trimIndent()
            writer.rustBlockTemplate(
@@ -99,7 +99,7 @@ class ServerOperationHandlerGenerator(
                rustTemplate(
                    """
                    type Sealed = #{ServerOperationHandler}::sealed::Hidden;
                    async fn call(self, req: #{http}::Request<B>) -> #{http}::Response<#{SmithyHttpServer}::BoxBody> {
                    async fn call(self, req: #{http}::Request<B>) -> #{http}::Response<#{SmithyHttpServer}::body::BoxBody> {
                        let mut req = #{AxumCore}::extract::RequestParts::new(req);
                        use #{AxumCore}::extract::FromRequest;
                        use #{AxumCore}::response::IntoResponse;
@@ -109,7 +109,7 @@ class ServerOperationHandlerGenerator(
                        };
                        $callImpl
                        let output_wrapper: $outputWrapperName = output_inner.into();
                        output_wrapper.into_response().map(#{SmithyHttpServer}::boxed)
                        output_wrapper.into_response().map(#{SmithyHttpServer}::body::boxed)
                    }
                    """,
                    *codegenScope
@@ -143,10 +143,10 @@ class ServerOperationHandlerGenerator(
        return """
            $inputFn
            Fut: std::future::Future<Output = $outputType> + Send,
            B: $serverCrate::HttpBody + Send + 'static, $streamingBodyTraitBounds
            B: $serverCrate::body::HttpBody + Send + 'static, $streamingBodyTraitBounds
            B::Data: Send,
            B::Error: Into<$serverCrate::BoxError>,
            $serverCrate::rejection::SmithyRejection: From<<B as $serverCrate::HttpBody>::Error>
            $serverCrate::rejection::SmithyRejection: From<<B as $serverCrate::body::HttpBody>::Error>
        """.trimIndent()
    }
}
+2 −9
Original line number Diff line number Diff line
@@ -223,7 +223,7 @@ class ServerProtocolTestGenerator(
        }
        rustTemplate(
            """
            .body(#{SmithyHttpServer}::Body::from(#{Bytes}::from_static(b${httpRequestTestCase.body.orNull()?.dq()})))
            .body(#{SmithyHttpServer}::body::Body::from(#{Bytes}::from_static(b${httpRequestTestCase.body.orNull()?.dq()})))
            .unwrap();
            """,
            *codegenScope
@@ -508,8 +508,6 @@ class ServerProtocolTestGenerator(

            FailingTest(RestJson, "RestJsonUnitInputAndOutputNoOutput", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatQueryValues", Action.Request),
            FailingTest(RestJson, "DocumentTypeAsPayloadOutput", Action.Response),
            FailingTest(RestJson, "DocumentTypeAsPayloadOutputString", Action.Response),
            FailingTest(RestJson, "RestJsonEndpointTrait", Action.Request),
            FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", Action.Request),
            FailingTest(RestJson, "RestJsonInvalidGreetingError", Action.Response),
@@ -520,13 +518,8 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "RestJsonFooErrorWithDunderType", Action.Response),
            FailingTest(RestJson, "RestJsonFooErrorWithDunderTypeAndNamespace", Action.Response),
            FailingTest(RestJson, "RestJsonFooErrorWithDunderTypeUriAndNamespace", Action.Response),
            FailingTest(RestJson, "EnumPayloadResponse", Action.Response),
            FailingTest(RestJson, "RestJsonHttpPayloadTraitsWithBlob", Action.Response),
            FailingTest(RestJson, "RestJsonHttpPayloadTraitsWithNoBlobBody", Action.Response),
            FailingTest(RestJson, "RestJsonHttpPayloadTraitsWithMediaTypeWithBlob", Action.Response),
            FailingTest(RestJson, "RestJsonHttpPayloadWithStructure", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatLabels", Action.Request),
            FailingTest(RestJson, "StringPayloadResponse", Action.Response),
            FailingTest(RestJson, "RestJsonHttpResponseCode", Action.Response),
            FailingTest(RestJson, "RestJsonNoInputAndNoOutput", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Request),
            FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithBlob", Action.Response),
+42 −59
Original line number Diff line number Diff line
@@ -30,6 +30,7 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
@@ -41,18 +42,18 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.builderSymbol
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.generators.http.HttpMessageType
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.MakeOperationGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolTraitImplGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolBodyGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.toOptional
import software.amazon.smithy.rust.codegen.smithy.wrapOptional
import software.amazon.smithy.rust.codegen.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.findStreamingMember
@@ -75,7 +76,7 @@ class ServerHttpProtocolGenerator(
) : ProtocolGenerator(
    codegenContext,
    protocol,
    MakeOperationGenerator(codegenContext, protocol, HttpBoundProtocolBodyGenerator(codegenContext, protocol)),
    MakeOperationGenerator(codegenContext, protocol, HttpBoundProtocolPayloadGenerator(codegenContext, protocol)),
    ServerHttpProtocolImplGenerator(codegenContext, protocol),
) {
    // Define suffixes for operation input / output / error wrappers
@@ -157,10 +158,10 @@ private class ServerHttpProtocolImplGenerator(
            ##[#{AsyncTrait}::async_trait]
            impl<B> #{AxumCore}::extract::FromRequest<B> for $inputName
            where
                B: #{SmithyHttpServer}::HttpBody + Send, ${getStreamingBodyTraitBounds(operationShape)}
                B: #{SmithyHttpServer}::body::HttpBody + Send, ${streamingBodyTraitBounds(operationShape)}
                B::Data: Send,
                B::Error: Into<#{SmithyHttpServer}::BoxError>,
                #{SmithyRejection}: From<<B as #{SmithyHttpServer}::HttpBody>::Error>
                #{SmithyRejection}: From<<B as #{SmithyHttpServer}::body::HttpBody>::Error>
            {
                type Rejection = #{SmithyRejection};
                async fn from_request(req: &mut #{AxumCore}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
@@ -328,10 +329,10 @@ private class ServerHttpProtocolImplGenerator(
                    #{SmithyRejection}
                >
                where
                    B: #{SmithyHttpServer}::HttpBody + Send, ${getStreamingBodyTraitBounds(operationShape)}
                    B: #{SmithyHttpServer}::body::HttpBody + Send, ${streamingBodyTraitBounds(operationShape)}
                    B::Data: Send,
                    B::Error: Into<#{SmithyHttpServer}::BoxError>,
                    #{SmithyRejection}: From<<B as #{SmithyHttpServer}::HttpBody>::Error>
                    #{SmithyRejection}: From<<B as #{SmithyHttpServer}::body::HttpBody>::Error>
                """.trimIndent(),
                *codegenScope,
                "I" to inputSymbol,
@@ -441,42 +442,36 @@ private class ServerHttpProtocolImplGenerator(
        }
    }

    /**
     * Render an HTTP response (headers, response code, body) for an operation's output and the given [bindings].
     */
    private fun RustWriter.serverRenderOutputShapeResponseSerializer(
        operationShape: OperationShape,
        bindings: List<HttpBindingDescriptor>,
    ) {
        // avoid non-usage warnings for response
        Attribute.AllowUnusedMut.render(this)
        rustTemplate("let mut builder = #{http}::Response::builder();", *codegenScope)
        serverRenderResponseHeaders(operationShape)
        bindings.find { it.location == HttpLocation.RESPONSE_CODE }?.let { serverRenderResponseCodeBinding(it) }

        for (binding in bindings) {
            val serializedValue = serverRenderBindingSerializer(binding)
            if (serializedValue != null) {
                serializedValue(this)
            }
        }
        val streamingMember = operationShape.outputShape(model).findStreamingMember(model)
        if (streamingMember != null) {
            val memberName = symbolProvider.toMemberName(streamingMember)
        operationShape.outputShape(model).findStreamingMember(model)?.let {
            val memberName = symbolProvider.toMemberName(it)
            rustTemplate(
                """
                let payload = #{SmithyHttpServer}::body::Body::wrap_stream(output.$memberName);
                let body = #{SmithyHttpServer}::body::to_boxed(#{SmithyHttpServer}::body::Body::wrap_stream(output.$memberName));
                """,
                *codegenScope,
            )
        } else {
            val structuredDataSerializer = protocol.structuredDataSerializer(operationShape)
            structuredDataSerializer.serverOutputSerializer(operationShape)?.let { serializer ->
                rust(
                    "let payload = #T(&output)?;",
                    serializer
                )
            } ?: rust("""let payload = "";""")
        } ?:run {
            val payloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol, httpMessageType = HttpMessageType.RESPONSE)
            withBlockTemplate("let body = #{SmithyHttpServer}::body::to_boxed(", ");", *codegenScope) {
                payloadGenerator.generatePayload(this, "output", operationShape)
            }
        }

        rustTemplate(
            """
            builder.body(#{SmithyHttpServer}::body::to_boxed(payload))?
            builder.body(body)?
            """,
            *codegenScope,
        )
@@ -522,20 +517,12 @@ private class ServerHttpProtocolImplGenerator(
        }
    }

    private fun serverRenderBindingSerializer(
        binding: HttpBindingDescriptor,
    ): Writable? {
        val member = binding.member
        return when (binding.location) {
            HttpLocation.HEADER,
            HttpLocation.PREFIX_HEADERS,
            HttpLocation.DOCUMENT,
            HttpLocation.PAYLOAD -> {
                // All of these are handled separately.
                null
            }
            HttpLocation.RESPONSE_CODE -> writable {
                val memberName = symbolProvider.toMemberName(member)
    private fun serverRenderResponseCodeBinding(
        binding: HttpBindingDescriptor
    ): Writable {
        check(binding.location == HttpLocation.RESPONSE_CODE)
        return writable {
            val memberName = symbolProvider.toMemberName(binding.member)
            rustTemplate(
                """
                let status = output.$memberName
@@ -547,10 +534,6 @@ private class ServerHttpProtocolImplGenerator(
            )
            rust("builder = builder.status(http_status);")
        }
            else -> {
                UNREACHABLE("Unexpected binding location: ${binding.location}")
            }
        }
    }

    private fun RustWriter.serverRenderShapeParser(
@@ -1075,11 +1058,11 @@ private class ServerHttpProtocolImplGenerator(
            }
        }
    }
    private fun getStreamingBodyTraitBounds(operationShape: OperationShape): String {

    private fun streamingBodyTraitBounds(operationShape: OperationShape) =
        if (operationShape.inputShape(model).hasStreamingMember(model)) {
            return "\n B: Into<#{SmithyHttp}::byte_stream::ByteStream>,"
            "\n B: Into<#{SmithyHttp}::byte_stream::ByteStream>,"
        } else {
            return ""
        }
            ""
        }
}
+1 −1
Original line number Diff line number Diff line
@@ -61,7 +61,7 @@ import software.amazon.smithy.rust.codegen.util.toSnakeCase
 *     - serializing data to an HTTP request (we are a client),
 *     - serializing data to an HTTP response (we are a server),
 */
enum class HttpMessageType {
public enum class HttpMessageType {
    REQUEST, RESPONSE
}

Loading