Unverified Commit f76bc159 authored by Guy Margalit's avatar Guy Margalit Committed by GitHub
Browse files

Server streaming body (#1023)



Add support for server blob streaming requests and responses

Data is streamed over the HTTP body.

Signed-off-by: default avatarGuy Margalit <guymguym@gmail.com>
Co-authored-by: default avatardavid-perez <d@vidp.dev>
parent 907c0f3f
Loading
Loading
Loading
Loading
+10 −2
Original line number Diff line number Diff line
@@ -18,6 +18,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpPro
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape

/**
@@ -39,6 +41,7 @@ class ServerOperationHandlerGenerator(
        "PinProjectLite" to ServerCargoDependency.PinProjectLite.asType(),
        "Tower" to ServerCargoDependency.Tower.asType(),
        "FuturesUtil" to ServerCargoDependency.FuturesUtil.asType(),
        "SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType(),
        "SmithyHttpServer" to CargoDependency.SmithyHttpServer(runtimeConfig).asType(),
        "SmithyRejection" to ServerHttpProtocolGenerator.smithyRejection(runtimeConfig),
        "Phantom" to ServerRuntimeType.Phantom,
@@ -132,13 +135,18 @@ class ServerOperationHandlerGenerator(
        } else {
            symbolProvider.toSymbol(operation.outputShape(model)).fullName
        }
        val streamingBodyTraitBounds = if (operation.inputShape(model).hasStreamingMember(model)) {
            "\n B: Into<#{SmithyHttp}::byte_stream::ByteStream>,"
        } else {
            ""
        }
        return """
            $inputFn
            Fut: std::future::Future<Output = $outputType> + Send,
            B: $serverCrate::HttpBody + Send + 'static,
            B: $serverCrate::HttpBody + Send + 'static, $streamingBodyTraitBounds
            B::Data: Send,
            B::Error: Into<$serverCrate::BoxError>,
            $serverCrate::rejection::SmithyRejection: From<<B as $serverCrate::HttpBody>::Error>
        """
        """.trimIndent()
    }
}
+94 −72
Original line number Diff line number Diff line
@@ -39,6 +39,7 @@ import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.isStreaming
import software.amazon.smithy.rust.codegen.util.orNull
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase
@@ -326,15 +327,37 @@ class ServerProtocolTestGenerator(
            """
            use #{AxumCore}::extract::FromRequest;
            let mut http_request = #{AxumCore}::extract::RequestParts::new(http_request);
            let input_wrapper = super::$operationName::from_request(&mut http_request).await.expect("failed to parse request");
            let input = input_wrapper.0;
            let parsed = super::$operationName::from_request(&mut http_request).await.expect("failed to parse request").0;
            """,
            *codegenScope,
        )
        if (operationShape.outputShape(model).hasStreamingMember(model)) {
            rustWriter.rust("""todo!("streaming types aren't supported yet");""")

        if (inputShape.hasStreamingMember(model)) {
            // A streaming shape does not implement `PartialEq`, so we have to iterate over the input shape's members
            // and handle the equality assertion separately.
            for (member in inputShape.members()) {
                val memberName = codegenContext.symbolProvider.toMemberName(member)
                if (member.isStreaming(codegenContext.model)) {
                    rustWriter.rustTemplate(
                        """
                        #{AssertEq}(
                            parsed.$memberName.collect().await.unwrap().into_bytes(),
                            expected.$memberName.collect().await.unwrap().into_bytes()
                        );
                        """,
                        *codegenScope
                    )
                } else {
                    rustWriter.rustTemplate(
                        """
                        #{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`");
                        """,
                        *codegenScope
                    )
                }
            }
        } else {
            rustWriter.rustTemplate("#{AssertEq}(input, expected);", *codegenScope)
            rustWriter.rustTemplate("#{AssertEq}(parsed, expected);", *codegenScope)
        }
    }

@@ -511,16 +534,7 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "RestJsonNoInputAndNoOutput", Action.Response),
            FailingTest(RestJson, "RestJsonNoInputAndOutputWithJson", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Request),
            FailingTest(RestJson, "RestJsonStreamingTraitsWithBlob", Action.Request),
            FailingTest(RestJson, "RestJsonStreamingTraitsWithNoBlobBody", Action.Request),
            FailingTest(RestJson, "RestJsonStreamingTraitsWithBlob", Action.Response),
            FailingTest(RestJson, "RestJsonStreamingTraitsWithNoBlobBody", Action.Response),
            FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithBlob", Action.Request),
            FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithNoBlobBody", Action.Request),
            FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithBlob", Action.Response),
            FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithNoBlobBody", Action.Response),
            FailingTest(RestJson, "RestJsonStreamingTraitsWithMediaTypeWithBlob", Action.Request),
            FailingTest(RestJson, "RestJsonStreamingTraitsWithMediaTypeWithBlob", Action.Response),
            FailingTest(RestJson, "RestJsonHttpWithEmptyBlobPayload", Action.Request),
            FailingTest(RestJson, "RestJsonHttpWithEmptyStructurePayload", Action.Request),

@@ -592,7 +606,9 @@ class ServerProtocolTestGenerator(
            ).build()
        private fun fixRestJsonAllQueryStringTypes(testCase: HttpRequestTestCase): HttpRequestTestCase =
            testCase.toBuilder().params(
                 Node.parse("""{
                Node.parse(
                    """
                    {
                        "queryString": "Hello there",
                        "queryStringList": ["a", "b", "c"],
                        "queryStringSet": ["a", "b", "c"],
@@ -631,16 +647,22 @@ class ServerProtocolTestGenerator(
                            "Enum": ["Foo"],
                            "EnumList": ["Foo", "Baz", "Bar"]
                        }
                }""".trimMargin()).asObjectNode().get()
                    }
                    """.trimMargin()
                ).asObjectNode().get()
            ).build()
        private fun fixRestJsonQueryStringEscaping(testCase: HttpRequestTestCase): HttpRequestTestCase =
            testCase.toBuilder().params(
                Node.parse("""{
                Node.parse(
                    """
                    {
                        "queryString": "%:/?#[]@!${'$'}&'()*+,;=😹",
                        "queryParamsMapOfStringList": {
                            "String": ["%:/?#[]@!${'$'}&'()*+,;=😹"]
                        }
               }""".trimMargin()).asObjectNode().get()
                    }
                    """.trimMargin()
                ).asObjectNode().get()
            ).build()
        // This test assumes that errors in responses are identified by an `X-Amzn-Errortype` header with the error shape name.
        // However, Smithy specifications for AWS protocols that serialize to JSON recommend that new server implementations
+87 −72
Original line number Diff line number Diff line
@@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.node.ExpectationNotMetException
import software.amazon.smithy.model.shapes.CollectionShape
@@ -55,6 +54,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredData
import software.amazon.smithy.rust.codegen.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.findStreamingMember
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.inputShape
@@ -119,6 +119,7 @@ private class ServerHttpProtocolImplGenerator(
        "PercentEncoding" to CargoDependency.PercentEncoding.asType(),
        "Regex" to CargoDependency.Regex.asType(),
        "SerdeUrlEncoded" to ServerCargoDependency.SerdeUrlEncoded.asType(),
        "SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType(),
        "SmithyHttpServer" to CargoDependency.SmithyHttpServer(runtimeConfig).asType(),
        "SmithyRejection" to ServerHttpProtocolGenerator.smithyRejection(runtimeConfig),
        "http" to RuntimeType.http
@@ -132,13 +133,12 @@ private class ServerHttpProtocolImplGenerator(
    }

    /*
     * Generation of `FromRequest` and `IntoResponse`. They are currently only implemented for non-streaming request
     * and response bodies, that is, models without streaming traits
     * (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html).
     * For non-streaming request bodies, we require the HTTP body to be fully read in memory before parsing or
     * deserialization. From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize
     * Generation of `FromRequest` and `IntoResponse`.
     * For non-streaming request bodies, that is, models without streaming traits
     * (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html)
     * we require the HTTP body to be fully read in memory before parsing or deserialization.
     * From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize
     * an HTTP response to `Bytes`.
     * TODO Add support for streaming.
     * These traits are the public entrypoint of the ser/de logic of the `aws-smithy-http-server` server.
     */
    private fun RustWriter.renderTraits(
@@ -147,38 +147,24 @@ private class ServerHttpProtocolImplGenerator(
        operationShape: OperationShape
    ) {
        val operationName = symbolProvider.toSymbol(operationShape).name
        // Implement Axum `FromRequest` trait for input types.
        val inputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"

        val fromRequest = if (operationShape.inputShape(model).hasStreamingMember(model)) {
            // For streaming request bodies, we need to generate a different implementation of the `FromRequest` trait.
            // It will first offer the streaming input to the parser and potentially read the body into memory
            // if an error occurred or if the streaming parser indicates that it needs the full data to proceed.
            """
            async fn from_request(_req: &mut #{AxumCore}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
                todo!("Streaming support for input shapes is not yet supported in `smithy-rs`")
            }
            """.trimIndent()
        } else {
            """
            async fn from_request(req: &mut #{AxumCore}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
                Ok($inputName(#{parse_request}(req).await?))
            }
            """.trimIndent()
        }
        // Implement Axum `FromRequest` trait for input types.
        rustTemplate(
            """
            pub struct $inputName(pub #{I});
            ##[#{AsyncTrait}::async_trait]
            impl<B> #{AxumCore}::extract::FromRequest<B> for $inputName
            where
                B: #{SmithyHttpServer}::HttpBody + Send,
                B: #{SmithyHttpServer}::HttpBody + Send, ${getStreamingBodyTraitBounds(operationShape)}
                B::Data: Send,
                B::Error: Into<#{SmithyHttpServer}::BoxError>,
                #{SmithyRejection}: From<<B as #{SmithyHttpServer}::HttpBody>::Error>
            {
                type Rejection = #{SmithyRejection};
                $fromRequest
                async fn from_request(req: &mut #{AxumCore}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
                    Ok($inputName(#{parse_request}(req).await?))
                }
            }
            """.trimIndent(),
            *codegenScope,
@@ -187,21 +173,19 @@ private class ServerHttpProtocolImplGenerator(
        )

        // Implement Axum `IntoResponse` for output types.

        val outputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}"
        val errorSymbol = operationShape.errorSymbol(symbolProvider)

        val httpExtensions = setHttpExtensions(operationShape)
        // For streaming response bodies, we need to generate a different implementation of the `IntoResponse` trait.
        // The body type will have to be a `StreamBody`. The service implementer will return a `Stream` from their handler.
        val intoResponseStreaming = "todo!(\"Streaming support for output shapes is not yet supported in `smithy-rs`\")"

        if (operationShape.errors.isNotEmpty()) {
            val intoResponseImpl = if (operationShape.outputShape(model).hasStreamingMember(model)) {
                intoResponseStreaming
            } else {
            // The output of fallible operations is a `Result` which we convert into an
            // isomorphic `enum` type we control that can in turn be converted into a response.
            val intoResponseImpl =
                """
                let mut response = match self {
                    Self::Output(o) => {
                        match #{serialize_response}(&o) {
                        match #{serialize_response}(o) {
                            Ok(response) => response,
                            Err(e) => {
                                e.into_response()
@@ -223,9 +207,7 @@ private class ServerHttpProtocolImplGenerator(
                $httpExtensions
                response
                """.trimIndent()
            }
            // The output of fallible operations is a `Result` which we convert into an isomorphic `enum` type we control
            // that can in turn be converted into a response.

            rustTemplate(
                """
                pub enum $outputName {
@@ -246,27 +228,25 @@ private class ServerHttpProtocolImplGenerator(
                "serialize_error" to serverSerializeError(operationShape)
            )
        } else {
            val handleSerializeOutput = if (operationShape.outputShape(model).hasStreamingMember(model)) {
                intoResponseStreaming
            } else {
            // The output of non-fallible operations is a model type which we convert into
            // a "wrapper" unit `struct` type we control that can in turn be converted into a response.
            val intoResponseImpl =
                """
                let mut response = match #{serialize_response}(&self.0) {
                let mut response = match #{serialize_response}(self.0) {
                    Ok(response) => response,
                    Err(e) => e.into_response()
                };
                $httpExtensions
                response
                """.trimIndent()
            }
            // The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type
            // we control that can in turn be converted into a response.

            rustTemplate(
                """
                pub struct $outputName(pub #{O});
                ##[#{AsyncTrait}::async_trait]
                impl #{AxumCore}::response::IntoResponse for $outputName {
                    fn into_response(self) -> #{AxumCore}::response::Response {
                        $handleSerializeOutput
                        $intoResponseImpl
                    }
                }
                """.trimIndent(),
@@ -335,6 +315,7 @@ private class ServerHttpProtocolImplGenerator(
        val inputSymbol = symbolProvider.toSymbol(inputShape)
        val includedMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT)
        val unusedVars = if (includedMembers.isEmpty()) "##[allow(unused_variables)] " else ""

        return RuntimeType.forInlineFun(fnName, operationDeserModule) {
            Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it)
            it.rustBlockTemplate(
@@ -346,11 +327,11 @@ private class ServerHttpProtocolImplGenerator(
                    #{SmithyRejection}
                >
                where
                    B: #{SmithyHttpServer}::HttpBody + Send,
                    B: #{SmithyHttpServer}::HttpBody + Send, ${getStreamingBodyTraitBounds(operationShape)}
                    B::Data: Send,
                    B::Error: Into<#{SmithyHttpServer}::BoxError>,
                    #{SmithyRejection}: From<<B as #{SmithyHttpServer}::HttpBody>::Error>
                """,
                """.trimIndent(),
                *codegenScope,
                "I" to inputSymbol,
            ) {
@@ -371,8 +352,12 @@ private class ServerHttpProtocolImplGenerator(
        val outputSymbol = symbolProvider.toSymbol(outputShape)
        return RuntimeType.forInlineFun(fnName, operationSerModule) {
            Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it)

            // Note we only need to take ownership of the output in the case that it contains streaming members.
            // However we currently always take ownership here, but worth noting in case in the future we want
            // to generate different signatures for streaming vs non-streaming for some reason.
            it.rustBlockTemplate(
                "pub fn $fnName(output: &#{O}) -> std::result::Result<#{AxumCore}::response::Response, #{SmithyRejection}>",
                "pub fn $fnName(output: #{O}) -> std::result::Result<#{AxumCore}::response::Response, #{SmithyRejection}>",
                *codegenScope,
                "O" to outputSymbol,
            ) {
@@ -459,13 +444,6 @@ private class ServerHttpProtocolImplGenerator(
        operationShape: OperationShape,
        bindings: List<HttpBindingDescriptor>,
    ) {
        val structuredDataSerializer = protocol.structuredDataSerializer(operationShape)
        structuredDataSerializer.serverOutputSerializer(operationShape)?.let { serializer ->
            rust(
                "let payload = #T(output)?;",
                serializer
            )
        } ?: rust("""let payload = "";""")
        // avoid non-usage warnings for response
        Attribute.AllowUnusedMut.render(this)
        rustTemplate("let mut builder = #{http}::Response::builder();", *codegenScope)
@@ -477,6 +455,24 @@ private class ServerHttpProtocolImplGenerator(
                serializedValue(this)
            }
        }
        val streamingMember = operationShape.outputShape(model).findStreamingMember(model)
        if (streamingMember != null) {
            val memberName = symbolProvider.toMemberName(streamingMember)
            rustTemplate(
                """
                let payload = #{SmithyHttpServer}::body::Body::wrap_stream(output.$memberName);
                """,
                *codegenScope,
            )
        } else {
            val structuredDataSerializer = protocol.structuredDataSerializer(operationShape)
            structuredDataSerializer.serverOutputSerializer(operationShape)?.let { serializer ->
                rust(
                    "let payload = #T(&output)?;",
                    serializer
                )
            } ?: rust("""let payload = "";""")
        }
        rustTemplate(
            """
            builder.body(#{SmithyHttpServer}::body::to_boxed(payload))?
@@ -512,9 +508,11 @@ private class ServerHttpProtocolImplGenerator(
        val bindingGenerator = ServerResponseBindingGenerator(protocol, codegenContext, operationShape)
        val addHeadersFn = bindingGenerator.generateAddHeadersFn(errorShape ?: operationShape)
        if (addHeadersFn != null) {
            // notice that we need to borrow the output only for output shapes but not for error shapes
            val outputOwnedOrBorrow = if (errorShape == null) "&output" else "output"
            rust(
                """
                builder = #{T}(output, builder)?;
                builder = #{T}($outputOwnedOrBorrow, builder)?;
                """.trimIndent(),
                addHeadersFn
            )
@@ -528,12 +526,11 @@ private class ServerHttpProtocolImplGenerator(
        val operationName = symbolProvider.toSymbol(operationShape).name
        val member = binding.member
        return when (binding.location) {
            HttpLocation.HEADER, HttpLocation.PREFIX_HEADERS, HttpLocation.DOCUMENT -> {
                // All of these are handled separately.
                null
            }
            HttpLocation.HEADER,
            HttpLocation.PREFIX_HEADERS,
            HttpLocation.DOCUMENT,
            HttpLocation.PAYLOAD -> {
                logger.warning("[rust-server-codegen] $operationName: response serialization does not currently support ${binding.location} bindings")
                // All of these are handled separately.
                null
            }
            HttpLocation.RESPONSE_CODE -> writable {
@@ -608,6 +605,19 @@ private class ServerHttpProtocolImplGenerator(
        return when (binding.location) {
            HttpLocation.HEADER -> writable { serverRenderHeaderParser(this, binding, operationShape) }
            HttpLocation.PAYLOAD -> {
                return if (binding.member.isStreaming(model)) {
                    writable {
                        rustTemplate(
                            """
                            {
                                let body = request.take_body().ok_or(#{SmithyHttpServer}::rejection::BodyAlreadyExtracted)?;
                                Some(body.into())
                            }
                            """.trimIndent(),
                            *codegenScope
                        )
                    }
                } else {
                    val structureShapeHandler: RustWriter.(String) -> Unit = { body ->
                        rust("#T($body)", structuredDataParser.payloadParser(binding.member))
                    }
@@ -617,9 +627,6 @@ private class ServerHttpProtocolImplGenerator(
                        errorSymbol,
                        structuredHandler = structureShapeHandler
                    )
                return if (binding.member.isStreaming(model)) {
                    writable { rust("""todo!("streaming request bodies");""") }
                } else {
                    writable {
                        rustTemplate(
                            """
@@ -1047,4 +1054,12 @@ private class ServerHttpProtocolImplGenerator(
            }
        }
    }

    private fun getStreamingBodyTraitBounds(operationShape: OperationShape): String {
        if (operationShape.inputShape(model).hasStreamingMember(model)) {
            return "\n B: Into<#{SmithyHttp}::byte_stream::ByteStream>,"
        } else {
            return ""
        }
    }
}
+1 −1
Original line number Diff line number Diff line
@@ -26,7 +26,7 @@ bytes = "1.1"
futures-util = { version = "0.3", default-features = false }
http = "0.2"
http-body = "0.4"
hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp"] }
hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp", "stream"] }
mime = "0.3"
nom = "7"
pin-project-lite = "0.2"
+6 −0
Original line number Diff line number Diff line
@@ -178,6 +178,12 @@ impl From<aws_smithy_types::date_time::DateTimeParseError> for SmithyRejection {
    }
}

impl From<aws_smithy_types::date_time::DateTimeFormatError> for SmithyRejection {
    fn from(err: aws_smithy_types::date_time::DateTimeFormatError) -> Self {
        SmithyRejection::Serialize(Serialize::from_err(err))
    }
}

impl From<aws_smithy_types::primitive::PrimitiveParseError> for SmithyRejection {
    fn from(err: aws_smithy_types::primitive::PrimitiveParseError) -> Self {
        SmithyRejection::Deserialize(Deserialize::from_err(err))
Loading