Unverified Commit e56f0dd6 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Refactor and DRY up protocol test generation (#3713)

Protocol test generation code was one of the earlier parts of
smithy-rs's codebase and it has accrued a fair amount of tech debt as we
have evolved the code generation primitives. Its code was also forked
when the server code generator was introduced, introducing a lot of
duplicated code that has deviated over time.

This commit refactors the code to modern standards and aims to reconcile
commonalities in `ProtocolTestGenerator`, so that both client and server
can reap centralized improvements over time.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent cba3edfb
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -13,8 +13,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.client.Fluen
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientDocs
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.DefaultProtocolTestGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Feature
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
@@ -28,6 +27,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.serviceNameOrDefault
import software.amazon.smithy.rustsdk.customize.s3.S3ExpressFluentClientCustomization
@@ -91,7 +91,7 @@ class AwsFluentClientDecorator : ClientCodegenDecorator {
        codegenContext: ClientCodegenContext,
        baseGenerator: ProtocolTestGenerator,
    ): ProtocolTestGenerator =
        DefaultProtocolTestGenerator(
        ClientProtocolTestGenerator(
            codegenContext,
            baseGenerator.protocolSupport,
            baseGenerator.operationShape,
+2 −2
Original line number Diff line number Diff line
@@ -23,7 +23,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGen
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.OperationErrorGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.DefaultProtocolTestGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolTestGenerator
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientProtocolLoader
import software.amazon.smithy.rust.codegen.client.smithy.transformers.AddErrorMessage
import software.amazon.smithy.rust.codegen.client.smithy.transformers.RemoveEventStreamOperations
@@ -322,7 +322,7 @@ class ClientCodegenVisitor(
            // render protocol tests into `operation.rs` (note operationWriter vs. inputWriter)
            codegenDecorator.protocolTestGenerator(
                codegenContext,
                DefaultProtocolTestGenerator(
                ClientProtocolTestGenerator(
                    codegenContext,
                    protocolGeneratorFactory.support(),
                    operationShape,
+1 −1
Original line number Diff line number Diff line
@@ -16,10 +16,10 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGen
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.smithy.customize.CombinedCoreCodegenDecorator
import software.amazon.smithy.rust.codegen.core.smithy.customize.CoreCodegenDecorator
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap
import java.util.ServiceLoader
import java.util.logging.Logger
+1 −1
Original line number Diff line number Diff line
@@ -17,7 +17,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCus
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderCustomization
@@ -25,6 +24,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomiza
import software.amazon.smithy.rust.codegen.core.smithy.generators.ManifestCustomizations
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator

/**
 * Delegating decorator that only applies when a condition is true
+34 −256
Original line number Diff line number Diff line
@@ -5,43 +5,37 @@

package software.amazon.smithy.rust.codegen.client.smithy.generators.protocol

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.shapes.DoubleShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.protocoltests.traits.AppliesTo
import software.amazon.smithy.protocoltests.traits.HttpMessageTestCase
import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase
import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait
import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase
import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientInstantiator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.allow
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.escape
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.FailingTest
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_10
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCaseKind
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.inputShape
import software.amazon.smithy.rust.codegen.core.util.isStreaming
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.outputShape
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import java.util.logging.Logger
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType as RT

@@ -52,18 +46,10 @@ data class ClientCreationParams(
    val clientName: String,
)

interface ProtocolTestGenerator {
    val codegenContext: ClientCodegenContext
    val protocolSupport: ProtocolSupport
    val operationShape: OperationShape

    fun render(writer: RustWriter)
}

/**
 * Generate protocol tests for an operation
 * Generate client protocol tests for an [operationShape].
 */
class DefaultProtocolTestGenerator(
class ClientProtocolTestGenerator(
    override val codegenContext: ClientCodegenContext,
    override val protocolSupport: ProtocolSupport,
    override val operationShape: OperationShape,
@@ -79,121 +65,53 @@ class DefaultProtocolTestGenerator(
            "Client" to ClientRustModule.root.toType().resolve("Client"),
        )
    },
) : ProtocolTestGenerator {
) : ProtocolTestGenerator() {
    companion object {
        private val ExpectFail =
            setOf<FailingTest>(
                // Failing because we don't serialize default values if they match the default.
                FailingTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultsValuesWhenMissingInResponse", TestCaseKind.Request),
                FailingTest(AWS_JSON_10, "AwsJson10ClientUsesExplicitlyProvidedMemberValuesOverDefaults", TestCaseKind.Request),
                FailingTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultValuesInInput", TestCaseKind.Request),
            )
    }

    override val appliesTo: AppliesTo
        get() = AppliesTo.CLIENT
    override val expectFail: Set<FailingTest>
        get() = ExpectFail
    override val runOnly: Set<String>
        get() = emptySet()
    override val disabledTests: Set<String>
        get() = emptySet()

    override val logger: Logger = Logger.getLogger(javaClass.name)

    private val rc = codegenContext.runtimeConfig
    private val logger = Logger.getLogger(javaClass.name)

    private val inputShape = operationShape.inputShape(codegenContext.model)
    private val outputShape = operationShape.outputShape(codegenContext.model)
    private val operationSymbol = codegenContext.symbolProvider.toSymbol(operationShape)
    private val operationIndex = OperationIndex.of(codegenContext.model)

    private val instantiator = ClientInstantiator(codegenContext)

    private val codegenScope =
        arrayOf(
            "SmithyHttp" to RT.smithyHttp(rc),
            "AssertEq" to RT.PrettyAssertions.resolve("assert_eq!"),
            "Uri" to RT.Http.resolve("Uri"),
        )

    sealed class TestCase {
        abstract val testCase: HttpMessageTestCase

        data class RequestTest(override val testCase: HttpRequestTestCase) : TestCase()

        data class ResponseTest(override val testCase: HttpResponseTestCase, val targetShape: StructureShape) :
            TestCase()
    }

    override fun render(writer: RustWriter) {
        val requestTests =
            operationShape.getTrait<HttpRequestTestsTrait>()
                ?.getTestCasesFor(AppliesTo.CLIENT).orEmpty().map { TestCase.RequestTest(it) }
        val responseTests =
            operationShape.getTrait<HttpResponseTestsTrait>()
                ?.getTestCasesFor(AppliesTo.CLIENT).orEmpty().map { TestCase.ResponseTest(it, outputShape) }
        val errorTests =
            operationIndex.getErrors(operationShape).flatMap { error ->
                val testCases =
                    error.getTrait<HttpResponseTestsTrait>()
                        ?.getTestCasesFor(AppliesTo.CLIENT).orEmpty()
                testCases.map { TestCase.ResponseTest(it, error) }
            }
        val allTests: List<TestCase> = (requestTests + responseTests + errorTests).filterMatching()
        if (allTests.isNotEmpty()) {
            val operationName = operationSymbol.name
            val testModuleName = "${operationName.toSnakeCase()}_request_test"
            val additionalAttributes =
                listOf(
                    Attribute(allow("unreachable_code", "unused_variables")),
                )
            writer.withInlineModule(
                RustModule.inlineTests(testModuleName, additionalAttributes = additionalAttributes),
                null,
            ) {
                renderAllTestCases(allTests)
            }
        }
    }

    private fun RustWriter.renderAllTestCases(allTests: List<TestCase>) {
        allTests.forEach {
            renderTestCaseBlock(it.testCase, this) {
    override fun RustWriter.renderAllTestCases(allTests: List<TestCase>) {
        for (it in allTests) {
            renderTestCaseBlock(it, this) {
                when (it) {
                    is TestCase.RequestTest -> this.renderHttpRequestTestCase(it.testCase)
                    is TestCase.ResponseTest -> this.renderHttpResponseTestCase(it.testCase, it.targetShape)
                    is TestCase.MalformedRequestTest -> PANIC("Client protocol test generation does not support HTTP compliance test case type `$it`")
                }
            }
        }
    }

    /**
     * Filter out test cases that are disabled or don't match the service protocol
     */
    private fun List<TestCase>.filterMatching(): List<TestCase> {
        return if (RunOnly.isNullOrEmpty()) {
            this.filter { testCase ->
                testCase.testCase.protocol == codegenContext.protocol &&
                    !DisableTests.contains(testCase.testCase.id)
            }
        } else {
            this.filter { RunOnly.contains(it.testCase.id) }
        }
    }

    private fun renderTestCaseBlock(
        testCase: HttpMessageTestCase,
        testModuleWriter: RustWriter,
        block: Writable,
    ) {
        testModuleWriter.newlinePrefix = "/// "
        testCase.documentation.map {
            testModuleWriter.writeWithNoFormatting(it)
        }
        testModuleWriter.write("Test ID: ${testCase.id}")
        testModuleWriter.newlinePrefix = ""
        Attribute.TokioTest.render(testModuleWriter)
        val action =
            when (testCase) {
                is HttpResponseTestCase -> Action.Response
                is HttpRequestTestCase -> Action.Request
                else -> throw CodegenException("unknown test case type")
            }
        if (expectFail(testCase)) {
            testModuleWriter.writeWithNoFormatting("#[should_panic]")
        }
        val fnName =
            when (action) {
                is Action.Response -> "_response"
                is Action.Request -> "_request"
            }
        Attribute.AllowUnusedMut.render(testModuleWriter)
        testModuleWriter.rustBlock("async fn ${testCase.id.toSnakeCase()}$fnName()") {
            block(this)
        }
    }

    private fun RustWriter.renderHttpRequestTestCase(httpRequestTestCase: HttpRequestTestCase) {
        if (!protocolSupport.requestSerialization) {
            rust("/* test case disabled for this protocol (not yet supported) */")
@@ -276,18 +194,6 @@ class DefaultProtocolTestGenerator(
        }
    }

    private fun HttpMessageTestCase.action(): Action =
        when (this) {
            is HttpRequestTestCase -> Action.Request
            is HttpResponseTestCase -> Action.Response
            else -> throw CodegenException("Unknown test case type")
        }

    private fun expectFail(testCase: HttpMessageTestCase): Boolean =
        ExpectFail.find {
            it.id == testCase.id && it.action == testCase.action() && it.service == codegenContext.serviceShape.id.toString()
        } != null

    private fun RustWriter.renderHttpResponseTestCase(
        testCase: HttpResponseTestCase,
        expectedShape: StructureShape,
@@ -434,58 +340,6 @@ class DefaultProtocolTestGenerator(
        }
    }

    private fun checkRequiredHeaders(
        rustWriter: RustWriter,
        actualExpression: String,
        requireHeaders: List<String>,
    ) {
        basicCheck(
            requireHeaders,
            rustWriter,
            "required_headers",
            actualExpression,
            "require_headers",
        )
    }

    private fun checkForbidHeaders(
        rustWriter: RustWriter,
        actualExpression: String,
        forbidHeaders: List<String>,
    ) {
        basicCheck(
            forbidHeaders,
            rustWriter,
            "forbidden_headers",
            actualExpression,
            "forbid_headers",
        )
    }

    private fun checkHeaders(
        rustWriter: RustWriter,
        actualExpression: String,
        headers: Map<String, String>,
    ) {
        if (headers.isEmpty()) {
            return
        }
        val variableName = "expected_headers"
        rustWriter.withBlock("let $variableName = [", "];") {
            writeWithNoFormatting(
                headers.entries.joinToString(",") {
                    "(${it.key.dq()}, ${it.value.dq()})"
                },
            )
        }
        assertOk(rustWriter) {
            write(
                "#T($actualExpression, $variableName)",
                RT.protocolTest(rc, "validate_headers"),
            )
        }
    }

    private fun checkRequiredQueryParams(
        rustWriter: RustWriter,
        requiredParams: List<String>,
@@ -518,80 +372,4 @@ class DefaultProtocolTestGenerator(
        "&http_request",
        "validate_query_string",
    )

    private fun basicCheck(
        params: List<String>,
        rustWriter: RustWriter,
        expectedVariableName: String,
        actualExpression: String,
        checkFunction: String,
    ) {
        if (params.isEmpty()) {
            return
        }
        rustWriter.withBlock("let $expectedVariableName = ", ";") {
            strSlice(this, params)
        }
        assertOk(rustWriter) {
            write(
                "#T($actualExpression, $expectedVariableName)",
                RT.protocolTest(rc, checkFunction),
            )
        }
    }

    /**
     * wraps `inner` in a call to `aws_smithy_protocol_test::assert_ok`, a convenience wrapper
     * for pretty printing protocol test helper results
     */
    private fun assertOk(
        rustWriter: RustWriter,
        inner: Writable,
    ) {
        rustWriter.write("#T(", RT.protocolTest(rc, "assert_ok"))
        inner(rustWriter)
        rustWriter.write(");")
    }

    private fun strSlice(
        writer: RustWriter,
        args: List<String>,
    ) {
        writer.withBlock("&[", "]") {
            write(args.joinToString(",") { it.dq() })
        }
    }

    companion object {
        sealed class Action {
            object Request : Action()

            object Response : Action()
        }

        data class FailingTest(val service: String, val id: String, val action: Action)

        // These tests fail due to shortcomings in our implementation.
        // These could be configured via runtime configuration, but since this won't be long-lasting,
        // it makes sense to do the simplest thing for now.
        // The test will _fail_ if these pass, so we will discover & remove if we fix them by accident
        private val JsonRpc10 = "aws.protocoltests.json10#JsonRpc10"
        private val AwsJson11 = "aws.protocoltests.json#JsonProtocol"
        private val RestJson = "aws.protocoltests.restjson#RestJson"
        private val RestXml = "aws.protocoltests.restxml#RestXml"
        private val AwsQuery = "aws.protocoltests.query#AwsQuery"
        private val Ec2Query = "aws.protocoltests.ec2#AwsEc2"
        private val ExpectFail =
            setOf<FailingTest>(
                // Failing because we don't serialize default values if they match the default
                FailingTest(JsonRpc10, "AwsJson10ClientPopulatesDefaultsValuesWhenMissingInResponse", Action.Request),
                FailingTest(JsonRpc10, "AwsJson10ClientUsesExplicitlyProvidedMemberValuesOverDefaults", Action.Request),
                FailingTest(JsonRpc10, "AwsJson10ClientPopulatesDefaultValuesInInput", Action.Request),
            )
        private val RunOnly: Set<String>? = null

        // These tests are not even attempted to be generated, either because they will not compile
        // or because they are flaky
        private val DisableTests: Set<String> = setOf()
    }
}
Loading