Unverified Commit 1d00f381 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Protocol Test Updates (#336)

* Update procotol test generator to use the external parsing API
parent 81a1ab6e
Loading
Loading
Loading
Loading
+22 −3
Original line number Diff line number Diff line
@@ -18,11 +18,14 @@ import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait
import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase
import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.escape
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
@@ -225,18 +228,34 @@ class HttpProtocolTestGenerator(
        writeInline("let expected_output =")
        instantiator.render(this, expectedShape, testCase.params)
        write(";")
        write("let http_response = #T::new()", RuntimeType.HttpResponseBuilder)
        write("let mut http_response = #T::new()", RuntimeType.HttpResponseBuilder)
        testCase.headers.forEach { (key, value) ->
            writeWithNoFormatting(".header(${key.dq()}, ${value.dq()})")
        }
        rust(
            """
                .status(${testCase.code})
                .body(${testCase.body.orNull()?.dq()?.replace("#", "##") ?: "vec![]"})
                .body(#T::from(${testCase.body.orNull()?.dq()?.replace("#", "##") ?: "vec![]"}))
                .unwrap();
            """,
            RuntimeType.sdkBody(runtimeConfig = protocolConfig.runtimeConfig)
        )
        rustTemplate(
            """
            use #{parse_http_response};
            let parser = #{op}::new();
            let parsed = parser.parse_unloaded(&mut http_response);
            let http_response = http_response.map(|body|#{bytes}::copy_from_slice(body.bytes().unwrap()));
            let parsed = parsed.unwrap_or_else(||
                <#{op} as #{parse_http_response}<#{sdk_body}>>::parse_loaded(&parser, &http_response)
            );
        """,
            "op" to operationSymbol,
            "bytes" to RuntimeType.Bytes,
            "parse_http_response" to CargoDependency.SmithyHttp(protocolConfig.runtimeConfig).asType()
                .member("response::ParseHttpResponse"),
            "sdk_body" to RuntimeType.sdkBody(runtimeConfig = protocolConfig.runtimeConfig)
        )
        write("let parsed = #T::from_response(&http_response);", operationSymbol)
        if (expectedShape.hasTrait(ErrorTrait::class.java)) {
            val errorSymbol = operationShape.errorSymbol(protocolConfig.symbolProvider)
            val errorVariant = protocolConfig.symbolProvider.toSymbol(expectedShape).name
+19 −2
Original line number Diff line number Diff line
@@ -14,10 +14,12 @@ import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.smithy.CodegenVisitor
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
@@ -25,6 +27,7 @@ import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.testutil.generatePluginContext
import software.amazon.smithy.rust.codegen.util.CommandFailed
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.runCommand
import java.nio.file.Path

@@ -117,9 +120,23 @@ class HttpProtocolTestGeneratorTest {
    ): Path {

        // A stubbed test protocol to do enable testing intentionally broken protocols
        class TestProtocol(protocolConfig: ProtocolConfig) : HttpProtocolGenerator(protocolConfig) {
        class TestProtocol(private val protocolConfig: ProtocolConfig) : HttpProtocolGenerator(protocolConfig) {
            private val symbolProvider = protocolConfig.symbolProvider
            override fun traitImplementations(operationWriter: RustWriter, operationShape: OperationShape) {
                // no trait implementations for tests
                operationWriter.rustTemplate(
                    """
                    impl #{parse_strict} for ${operationShape.id.name}{
                        type Output = Result<#{output}, #{error}>;
                        fn parse(&self, response: &#{response}<#{bytes}>) -> Self::Output {
                            self.parse_response(response)
                        }
                    }""",
                    "parse_strict" to RuntimeType.parseStrict(protocolConfig.runtimeConfig),
                    "output" to symbolProvider.toSymbol(operationShape.outputShape(protocolConfig.model)),
                    "error" to operationShape.errorSymbol(symbolProvider),
                    "response" to RuntimeType.Http("Response"),
                    "bytes" to RuntimeType.Bytes
                )
            }

            override fun fromResponseImpl(implBlockWriter: RustWriter, operationShape: OperationShape) {