Unverified Commit 4ccdadde authored by david-perez's avatar david-perez Committed by GitHub
Browse files

`rust-server-codegen`: fix deserialization of escaped query string data (#1058)

We are currently deserializing the query string into `Vec<(&str,
&str)>`. `serde_urlencoded` panics if the input string slice contains
escaped data, since in that case it needs to allocate a new `String` to
unescape the input string slice's contents.

Instead of deserializing to `Vec<(String, String)>`, we can instead use
`Cow<'a, str>` so that deserialization only allocates when strictly
required.

Reference: https://github.com/serde-rs/serde/issues/1413#issuecomment-494892266
parent 7b061f40
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@ object ServerRuntimeType {
        RuntimeType(inlineDependency.name, inlineDependency, namespace = "crate")

    val Phantom = RuntimeType("PhantomData", dependency = null, namespace = "std::marker")
    val Cow = RuntimeType("Cow", dependency = null, namespace = "std::borrow")

    fun Router(runtimeConfig: RuntimeConfig) =
        RuntimeType("Router", CargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::routing")
+45 −2
Original line number Diff line number Diff line
@@ -441,7 +441,6 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "RestJsonOutputUnionWithUnitMember", Action.Response),
            FailingTest(RestJson, "RestJsonUnitInputAllowsAccept", Action.Request),
            FailingTest(RestJson, "RestJsonUnitInputAndOutputNoOutput", Action.Response),
            FailingTest(RestJson, "RestJsonAllQueryStringTypes", Action.Request),
            FailingTest(RestJson, "RestJsonQueryStringEscaping", Action.Request),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatQueryValues", Action.Request),
            FailingTest(RestJson, "DocumentOutput", Action.Response),
@@ -623,6 +622,49 @@ class ServerProtocolTestGenerator(
                    """.trimMargin()
                ).asObjectNode().get()
            ).build()
        private fun fixRestJsonAllQueryStringTypes(testCase: HttpRequestTestCase): HttpRequestTestCase =
             testCase.toBuilder().params(
                 Node.parse("""{
                    "queryString": "Hello there",
                    "queryStringList": ["a", "b", "c"],
                    "queryStringSet": ["a", "b", "c"],
                    "queryByte": 1,
                    "queryShort": 2,
                    "queryInteger": 3,
                    "queryIntegerList": [1, 2, 3],
                    "queryIntegerSet": [1, 2, 3],
                    "queryLong": 4,
                    "queryFloat": 1.1,
                    "queryDouble": 1.1,
                    "queryDoubleList": [1.1, 2.1, 3.1],
                    "queryBoolean": true,
                    "queryBooleanList": [true, false, true],
                    "queryTimestamp": 1,
                    "queryTimestampList": [1, 2, 3],
                    "queryEnum": "Foo",
                    "queryEnumList": ["Foo", "Baz", "Bar"],
                    "queryParamsMapOfStringList": {
                        "String": ["Hello there"],
                        "StringList": ["a", "b", "c"],
                        "StringSet": ["a", "b", "c"],
                        "Byte": ["1"],
                        "Short": ["2"],
                        "Integer": ["3"],
                        "IntegerList": ["1", "2", "3"],
                        "IntegerSet": ["1", "2", "3"],
                        "Long": ["4"],
                        "Float": ["1.1"],
                        "Double": ["1.1"],
                        "DoubleList": ["1.1", "2.1", "3.1"],
                        "Boolean": ["true"],
                        "BooleanList": ["true", "false", "true"],
                        "Timestamp": ["1970-01-01T00:00:01Z"],
                        "TimestampList": ["1970-01-01T00:00:01Z", "1970-01-01T00:00:02Z", "1970-01-01T00:00:03Z"],
                        "Enum": ["Foo"],
                        "EnumList": ["Foo", "Baz", "Bar"]
                    }
                }""".trimMargin()).asObjectNode().get()
             ).build()

        // These are tests whose definitions in the `awslabs/smithy` repository are wrong.
        // This is because they have not been written from a server perspective, and as such the expected `params` field is incomplete.
@@ -631,7 +673,8 @@ class ServerProtocolTestGenerator(
            // https://github.com/awslabs/smithy/pull/1040
            Pair(RestJson, "RestJsonSupportsNaNFloatQueryValues") to ::fixRestJsonSupportsNaNFloatQueryValues,
            Pair(RestJson, "RestJsonSupportsInfinityFloatQueryValues") to ::fixRestJsonSupportsInfinityFloatQueryValues,
            Pair(RestJson, "RestJsonSupportsNegativeInfinityFloatQueryValues") to ::fixRestJsonSupportsNegativeInfinityFloatQueryValues
            Pair(RestJson, "RestJsonSupportsNegativeInfinityFloatQueryValues") to ::fixRestJsonSupportsNegativeInfinityFloatQueryValues,
            Pair(RestJson, "RestJsonAllQueryStringTypes") to ::fixRestJsonAllQueryStringTypes
        )
    }
}
+8 −6
Original line number Diff line number Diff line
@@ -32,6 +32,7 @@ import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.rustlang.writable
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
@@ -106,6 +107,7 @@ private class ServerHttpProtocolImplGenerator(
    private val codegenScope = arrayOf(
        "AsyncTrait" to ServerCargoDependency.AsyncTrait.asType(),
        "AxumCore" to ServerCargoDependency.AxumCore.asType(),
        "Cow" to ServerRuntimeType.Cow,
        "DateTime" to RuntimeType.DateTime(runtimeConfig),
        "HttpBody" to CargoDependency.HttpBody.asType(),
        "Hyper" to CargoDependency.Hyper.asType(),
@@ -116,7 +118,7 @@ private class ServerHttpProtocolImplGenerator(
        "SerdeUrlEncoded" to ServerCargoDependency.SerdeUrlEncoded.asType(),
        "SmithyHttpServer" to CargoDependency.SmithyHttpServer(runtimeConfig).asType(),
        "SmithyRejection" to ServerHttpProtocolGenerator.smithyRejection(runtimeConfig),
        "http" to RuntimeType.http,
        "http" to RuntimeType.http
    )

    override fun generateTraitImpls(operationWriter: RustWriter, operationShape: OperationShape) {
@@ -749,7 +751,7 @@ private class ServerHttpProtocolImplGenerator(
            rustTemplate(
                """
                let query_string = request.uri().query().ok_or(#{SmithyHttpServer}::rejection::MissingQueryString)?;
                let pairs = #{SerdeUrlEncoded}::from_str::<Vec<(&str, &str)>>(query_string)?;
                let pairs = #{SerdeUrlEncoded}::from_str::<Vec<(#{Cow}<'_, str>, #{Cow}<'_, str>)>>(query_string)?;
                """.trimIndent(),
                *codegenScope
            )
@@ -778,7 +780,7 @@ private class ServerHttpProtocolImplGenerator(
                        """
                        if !seen_${memberName} && k == "${it.locationName}" {
                            input = input.${it.member.setterName()}(
                                #{deserializer}(v)?
                                #{deserializer}(&v)?
                            );
                            seen_${memberName} = true;
                        }
@@ -798,7 +800,7 @@ private class ServerHttpProtocolImplGenerator(
                                //     * `String` in case it doesn't.
                                rustTemplate(
                                    """
                                    let v = <_>::from(#{PercentEncoding}::percent_decode_str(v).decode_utf8()?.as_ref());
                                    let v = <_>::from(#{PercentEncoding}::percent_decode_str(&v).decode_utf8()?.as_ref());
                                    """.trimIndent(),
                                    *codegenScope
                                )
@@ -814,7 +816,7 @@ private class ServerHttpProtocolImplGenerator(
                                val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
                                rustTemplate(
                                    """
                                    let v = #{PercentEncoding}::percent_decode_str(v).decode_utf8()?;
                                    let v = #{PercentEncoding}::percent_decode_str(&v).decode_utf8()?;
                                    let v = #{DateTime}::from_str(&v, #{format})?;
                                    """.trimIndent(),
                                    *codegenScope,
@@ -824,7 +826,7 @@ private class ServerHttpProtocolImplGenerator(
                            else -> { // Number or boolean.
                                rust(
                                    """
                                    let v = <_ as #T>::parse_smithy_primitive(v)?;
                                    let v = <_ as #T>::parse_smithy_primitive(&v)?;
                                    """.trimIndent(),
                                    CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Parse")
                                )