Unverified Commit 171d1a4a authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Use `pretty_assertions` in protocol tests (#1038)

* Use `pretty_assertions` in protocol tests

We do this by pulling in `pretty_assertions` into the `codegenScope` so
that rendered code that uses `assert_eq!` _explicitly_ invokes the macro
from `pretty_assertions`.
parent 38fc65da
Loading
Loading
Loading
Loading
+12 −9
Original line number Diff line number Diff line
@@ -75,6 +75,7 @@ class ServerProtocolTestGenerator(
        "Hyper" to CargoDependency.Hyper.asType(),
        "AxumCore" to ServerCargoDependency.AxumCore.asType(),
        "SmithyHttpServer" to CargoDependency.SmithyHttpServer(codegenContext.runtimeConfig).asType(),
        "AssertEq" to CargoDependency.PrettyAssertions.asType().member("assert_eq!")
    )

    sealed class TestCase {
@@ -287,20 +288,21 @@ class ServerProtocolTestGenerator(
            """,
            *codegenScope,
        )
        rust(
        rustTemplate(
            """
            assert_eq!(
            #{AssertEq}(
                http::StatusCode::from_u16(${testCase.code}).expect("invalid expected HTTP status code"),
                http_response.status()
            );
            """
            """,
            *codegenScope
        )
        checkHttpExtensions(this)
        if (!testCase.body.isEmpty()) {
            rustTemplate(
                """
                let body = #{Hyper}::body::to_bytes(http_response.into_body()).await.expect("unable to extract body to bytes");
                assert_eq!(${escape(testCase.body.get()).dq()}, body);
                #{AssertEq}(${escape(testCase.body.get()).dq()}, body);
                """,
                *codegenScope
            )
@@ -329,17 +331,18 @@ class ServerProtocolTestGenerator(
        if (operationShape.outputShape(model).hasStreamingMember(model)) {
            rustWriter.rust("""todo!("streaming types aren't supported yet");""")
        } else {
            rustWriter.rust("assert_eq!(input, expected);")
            rustWriter.rustTemplate("#{AssertEq}(input, expected);", *codegenScope)
        }
    }

    private fun checkHttpExtensions(rustWriter: RustWriter) {
        rustWriter.rust(
        rustWriter.rustTemplate(
            """
            let request_extensions = http_response.extensions().get::<aws_smithy_http_server::RequestExtensions>().expect("extension `RequestExtensions` not found");
            assert_eq!(request_extensions.namespace, ${operationShape.id.getNamespace().dq()});
            assert_eq!(request_extensions.operation_name, ${operationSymbol.name.dq()});
            """.trimIndent()
            #{AssertEq}(request_extensions.namespace, ${operationShape.id.getNamespace().dq()});
            #{AssertEq}(request_extensions.operation_name, ${operationSymbol.name.dq()});
            """.trimIndent(),
            *codegenScope
        )
    }

+1 −0
Original line number Diff line number Diff line
@@ -201,6 +201,7 @@ data class CargoDependency(
        val LazyStatic: CargoDependency = CargoDependency("lazy_static", CratesIo("1.4"))
        val Md5: CargoDependency = CargoDependency("md5", CratesIo("0.7"))
        val PercentEncoding: CargoDependency = CargoDependency("percent-encoding", CratesIo("2"))
        val PrettyAssertions: CargoDependency = CargoDependency("pretty_assertions", CratesIo("1"), scope = DependencyScope.Dev)
        val Regex: CargoDependency = CargoDependency("regex", CratesIo("1"))
        val Ring: CargoDependency = CargoDependency("ring", CratesIo("0.16"))
        val TempFile: CargoDependency = CargoDependency("temp-file", CratesIo("0.1.6"), scope = DependencyScope.Dev)
+28 −15
Original line number Diff line number Diff line
@@ -78,6 +78,12 @@ class ProtocolTestGenerator(
        Instantiator(symbolProvider, model, runtimeConfig)
    }

    private val codegenScope = arrayOf(
        "SmithyHttp" to CargoDependency.SmithyHttp(codegenContext.runtimeConfig).asType(),
        "Http" to CargoDependency.Http.asType(),
        "AssertEq" to CargoDependency.PrettyAssertions.asType().member("assert_eq!")
    )

    sealed class TestCase {
        abstract val testCase: HttpMessageTestCase

@@ -189,23 +195,24 @@ class ProtocolTestGenerator(
        with(httpRequestTestCase) {
            host.orNull()?.also { host ->
                val withScheme = "http://$host"
                rust(
                rustTemplate(
                    """
                    let mut http_request = http_request;
                    let ep = #T::endpoint::Endpoint::mutable(#T::Uri::from_static(${withScheme.dq()}));
                    let ep = #{SmithyHttp}::endpoint::Endpoint::mutable(#{Http}::Uri::from_static(${withScheme.dq()}));
                    ep.set_endpoint(http_request.uri_mut(), parts.acquire().get());
                    """,
                    CargoDependency.SmithyHttp(codegenContext.runtimeConfig).asType(), CargoDependency.Http.asType()
                    *codegenScope
                )
            }
            rust(
                """
                assert_eq!(http_request.method(), ${method.dq()});
                assert_eq!(http_request.uri().path(), ${uri.dq()});
            rustTemplate(
                """
                #{AssertEq}(http_request.method(), ${method.dq()});
                #{AssertEq}(http_request.uri().path(), ${uri.dq()});
                """,
                *codegenScope
            )
            resolvedHost.orNull()?.also { host ->
                rust("""assert_eq!(http_request.uri().host().expect("host should be set"), ${host.dq()});""")
                rustTemplate("""#{AssertEq}(http_request.uri().host().expect("host should be set"), ${host.dq()});""", *codegenScope)
            }
        }
        checkQueryParams(this, httpRequestTestCase.queryParams)
@@ -295,7 +302,7 @@ class ProtocolTestGenerator(
            val errorVariant = codegenContext.symbolProvider.toSymbol(expectedShape).name
            rust("""let parsed = parsed.expect_err("should be error response");""")
            rustBlock("if let #TKind::$errorVariant(actual_error) = parsed.kind", errorSymbol) {
                rust("assert_eq!(expected_output, actual_error);")
                rustTemplate("#{AssertEq}(expected_output, actual_error);", *codegenScope)
            }
            rustBlock("else") {
                rust("panic!(\"wrong variant: Got: {:?}. Expected: {:?}\", parsed, expected_output);")
@@ -305,13 +312,14 @@ class ProtocolTestGenerator(
            outputShape.members().forEach { member ->
                val memberName = codegenContext.symbolProvider.toMemberName(member)
                if (member.isStreaming(codegenContext.model)) {
                    rust(
                    rustTemplate(
                        """
                        assert_eq!(
                        #{AssertEq}(
                            parsed.$memberName.collect().await.unwrap().into_bytes(),
                            expected_output.$memberName.collect().await.unwrap().into_bytes()
                        );
                        """
                        """,
                        *codegenScope
                    )
                } else {
                    when (codegenContext.model.expectShape(member.target)) {
@@ -327,7 +335,7 @@ class ProtocolTestGenerator(
                            )
                        }
                        else ->
                            rust("""assert_eq!(parsed.$memberName, expected_output.$memberName, "Unexpected value for `$memberName`");""")
                            rustTemplate("""#{AssertEq}(parsed.$memberName, expected_output.$memberName, "Unexpected value for `$memberName`");""", *codegenScope)
                    }
                }
            }
@@ -345,8 +353,13 @@ class ProtocolTestGenerator(
    private fun checkBody(rustWriter: RustWriter, body: String, mediaType: String?) {
        rustWriter.write("""let body = http_request.body().bytes().expect("body should be strict");""")
        if (body == "") {
            rustWriter.write("// No body")
            rustWriter.write("assert_eq!(std::str::from_utf8(body).unwrap(), ${"".dq()});")
            rustWriter.rustTemplate(
                """
                // No body
                #{AssertEq}(std::str::from_utf8(body).unwrap(), "");
                """,
                *codegenScope
            )
        } else {
            // When we generate a body instead of a stub, drop the trailing `;` and enable the assertion
            assertOk(rustWriter) {