Unverified Commit 97a49f3c authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Consider `NaN` to be equal to itself in server protocol tests (#1177)

The client protocol tests use the
`aws_smithy_protocol_test::FloatEquals` for this [0].

Note we're only applying this to direct floating point shape members,
i.e. this commit _does not_ address #1147.

[0]: https://docs.rs/aws-smithy-protocol-test/latest/aws_smithy_protocol_test/trait.FloatEquals.html
parent e1099324
Loading
Loading
Loading
Loading
+36 −10
Original line number Diff line number Diff line
@@ -8,6 +8,8 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.DoubleShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
@@ -356,10 +358,43 @@ class ServerProtocolTestGenerator(
                    )
                }
            }
        } else {
            val hasFloatingPointMembers = inputShape.members().any {
                val target = model.expectShape(it.target)
                (target is DoubleShape) || (target is FloatShape)
            }

            // TODO(https://github.com/awslabs/smithy-rs/issues/1147) Handle the case of nested floating point members.
            if (hasFloatingPointMembers) {
                for (member in inputShape.members()) {
                    val memberName = codegenContext.symbolProvider.toMemberName(member)
                    when (codegenContext.model.expectShape(member.target)) {
                        is DoubleShape, is FloatShape -> {
                            rustWriter.addUseImports(
                                RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "FloatEquals").toSymbol()
                            )
                            rustWriter.rust(
                                """
                                assert!(parsed.$memberName.float_equals(&expected.$memberName),
                                    "Unexpected value for `$memberName` {:?} vs. {:?}", expected.$memberName, parsed.$memberName);
                                """
                            )
                        }
                        else -> {
                            rustWriter.rustTemplate(
                                """
                                    #{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`");
                                    """,
                                *codegenScope
                            )
                        }
                    }
                }
            } else {
                rustWriter.rustTemplate("#{AssertEq}(parsed, expected);", *codegenScope)
            }
        }
    }

    private fun checkBody(rustWriter: RustWriter, body: String, mediaType: String?) {
        rustWriter.rustTemplate(
@@ -503,11 +538,9 @@ class ServerProtocolTestGenerator(
        private val ExpectFail = setOf<FailingTest>(
            // Headers.
            FailingTest(RestJson, "RestJsonHttpWithHeadersButNoPayload", Action.Request),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatHeaderInputs", Action.Request),
            FailingTest(RestJson, "RestJsonInputAndOutputWithQuotedStringHeaders", Action.Response),

            FailingTest(RestJson, "RestJsonUnitInputAndOutputNoOutput", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatQueryValues", Action.Request),
            FailingTest(RestJson, "RestJsonEndpointTrait", Action.Request),
            FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", Action.Request),
            FailingTest(RestJson, "RestJsonFooErrorUsingCode", Action.Response),
@@ -516,10 +549,8 @@ class ServerProtocolTestGenerator(
            FailingTest(RestJson, "RestJsonFooErrorWithDunderType", Action.Response),
            FailingTest(RestJson, "RestJsonFooErrorWithDunderTypeAndNamespace", Action.Response),
            FailingTest(RestJson, "RestJsonFooErrorWithDunderTypeUriAndNamespace", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatLabels", Action.Request),
            FailingTest(RestJson, "RestJsonHttpResponseCode", Action.Response),
            FailingTest(RestJson, "RestJsonNoInputAndNoOutput", Action.Response),
            FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Request),
            FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithBlob", Action.Response),
            FailingTest(RestJson, "RestJsonHttpWithEmptyBlobPayload", Action.Request),
            FailingTest(RestJson, "RestJsonHttpWithEmptyStructurePayload", Action.Request),
@@ -540,11 +571,6 @@ class ServerProtocolTestGenerator(
        private val DisableTests = setOf<String>()

        private fun fixRestJsonSupportsNaNFloatQueryValues(testCase: HttpRequestTestCase): HttpRequestTestCase {
            // TODO This test does not pass, even after fixing it with this function, because, in IEEE 754 floating
            // point numbers, `NaN` is not equal to any other floating point number, even itself! So we can't compare it
            // to any "expected" value.
            // Reference: https://doc.rust-lang.org/std/primitive.f32.html
            // Request for guidance about this test to Smithy team: https://github.com/awslabs/smithy/pull/1040#discussion_r780418707
            val params = Node.parse(
                """
                {