Unverified Commit 5eb3c822 authored by Julian Antonielli's avatar Julian Antonielli Committed by GitHub
Browse files

Run HttpMalformedRequest tests through router (#1904)



* Add links to relevant issues/PRs for failing tests

* Update codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt

Co-authored-by: default avatardavid-perez <d@vidp.dev>

* Make `build_router_and_make_request` return the response

* Use `rustWriter.write` over `rustWriter.rust` for incomplete rust code

* Run HttpMalformedRequest tests through router

* Don't assert on operation_extension

This causes HttpMalformedRequest tests to fail.

* Refactor code to avoid repetition

* Run ktlint

* Check operation extension for regular request tests

* Run ktlint

* Remove unused parameter from `makeRequest`

* Refactor checkRequest2

* Remove `checkRequest/2`, run new-style tests for malformed http requests

- Remove `checkRequest`, `checkRequest2` in favour
  of using `makeRequest`/`makeRequest2` directly
- Run service builder tests for MalformedHttpRequest tests

* Format comment better

* Remove `RestJsonStreamingTraitsRequireLengthWithBlob` from failing tests

This test actually passes fine, but is not relevant for the server.

* Add comment about `RestJsonHttpResponseCodeDefaultsToModeledCode`

* Remove channel assertion on `makeRequest2`

* Use `checkOperationExtension` for `makeRequest2`

* Attempt to use a channel for both `makeRequest/2`

* Go back to using channel for new api and operation extension for old api

* Fix malformedHttpRequest tests

* Add block around each API's tests

* Remove unrelated file

* Mark wrong tests as failing

Marks these two tests as failing:
  - RestJsonWithPayloadExpectsImpliedContentType
  - RestJsonBodyMalformedMapNullKey
These will be fixed in https://github.com/awslabs/smithy/pull/1477

.

Co-authored-by: default avatardavid-perez <d@vidp.dev>
parent 4c852b1d
Loading
Loading
Loading
Loading
+69 −32
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
@@ -200,12 +201,10 @@ class ServerProtocolTestGenerator(
                        #{RegistryBuilderMethods:W}
                }

                /// The operation full name is a concatenation of `<operation namespace>.<operation name>`.
                pub(crate) async fn build_router_and_make_request(
                    http_request: #{Http}::request::Request<#{SmithyHttpServer}::body::Body>,
                    operation_full_name: &str,
                    f: &dyn Fn(RegistryBuilder) -> RegistryBuilder,
                ) {
                ) -> #{Http}::response::Response<#{SmithyHttpServer}::body::BoxBody> {
                    let mut router: #{Router} = f(create_operation_registry_builder())
                        .build()
                        .expect("unable to build operation registry")
@@ -214,6 +213,12 @@ class ServerProtocolTestGenerator(
                        .call(http_request)
                        .await
                        .expect("unable to make an HTTP request");

                    http_response
                }

                /// The operation full name is a concatenation of `<operation namespace>.<operation name>`.
                pub(crate) fn check_operation_extension_was_set(http_response: #{Http}::response::Response<#{SmithyHttpServer}::body::BoxBody>, operation_full_name: &str) {
                    let operation_extension = http_response.extensions()
                        .get::<#{SmithyHttpServer}::extension::OperationExtension>()
                        .expect("extension `OperationExtension` not found");
@@ -284,6 +289,7 @@ class ServerProtocolTestGenerator(

                    is TestCase.MalformedRequestTest -> this.renderHttpMalformedRequestTestCase(
                        it.testCase,
                        operationShape,
                        operationSymbol,
                    )
                }
@@ -388,7 +394,8 @@ class ServerProtocolTestGenerator(
            renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull())
        }
        if (protocolSupport.requestBodyDeserialization) {
            checkRequest(operationShape, operationSymbol, httpRequestTestCase, this)
            makeRequest(operationShape, this, checkRequestHandler(operationShape, httpRequestTestCase))
            checkHandlerWasEntered(operationShape, operationSymbol, this)
        }

        // Test against new service builder.
@@ -396,7 +403,8 @@ class ServerProtocolTestGenerator(
            renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull())
        }
        if (protocolSupport.requestBodyDeserialization) {
            checkRequest2(operationShape, operationSymbol, httpRequestTestCase, this)
            makeRequest2(operationShape, operationSymbol, this, checkRequestHandler(operationShape, httpRequestTestCase))
            checkHandlerWasEntered2(this)
        }

        // Explicitly warn if the test case defined parameters that we aren't doing anything with
@@ -467,25 +475,31 @@ class ServerProtocolTestGenerator(
     */
    private fun RustWriter.renderHttpMalformedRequestTestCase(
        testCase: HttpMalformedRequestTestCase,
        operationShape: OperationShape,
        operationSymbol: Symbol,
    ) {
        val (_, outputT) = operationInputOutputTypes[operationShape]!!

        rust("// Use the `OperationRegistryBuilder`")
        rustBlock("") {
            with(testCase.request) {
                // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`.
                renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull())
            }
            makeRequest(operationShape, this, writable("todo!() as $outputT"))
            checkResponse(this, testCase.response)
        }

        val operationName = "${operationSymbol.name}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"
        rustTemplate(
            """
            let mut http_request = #{SmithyHttpServer}::request::RequestParts::new(http_request);
            let rejection = super::$operationName::from_request(&mut http_request).await.expect_err("request was accepted but we expected it to be rejected");
            let http_response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(rejection);
            """,
            "Protocol" to protocolGenerator.protocol.markerStruct(),
            *codegenScope,
        )
        rust("// Use new service builder")
        rustBlock("") {
            with(testCase.request) {
                // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`.
                renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull())
            }
            makeRequest2(operationShape, operationSymbol, this, writable("todo!() as $outputT"))
            checkResponse(this, testCase.response)
        }
    }

    private fun RustWriter.renderHttpRequest(
        uri: String,
@@ -563,41 +577,53 @@ class ServerProtocolTestGenerator(
        }

    /** Checks the request using the `OperationRegistryBuilder`. */
    private fun checkRequest(
    private fun makeRequest(
        operationShape: OperationShape,
        operationSymbol: Symbol,
        httpRequestTestCase: HttpRequestTestCase,
        rustWriter: RustWriter,
        operationBody: Writable,
    ) {
        val (inputT, outputT) = operationInputOutputTypes[operationShape]!!

        rustWriter.withBlock(
        rustWriter.withBlockTemplate(
            """
            super::$PROTOCOL_TEST_HELPER_MODULE_NAME::build_router_and_make_request(
            let http_response = super::$PROTOCOL_TEST_HELPER_MODULE_NAME::build_router_and_make_request(
                http_request,
                "${operationShape.id.namespace}.${operationSymbol.name}",
                &|builder| {
                    builder.${operationShape.toName()}((|input| Box::pin(async move {
            """,

            "})) as super::$PROTOCOL_TEST_HELPER_MODULE_NAME::Fun<$inputT, $outputT>)}).await;",

            *codegenScope,
        ) {
            checkRequestHandler(operationShape, httpRequestTestCase)()
            operationBody()
        }
    }

    private fun checkHandlerWasEntered(
        operationShape: OperationShape,
        operationSymbol: Symbol,
        rustWriter: RustWriter,
    ) {
        val operationFullName = "${operationShape.id.namespace}.${operationSymbol.name}"
        rustWriter.rust(
            """
            super::$PROTOCOL_TEST_HELPER_MODULE_NAME::check_operation_extension_was_set(http_response, "$operationFullName");
            """,
        )
    }

    /** Checks the request using the new service builder. */
    private fun checkRequest2(
    private fun makeRequest2(
        operationShape: OperationShape,
        operationSymbol: Symbol,
        httpRequestTestCase: HttpRequestTestCase,
        rustWriter: RustWriter,
        body: Writable,
    ) {
        val (inputT, _) = operationInputOutputTypes[operationShape]!!
        val operationName = RustReservedWords.escapeIfNeeded(operationSymbol.name.toSnakeCase())
        rustWriter.rustTemplate(
            """
            ##[allow(unused_mut)]
            let (sender, mut receiver) = #{Tokio}::sync::mpsc::channel(1);
            let service = crate::service::$serviceName::unchecked_builder()
                .$operationName(move |input: $inputT| {
@@ -612,13 +638,20 @@ class ServerProtocolTestGenerator(
            let http_response = #{Tower}::ServiceExt::oneshot(service, http_request)
                .await
                .expect("unable to make an HTTP request");
            assert!(receiver.recv().await.is_some())
            """,
            "Body" to checkRequestHandler(operationShape, httpRequestTestCase),
            "Body" to body,
            *codegenScope,
        )
    }

    private fun checkHandlerWasEntered2(rustWriter: RustWriter) {
        rustWriter.rust(
            """
            assert!(receiver.recv().await.is_some());
            """,
        )
    }

    private fun checkRequestParams(inputShape: StructureShape, rustWriter: RustWriter) {
        if (inputShape.hasStreamingMember(model)) {
            // A streaming shape does not implement `PartialEq`, so we have to iterate over the input shape's members
@@ -842,7 +875,7 @@ class ServerProtocolTestGenerator(
    private fun assertOk(rustWriter: RustWriter, inner: Writable) {
        rustWriter.rust("#T(", RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "assert_ok"))
        inner(rustWriter)
        rustWriter.rust(");")
        rustWriter.write(");")
    }

    private fun strSlice(writer: RustWriter, args: List<String>) {
@@ -872,6 +905,10 @@ class ServerProtocolTestGenerator(
        private val AwsQuery = "aws.protocoltests.query#AwsQuery"
        private val Ec2Query = "aws.protocoltests.ec2#AwsEc2"
        private val ExpectFail = setOf<FailingTest>(
            // Pending merge from the Smithy team: see https://github.com/awslabs/smithy/pull/1477.
            FailingTest(RestJson, "RestJsonWithPayloadExpectsImpliedContentType", TestType.MalformedRequest),
            FailingTest(RestJson, "RestJsonBodyMalformedMapNullKey", TestType.MalformedRequest),

            // Pending resolution from the Smithy team, see https://github.com/awslabs/smithy/issues/1068.
            FailingTest(RestJson, "RestJsonHttpWithHeadersButNoPayload", TestType.Request),