Unverified Commit 8fe1083f authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Deserialize empty collections from the query string into `None` (#1050)

We were instead providing empty collections e.g. `Some(Vec::new())`,
which makes protocol tests fail.

This commit also passes Smithy member names through the symbol provider
to obtain safe variable names.

This commit also adds some plumbing to the server protocol test case
generation to fix broken request test definitions in the
`awslabs/smithy` project.
parent 15eb1c08
Loading
Loading
Loading
Loading
+70 −4
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
@@ -95,7 +96,7 @@ class ServerProtocolTestGenerator(
            val testCases = error.getTrait<HttpResponseTestsTrait>()?.testCases.orEmpty()
            testCases.map { TestCase.ResponseTest(it, error) }
        }
        val allTests: List<TestCase> = (requestTests + responseTests + errorTests).filterMatching()
        val allTests: List<TestCase> = (requestTests + responseTests + errorTests).filterMatching().fixBroken()

        if (allTests.isNotEmpty()) {
            val operationName = operationSymbol.name
@@ -138,6 +139,24 @@ class ServerProtocolTestGenerator(
        }
    }

    // This function applies a "fix function" to each broken test before we synthesize it.
    // Broken tests are those whose definitions in the `awslabs/smithy` repository are wrong, usually because they have
    // not been written with a server-side perspective in mind.
    private fun List<TestCase>.fixBroken(): List<TestCase> = this.map { when (it) {
        is TestCase.RequestTest -> {
            val howToFixIt = BrokenRequestTests[Pair(codegenContext.serviceShape.id.toString(), it.testCase.id)]
            if (howToFixIt == null) {
                it
            } else {
                val fixed = howToFixIt(it.testCase)
                TestCase.RequestTest(fixed, it.targetShape)
            }
        }
        is TestCase.ResponseTest -> {
            it
        }
    } }

    private fun renderTestCaseBlock(
        testCase: HttpMessageTestCase,
        testModuleWriter: RustWriter,
@@ -413,11 +432,8 @@ class ServerProtocolTestGenerator(
        private val Ec2Query = "aws.protocoltests.ec2#AwsEc2"
        private val ExpectFail = setOf<FailingTest>(
            FailingTest(RestJson, "RestJsonAllQueryStringTypes", Action.Request),
            FailingTest(RestJson, "RestJsonQueryStringMap", Action.Request),
            FailingTest(RestJson, "RestJsonQueryStringEscaping", Action.Request),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatQueryValues", Action.Request),
            FailingTest(RestJson, "RestJsonSupportsInfinityFloatQueryValues", Action.Request),
            FailingTest(RestJson, "RestJsonSupportsNegativeInfinityFloatQueryValues", Action.Request),
            FailingTest(RestJson, "DocumentOutput", Action.Response),
            FailingTest(RestJson, "DocumentOutputString", Action.Response),
            FailingTest(RestJson, "DocumentOutputNumber", Action.Response),
@@ -549,5 +565,55 @@ class ServerProtocolTestGenerator(
        // These tests are not even attempted to be generated, either because they will not compile
        // or because they are flaky
        private val DisableTests = setOf<String>()

        private fun fixRestJsonSupportsNaNFloatQueryValues(testCase: HttpRequestTestCase): HttpRequestTestCase {
            // TODO This test does not pass, even after fixing it with this function, because, in IEEE 754 floating
            // point numbers, `NaN` is not equal to any other floating point number, even itself! So we can't compare it
            // to any "expected" value.
            // Reference: https://doc.rust-lang.org/std/primitive.f32.html
            // Request for guidance about this test to Smithy team: https://github.com/awslabs/smithy/pull/1040#discussion_r780418707
            val params = Node.parse("""{
                "queryFloat": "NaN",
                "queryDouble": "NaN",
                "queryParamsMapOfStringList": {
                    "Float": ["NaN"],
                    "Double": ["NaN"]
                }
            }""".trimIndent()).asObjectNode().get()

            return testCase.toBuilder().params(params).build()
        }
        private fun fixRestJsonSupportsInfinityFloatQueryValues(testCase: HttpRequestTestCase): HttpRequestTestCase =
            testCase.toBuilder().params(
               Node.parse("""{
                   "queryFloat": "Infinity",
                   "queryDouble": "Infinity",
                   "queryParamsMapOfStringList": {
                       "Float": ["Infinity"],
                       "Double": ["Infinity"]
                   }
               }""".trimMargin()).asObjectNode().get()
            ).build()
        private fun fixRestJsonSupportsNegativeInfinityFloatQueryValues(testCase: HttpRequestTestCase): HttpRequestTestCase =
            testCase.toBuilder().params(
                Node.parse("""{
                   "queryFloat": "-Infinity",
                   "queryDouble": "-Infinity",
                   "queryParamsMapOfStringList": {
                       "Float": ["-Infinity"],
                       "Double": ["-Infinity"]
                   }
               }""".trimMargin()).asObjectNode().get()
            ).build()

        // These are tests whose definitions in the `awslabs/smithy` repository are wrong.
        // This is because they have not been written from a server perspective, and as such the expected `params` field is incomplete.
        // TODO Contribute a PR to fix them upstream and remove them from this list once the fixes get published in the next Smithy release.
        private val BrokenRequestTests = mapOf(
            // https://github.com/awslabs/smithy/pull/1040
            Pair(RestJson, "RestJsonSupportsNaNFloatQueryValues") to ::fixRestJsonSupportsNaNFloatQueryValues,
            Pair(RestJson, "RestJsonSupportsInfinityFloatQueryValues") to ::fixRestJsonSupportsInfinityFloatQueryValues,
            Pair(RestJson, "RestJsonSupportsNegativeInfinityFloatQueryValues") to ::fixRestJsonSupportsNegativeInfinityFloatQueryValues
        )
    }
}
+14 −6
Original line number Diff line number Diff line
@@ -722,22 +722,23 @@ private class ServerHttpProtocolImplGenerator(
            val (queryBindingsTargettingCollection, queryBindingsTargettingSimple) =
                queryBindings.partition { model.expectShape(it.member.target) is CollectionShape }
            queryBindingsTargettingSimple.forEach {
                rust("let mut seen_${it.memberName.toSnakeCase()} = false;")
                rust("let mut seen_${symbolProvider.toMemberName(it.member)} = false;")
            }
            queryBindingsTargettingCollection.forEach {
                rust("let mut ${it.memberName.toSnakeCase()} = Vec::new();")
                rust("let mut ${symbolProvider.toMemberName(it.member)} = Vec::new();")
            }

            rustBlock("for (k, v) in pairs") {
                queryBindingsTargettingSimple.forEach {
                    val deserializer = generateParsePercentEncodedStrFn(it)
                    val memberName = symbolProvider.toMemberName(it.member)
                    rustTemplate(
                        """
                        if !seen_${it.memberName.toSnakeCase()} && k == "${it.locationName}" {
                        if !seen_${memberName} && k == "${it.locationName}" {
                            input = input.${it.member.setterName()}(
                                #{deserializer}(v)?
                            );
                            seen_${it.memberName.toSnakeCase()} = true;
                            seen_${memberName} = true;
                        }
                        """.trimIndent(),
                        "deserializer" to deserializer
@@ -787,7 +788,7 @@ private class ServerHttpProtocolImplGenerator(
                                )
                            }
                        }
                        rust("${it.memberName.toSnakeCase()}.push(v);")
                        rust("${symbolProvider.toMemberName(it.member)}.push(v);")
                    }
                }

@@ -810,9 +811,16 @@ private class ServerHttpProtocolImplGenerator(
                rust("input = input.${queryParamsBinding.member.setterName()}(Some(query_params));")
            }
            queryBindingsTargettingCollection.forEach {
                val memberName = symbolProvider.toMemberName(it.member)
                rustTemplate(
                    """
                    input = input.${it.member.setterName()}(Some(${it.memberName.toSnakeCase()}));
                    input = input.${it.member.setterName()}(
                        if ${memberName}.is_empty() {
                            None
                        } else {
                            Some(${memberName})
                        }
                    );
                    """.trimIndent()
                )
            }