Unverified Commit 0af416fc authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Add response protocol test support (#73)

* Add response protocol test support

* CR improvements
parent 48b4ef6a
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -114,6 +114,7 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na
            RuntimeType(name = path, dependency = CargoDependency.Http, namespace = "http")

        val HttpRequestBuilder = Http("request::Builder")
        val HttpResponseBuilder = Http("response::Builder")

        val Serialize = RuntimeType("Serialize", CargoDependency.Serde, namespace = "serde")
        val Deserialize: RuntimeType = RuntimeType("Deserialize", CargoDependency.Serde, namespace = "serde")
+160 −55
Original line number Diff line number Diff line
package software.amazon.smithy.rust.codegen.smithy.generators

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.protocoltests.traits.HttpMessageTestCase
import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase
import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait
import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase
import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait
import software.amazon.smithy.rust.codegen.lang.Custom
import software.amazon.smithy.rust.codegen.lang.RustMetadata
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.lang.docs
import software.amazon.smithy.rust.codegen.lang.escape
import software.amazon.smithy.rust.codegen.lang.rust
import software.amazon.smithy.rust.codegen.lang.rustBlock
import software.amazon.smithy.rust.codegen.lang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.orNull
import software.amazon.smithy.rust.codegen.util.outputShape
import java.util.logging.Logger

data class ProtocolSupport(
    val requestBodySerialization: Boolean
    val requestBodySerialization: Boolean,
    val responseDeserialization: Boolean,
    val errorDeserialization: Boolean
)

/**
 * Generate protocol tests for an operation
 */
@@ -28,6 +40,7 @@ class HttpProtocolTestGenerator(
    private val writer: RustWriter
) {
    private val logger = Logger.getLogger(javaClass.name)

    // TODO: remove these once Smithy publishes fixes.
    // These tests are not even attempted to be compiled
    val DisableTests = setOf(
@@ -55,15 +68,35 @@ class HttpProtocolTestGenerator(
        "RestJsonHttpPrefixHeadersArePresent" // https://github.com/awslabs/smithy-rs/issues/35
    )
    private val inputShape = operationShape.inputShape(protocolConfig.model)
    fun render() {
        operationShape.getTrait(HttpRequestTestsTrait::class.java).map {
            renderHttpRequestTests(it)
    private val outputShape = operationShape.outputShape(protocolConfig.model)
    private val operationSymbol = protocolConfig.symbolProvider.toSymbol(operationShape)
    private val operationIndex = OperationIndex.of(protocolConfig.model)

    private val instantiator = with(protocolConfig) {
        Instantiator(symbolProvider, model, runtimeConfig)
    }

    sealed class TestCase {
        abstract val testCase: HttpMessageTestCase

        data class RequestTest(override val testCase: HttpRequestTestCase) : TestCase()
        data class ResponseTest(override val testCase: HttpResponseTestCase, val targetShape: StructureShape) :
            TestCase()
    }

    private fun renderHttpRequestTests(httpRequestTestsTrait: HttpRequestTestsTrait) {
        with(protocolConfig) {
            val operationName = symbolProvider.toSymbol(operationShape).name
    fun render() {
        val requestTests = operationShape.getTrait(HttpRequestTestsTrait::class.java)
            .orNull()?.testCases.orEmpty().map { TestCase.RequestTest(it) }
        val responseTests = operationShape.getTrait(HttpResponseTestsTrait::class.java)
            .orNull()?.testCases.orEmpty().map { TestCase.ResponseTest(it, outputShape) }

        val errorTests = operationIndex.getErrors(operationShape).flatMap { error ->
            val testCases = error.getTrait(HttpResponseTestsTrait::class.java).orNull()?.testCases.orEmpty()
            testCases.map { TestCase.ResponseTest(it, error) }
        }
        val allTests: List<TestCase> = (requestTests + responseTests + errorTests).filterMatching()
        if (allTests.isNotEmpty()) {
            val operationName = operationSymbol.name
            val testModuleName = "${operationName.toSnakeCase()}_request_test"
            val moduleMeta = RustMetadata(
                public = false,
@@ -73,33 +106,58 @@ class HttpProtocolTestGenerator(
                )
            )
            writer.withModule(testModuleName, moduleMeta) {
                httpRequestTestsTrait.testCases.filter { it.protocol == protocol }
                    .filter { !DisableTests.contains(it.id) }.forEach { testCase ->
                        try {
                            renderHttpRequestTestCase(testCase, this)
                        } catch (ex: Exception) {
                            println("failed to generate ${testCase.id}")
                            ex.printStackTrace()
                renderAllTestCases(allTests)
            }
        }
    }

    private fun RustWriter.renderAllTestCases(allTests: List<TestCase>) {
        allTests.forEach {
            renderTestCaseBlock(it.testCase, this) {
                when (it) {
                    is TestCase.RequestTest -> this.renderHttpRequestTestCase(it.testCase)
                    is TestCase.ResponseTest -> this.renderHttpResponseTestCase(it.testCase, it.targetShape)
                }
            }
        }
    }

    private val instantiator = with(protocolConfig) {
        Instantiator(symbolProvider, model, runtimeConfig)
    /**
     * Filter out test cases that are disabled or don't match the service protocol
     */
    private fun List<TestCase>.filterMatching(): List<TestCase> = this.filter { testCase ->
        testCase.testCase.protocol == protocolConfig.protocol &&
            !DisableTests.contains(testCase.testCase.id)
    }

    private fun renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase, testModuleWriter: RustWriter) {
        httpRequestTestCase.documentation.map {
            testModuleWriter.docs(testModuleWriter.escape(it))
    private fun renderTestCaseBlock(
        testCase: HttpMessageTestCase,
        testModuleWriter: RustWriter,
        block: RustWriter.() -> Unit
    ) {
        testModuleWriter.setNewlinePrefix("/// ")
        testCase.documentation.map {
            testModuleWriter.writeWithNoFormatting(it)
        }
        testModuleWriter.docs("Test ID: ${httpRequestTestCase.id}")
        testModuleWriter.write("Test ID: ${testCase.id}")
        testModuleWriter.setNewlinePrefix("")
        testModuleWriter.writeWithNoFormatting("#[test]")
        if (ExpectFail.contains(httpRequestTestCase.id)) {
        if (ExpectFail.contains(testCase.id)) {
            testModuleWriter.writeWithNoFormatting("#[should_panic]")
        }
        testModuleWriter.rustBlock("fn test_${httpRequestTestCase.id.toSnakeCase()}()") {
        val fnName = when (testCase) {
            is HttpResponseTestCase -> "_response"
            is HttpRequestTestCase -> "_request"
            else -> throw CodegenException("unknown test case type")
        }
        testModuleWriter.rustBlock("fn test_${testCase.id.toSnakeCase()}$fnName()") {
            block(this)
        }
    }

    private fun RustWriter.renderHttpRequestTestCase(
        httpRequestTestCase: HttpRequestTestCase
    ) {
        writeInline("let input =")
        instantiator.render(this, inputShape, httpRequestTestCase.params)
        write(";")
@@ -136,6 +194,51 @@ class HttpProtocolTestGenerator(
            }
        }
    }

    private fun RustWriter.renderHttpResponseTestCase(
        httpResponseTestCase: HttpResponseTestCase,
        expectedShape: StructureShape
    ) {
        if (!protocolSupport.responseDeserialization || (
            !protocolSupport.errorDeserialization && expectedShape.hasTrait(
                    ErrorTrait::class.java
                )
            )
        ) {
            rust("/* test case disabled for this protocol (not yet supported) */")
            if (ExpectFail.contains(httpResponseTestCase.id)) {
                // this test needs to fail, minor hack. Caused by overlap between ids of request & response tests
                write("todo!()")
            }
            return
        }
        writeInline("let expected_output =")
        instantiator.render(this, expectedShape, httpResponseTestCase.params)
        write(";")
        write("let http_response = #T::new()", RuntimeType.HttpResponseBuilder)
        httpResponseTestCase.headers.forEach { (key, value) ->
            writeWithNoFormatting(".header(${key.dq()}, ${value.dq()})")
        }
        rust(
            """
                .status(${httpResponseTestCase.code})
                .body(${httpResponseTestCase.body.orNull()?.dq()?.replace("#", "##") ?: "vec![]"})
                .unwrap();
            """
        )
        write("let parsed = #T::from_response(http_response);", operationSymbol)
        if (expectedShape.hasTrait(ErrorTrait::class.java)) {
            val errorSymbol = operationShape.errorSymbol(protocolConfig.symbolProvider)
            val errorVariant = protocolConfig.symbolProvider.toSymbol(expectedShape).name
            rustBlock("if let Err(#T::$errorVariant(actual_error)) = parsed", errorSymbol) {
                write("assert_eq!(expected_output, actual_error);")
            }
            rustBlock("else") {
                write("panic!(\"wrong variant: {:?}\", parsed);")
            }
        } else {
            write("assert_eq!(parsed.unwrap(), expected_output);")
        }
    }

    private fun checkRequiredHeaders(rustWriter: RustWriter, requireHeaders: List<String>) {
@@ -154,7 +257,9 @@ class HttpProtocolTestGenerator(
            // When we generate a body instead of a stub, drop the trailing `;` and enable the assertion
            assertOk(rustWriter) {
                rustWriter.write(
                    "#T(input.build_body(), ${rustWriter.escape(body).dq()}, #T::from(${(mediaType ?: "unknown").dq()}))",
                    "#T(input.build_body(), ${
                    rustWriter.escape(body).dq()
                    }, #T::from(${(mediaType ?: "unknown").dq()}))",
                    RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "validate_body"),
                    RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "MediaType")
                )
+1 −1
Original line number Diff line number Diff line
@@ -73,7 +73,7 @@ class BasicAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGenerat
        )
    }

    override fun support(): ProtocolSupport = ProtocolSupport(requestBodySerialization = true)
    override fun support(): ProtocolSupport = ProtocolSupport(requestBodySerialization = true, responseDeserialization = false, errorDeserialization = false)
}

/**
+1 −1
Original line number Diff line number Diff line
@@ -36,7 +36,7 @@ class AwsRestJsonFactory : ProtocolGeneratorFactory<AwsRestJsonGenerator> {

    override fun support(): ProtocolSupport {
        // TODO: Support body for RestJson
        return ProtocolSupport(requestBodySerialization = false)
        return ProtocolSupport(requestBodySerialization = false, responseDeserialization = false, errorDeserialization = false)
    }
}

+5 −0
Original line number Diff line number Diff line
@@ -14,3 +14,8 @@ fun OperationShape.inputShape(model: Model): StructureShape {
    // The Rust Smithy generator adds an input to all shapes automatically
    return model.expectShape(this.input.get(), StructureShape::class.java)
}

fun OperationShape.outputShape(model: Model): StructureShape {
    // The Rust Smithy generator adds an output to all shapes automatically
    return model.expectShape(this.output.get(), StructureShape::class.java)
}
Loading