Unverified Commit 4ccdadde authored by david-perez's avatar david-perez Committed by GitHub
Browse files

`rust-server-codegen`: fix deserialization of escaped query string data (#1058)

We are currently deserializing the query string into `Vec<(&str,
&str)>`. `serde_urlencoded` panics if the input string slice contains
escaped data, since in that case it needs to allocate a new `String` to
unescape the input string slice's contents.

Instead of deserializing to `Vec<(String, String)>`, we can instead use
`Cow<'a, str>` so that deserialization only allocates when strictly
required.

Reference: https://github.com/serde-rs/serde/issues/1413#issuecomment-494892266
parent 7b061f40
Loading
Loading
Loading
Loading
+1 −0
Original line number Original line Diff line number Diff line
@@ -20,6 +20,7 @@ object ServerRuntimeType {
        RuntimeType(inlineDependency.name, inlineDependency, namespace = "crate")
        RuntimeType(inlineDependency.name, inlineDependency, namespace = "crate")


    val Phantom = RuntimeType("PhantomData", dependency = null, namespace = "std::marker")
    val Phantom = RuntimeType("PhantomData", dependency = null, namespace = "std::marker")
    val Cow = RuntimeType("Cow", dependency = null, namespace = "std::borrow")


    fun Router(runtimeConfig: RuntimeConfig) =
    fun Router(runtimeConfig: RuntimeConfig) =
        RuntimeType("Router", CargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::routing")
        RuntimeType("Router", CargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::routing")
+45 −2
Original line number Original line Diff line number Diff line
@@ -441,7 +441,6 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "RestJsonOutputUnionWithUnitMember", Action.Response),
            FailingTest(RestJson, "RestJsonOutputUnionWithUnitMember", Action.Response),
            FailingTest(RestJson, "RestJsonUnitInputAllowsAccept", Action.Request),
            FailingTest(RestJson, "RestJsonUnitInputAllowsAccept", Action.Request),
            FailingTest(RestJson, "RestJsonUnitInputAndOutputNoOutput", Action.Response),
            FailingTest(RestJson, "RestJsonUnitInputAndOutputNoOutput", Action.Response),
            FailingTest(RestJson, "RestJsonAllQueryStringTypes", Action.Request),
            FailingTest(RestJson, "RestJsonQueryStringEscaping", Action.Request),
            FailingTest(RestJson, "RestJsonQueryStringEscaping", Action.Request),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatQueryValues", Action.Request),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatQueryValues", Action.Request),
            FailingTest(RestJson, "DocumentOutput", Action.Response),
            FailingTest(RestJson, "DocumentOutput", Action.Response),
@@ -623,6 +622,49 @@ class ServerProtocolTestGenerator(
                    """.trimMargin()
                    """.trimMargin()
                ).asObjectNode().get()
                ).asObjectNode().get()
            ).build()
            ).build()
        private fun fixRestJsonAllQueryStringTypes(testCase: HttpRequestTestCase): HttpRequestTestCase =
             testCase.toBuilder().params(
                 Node.parse("""{
                    "queryString": "Hello there",
                    "queryStringList": ["a", "b", "c"],
                    "queryStringSet": ["a", "b", "c"],
                    "queryByte": 1,
                    "queryShort": 2,
                    "queryInteger": 3,
                    "queryIntegerList": [1, 2, 3],
                    "queryIntegerSet": [1, 2, 3],
                    "queryLong": 4,
                    "queryFloat": 1.1,
                    "queryDouble": 1.1,
                    "queryDoubleList": [1.1, 2.1, 3.1],
                    "queryBoolean": true,
                    "queryBooleanList": [true, false, true],
                    "queryTimestamp": 1,
                    "queryTimestampList": [1, 2, 3],
                    "queryEnum": "Foo",
                    "queryEnumList": ["Foo", "Baz", "Bar"],
                    "queryParamsMapOfStringList": {
                        "String": ["Hello there"],
                        "StringList": ["a", "b", "c"],
                        "StringSet": ["a", "b", "c"],
                        "Byte": ["1"],
                        "Short": ["2"],
                        "Integer": ["3"],
                        "IntegerList": ["1", "2", "3"],
                        "IntegerSet": ["1", "2", "3"],
                        "Long": ["4"],
                        "Float": ["1.1"],
                        "Double": ["1.1"],
                        "DoubleList": ["1.1", "2.1", "3.1"],
                        "Boolean": ["true"],
                        "BooleanList": ["true", "false", "true"],
                        "Timestamp": ["1970-01-01T00:00:01Z"],
                        "TimestampList": ["1970-01-01T00:00:01Z", "1970-01-01T00:00:02Z", "1970-01-01T00:00:03Z"],
                        "Enum": ["Foo"],
                        "EnumList": ["Foo", "Baz", "Bar"]
                    }
                }""".trimMargin()).asObjectNode().get()
             ).build()


        // These are tests whose definitions in the `awslabs/smithy` repository are wrong.
        // These are tests whose definitions in the `awslabs/smithy` repository are wrong.
        // This is because they have not been written from a server perspective, and as such the expected `params` field is incomplete.
        // This is because they have not been written from a server perspective, and as such the expected `params` field is incomplete.
@@ -631,7 +673,8 @@ class ServerProtocolTestGenerator(
            // https://github.com/awslabs/smithy/pull/1040
            // https://github.com/awslabs/smithy/pull/1040
            Pair(RestJson, "RestJsonSupportsNaNFloatQueryValues") to ::fixRestJsonSupportsNaNFloatQueryValues,
            Pair(RestJson, "RestJsonSupportsNaNFloatQueryValues") to ::fixRestJsonSupportsNaNFloatQueryValues,
            Pair(RestJson, "RestJsonSupportsInfinityFloatQueryValues") to ::fixRestJsonSupportsInfinityFloatQueryValues,
            Pair(RestJson, "RestJsonSupportsInfinityFloatQueryValues") to ::fixRestJsonSupportsInfinityFloatQueryValues,
            Pair(RestJson, "RestJsonSupportsNegativeInfinityFloatQueryValues") to ::fixRestJsonSupportsNegativeInfinityFloatQueryValues
            Pair(RestJson, "RestJsonSupportsNegativeInfinityFloatQueryValues") to ::fixRestJsonSupportsNegativeInfinityFloatQueryValues,
            Pair(RestJson, "RestJsonAllQueryStringTypes") to ::fixRestJsonAllQueryStringTypes
        )
        )
    }
    }
}
}
+8 −6
Original line number Original line Diff line number Diff line
@@ -32,6 +32,7 @@ import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
@@ -106,6 +107,7 @@ private class ServerHttpProtocolImplGenerator(
    private val codegenScope = arrayOf(
    private val codegenScope = arrayOf(
        "AsyncTrait" to ServerCargoDependency.AsyncTrait.asType(),
        "AsyncTrait" to ServerCargoDependency.AsyncTrait.asType(),
        "AxumCore" to ServerCargoDependency.AxumCore.asType(),
        "AxumCore" to ServerCargoDependency.AxumCore.asType(),
        "Cow" to ServerRuntimeType.Cow,
        "DateTime" to RuntimeType.DateTime(runtimeConfig),
        "DateTime" to RuntimeType.DateTime(runtimeConfig),
        "HttpBody" to CargoDependency.HttpBody.asType(),
        "HttpBody" to CargoDependency.HttpBody.asType(),
        "Hyper" to CargoDependency.Hyper.asType(),
        "Hyper" to CargoDependency.Hyper.asType(),
@@ -116,7 +118,7 @@ private class ServerHttpProtocolImplGenerator(
        "SerdeUrlEncoded" to ServerCargoDependency.SerdeUrlEncoded.asType(),
        "SerdeUrlEncoded" to ServerCargoDependency.SerdeUrlEncoded.asType(),
        "SmithyHttpServer" to CargoDependency.SmithyHttpServer(runtimeConfig).asType(),
        "SmithyHttpServer" to CargoDependency.SmithyHttpServer(runtimeConfig).asType(),
        "SmithyRejection" to ServerHttpProtocolGenerator.smithyRejection(runtimeConfig),
        "SmithyRejection" to ServerHttpProtocolGenerator.smithyRejection(runtimeConfig),
        "http" to RuntimeType.http,
        "http" to RuntimeType.http
    )
    )


    override fun generateTraitImpls(operationWriter: RustWriter, operationShape: OperationShape) {
    override fun generateTraitImpls(operationWriter: RustWriter, operationShape: OperationShape) {
@@ -749,7 +751,7 @@ private class ServerHttpProtocolImplGenerator(
            rustTemplate(
            rustTemplate(
                """
                """
                let query_string = request.uri().query().ok_or(#{SmithyHttpServer}::rejection::MissingQueryString)?;
                let query_string = request.uri().query().ok_or(#{SmithyHttpServer}::rejection::MissingQueryString)?;
                let pairs = #{SerdeUrlEncoded}::from_str::<Vec<(&str, &str)>>(query_string)?;
                let pairs = #{SerdeUrlEncoded}::from_str::<Vec<(#{Cow}<'_, str>, #{Cow}<'_, str>)>>(query_string)?;
                """.trimIndent(),
                """.trimIndent(),
                *codegenScope
                *codegenScope
            )
            )
@@ -778,7 +780,7 @@ private class ServerHttpProtocolImplGenerator(
                        """
                        """
                        if !seen_${memberName} && k == "${it.locationName}" {
                        if !seen_${memberName} && k == "${it.locationName}" {
                            input = input.${it.member.setterName()}(
                            input = input.${it.member.setterName()}(
                                #{deserializer}(v)?
                                #{deserializer}(&v)?
                            );
                            );
                            seen_${memberName} = true;
                            seen_${memberName} = true;
                        }
                        }
@@ -798,7 +800,7 @@ private class ServerHttpProtocolImplGenerator(
                                //     * `String` in case it doesn't.
                                //     * `String` in case it doesn't.
                                rustTemplate(
                                rustTemplate(
                                    """
                                    """
                                    let v = <_>::from(#{PercentEncoding}::percent_decode_str(v).decode_utf8()?.as_ref());
                                    let v = <_>::from(#{PercentEncoding}::percent_decode_str(&v).decode_utf8()?.as_ref());
                                    """.trimIndent(),
                                    """.trimIndent(),
                                    *codegenScope
                                    *codegenScope
                                )
                                )
@@ -814,7 +816,7 @@ private class ServerHttpProtocolImplGenerator(
                                val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
                                val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
                                rustTemplate(
                                rustTemplate(
                                    """
                                    """
                                    let v = #{PercentEncoding}::percent_decode_str(v).decode_utf8()?;
                                    let v = #{PercentEncoding}::percent_decode_str(&v).decode_utf8()?;
                                    let v = #{DateTime}::from_str(&v, #{format})?;
                                    let v = #{DateTime}::from_str(&v, #{format})?;
                                    """.trimIndent(),
                                    """.trimIndent(),
                                    *codegenScope,
                                    *codegenScope,
@@ -824,7 +826,7 @@ private class ServerHttpProtocolImplGenerator(
                            else -> { // Number or boolean.
                            else -> { // Number or boolean.
                                rust(
                                rust(
                                    """
                                    """
                                    let v = <_ as #T>::parse_smithy_primitive(v)?;
                                    let v = <_ as #T>::parse_smithy_primitive(&v)?;
                                    """.trimIndent(),
                                    """.trimIndent(),
                                    CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Parse")
                                    CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Parse")
                                )
                                )