Unverified Commit 7299cdd3 authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Improve broken protocol test generation (#3726)

We currently "hotfix" a broken protocol test in-memory, but there's no
mechanism that alerts us when the broken protocol test has been fixed
upstream when updating our Smithy version. This commit introduces such a
mechanism by generating both the original and the fixed test, with a
`#[should_panic]` attribute on the former, so that the test fails when
all its assertions succeed.

With this change, in general this approach of fixing tests in-memory
should now be used over adding the broken test to `expectFail` and
adding the fixed test to a `<protocol>-extras.smithy` Smithy model,
which is substantially more effort.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 24a011b9
Loading
Loading
Loading
Loading
+6 −4
Original line number Diff line number Diff line
@@ -23,12 +23,12 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.BrokenTest
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.FailingTest
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_10
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCaseKind
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
@@ -70,9 +70,9 @@ class ClientProtocolTestGenerator(
        private val ExpectFail =
            setOf<FailingTest>(
                // Failing because we don't serialize default values if they match the default.
                FailingTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultsValuesWhenMissingInResponse", TestCaseKind.Request),
                FailingTest(AWS_JSON_10, "AwsJson10ClientUsesExplicitlyProvidedMemberValuesOverDefaults", TestCaseKind.Request),
                FailingTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultValuesInInput", TestCaseKind.Request),
                FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultsValuesWhenMissingInResponse"),
                FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientUsesExplicitlyProvidedMemberValuesOverDefaults"),
                FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultValuesInInput"),
            )
    }

@@ -84,6 +84,8 @@ class ClientProtocolTestGenerator(
        get() = emptySet()
    override val disabledTests: Set<String>
        get() = emptySet()
    override val brokenTests: Set<BrokenTest>
        get() = emptySet()

    override val logger: Logger = Logger.getLogger(javaClass.name)

+203 −21
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.orNull
@@ -51,9 +52,17 @@ abstract class ProtocolTestGenerator {
    /**
     * We expect these tests to fail due to shortcomings in our implementation.
     * They will _fail_ if they pass, so we will discover and remove them if we fix them by accident.
     **/
     */
    abstract val expectFail: Set<FailingTest>

    /**
     * We expect these tests to fail because their definitions are broken.
     * We map from a failing test to a "hotfix" function that can mutate the test in-memory and return a fixed version of it.
     * The tests will _fail_ if they pass, so we will discover and remove the hotfix if we're updating to a newer
     * version of Smithy where the test was fixed upstream.
     */
    abstract val brokenTests: Set<BrokenTest>

    /** Only generate these tests; useful to temporarily set and shorten development cycles */
    abstract val runOnly: Set<String>

@@ -63,18 +72,23 @@ abstract class ProtocolTestGenerator {
     */
    abstract val disabledTests: Set<String>

    private val serviceShapeId: ShapeId
        get() = codegenContext.serviceShape.id

    /** The Rust module in which we should generate the protocol tests for [operationShape]. */
    private fun protocolTestsModule(): RustModule.LeafModule {
        val operationName = codegenContext.symbolProvider.toSymbol(operationShape).name
        val testModuleName = "${operationName.toSnakeCase()}_test"
        val additionalAttributes =
            listOf(Attribute(allow("unreachable_code", "unused_variables")))
        val additionalAttributes = listOf(Attribute(allow("unreachable_code", "unused_variables")))
        return RustModule.inlineTests(testModuleName, additionalAttributes = additionalAttributes)
    }

    /** The entry point to render the protocol tests, invoked by the code generators. */
    fun render(writer: RustWriter) {
        val allTests = allMatchingTestCases().fixBroken()
        val allTests =
            allMatchingTestCases().flatMap {
                fixBrokenTestCase(it)
            }
        if (allTests.isEmpty()) {
            return
        }
@@ -84,15 +98,65 @@ abstract class ProtocolTestGenerator {
        }
    }

    /** Implementors should describe how to render the test cases. **/
    abstract fun RustWriter.renderAllTestCases(allTests: List<TestCase>)

    /**
     * This function applies a "fix function" to each broken test before we synthesize it.
     * Broken tests are those whose definitions in the `awslabs/smithy` repository are wrong.
     * We try to contribute fixes upstream to pare down this function to the identity function.
     * This function applies a "hotfix function" to a broken test case before we synthesize it.
     * Broken tests are those whose definitions in the `smithy-lang/smithy` repository are wrong.
     * We try to contribute fixes upstream to pare down the list of broken tests.
     * If the test is broken, we synthesize it in two versions: the original broken test with a `#[should_panic]`
     * attribute, so get alerted if the test now passes, and the fixed version, which should pass.
     */
    open fun List<TestCase>.fixBroken(): List<TestCase> = this
    private fun fixBrokenTestCase(it: TestCase): List<TestCase> =
        if (!it.isBroken()) {
            listOf(it)
        } else {
            assert(it.expectFail())

            val brokenTest = it.findInBroken()!!
            var fixed = brokenTest.fixIt(it)

            val intro = "The hotfix function for broken test case ${it.kind} ${it.id}"
            val moreInfo =
                """This test case was identified to be broken in at least these Smithy versions: [${brokenTest.inAtLeast.joinToString()}].
                |We are tracking things here: [${brokenTest.trackedIn.joinToString()}].
                """.trimMargin()

            // Something must change...
            if (it == fixed) {
                PANIC(
                    """$intro did not make any modifications. It is likely that the test case was 
                    |fixed upstream, and you're now updating the Smithy version; in this case, remove the hotfix 
                    |function, as the test is no longer broken.
                    |$moreInfo
                    """.trimMargin(),
                )
            }

            // ... but the hotfix function is not allowed to change the test case kind...
            if (it.kind != fixed.kind) {
                PANIC(
                    """$intro changed the test case kind. This is not allowed.
                    |$moreInfo
                    """.trimMargin(),
                )
            }

            // ... nor its id.
            if (it.id != fixed.id) {
                PANIC(
                    """$intro changed the test case id. This is not allowed.
                    |$moreInfo
                    """.trimMargin(),
                )
            }

            // The latter is because we're going to generate the fixed version with an identifiable suffix.
            fixed = fixed.suffixIdWith("_hotfixed")

            listOf(it, fixed)
        }

    /** Implementors should describe how to render the test cases. **/
    abstract fun RustWriter.renderAllTestCases(allTests: List<TestCase>)

    /** Filter out test cases that are disabled or don't match the service protocol. */
    private fun List<TestCase>.filterMatching(): List<TestCase> =
@@ -103,11 +167,25 @@ abstract class ProtocolTestGenerator {
            this.filter { testCase -> runOnly.contains(testCase.id) }
        }

    /** Do we expect this [testCase] to fail? */
    private fun expectFail(testCase: TestCase): Boolean =
        expectFail.find {
            it.id == testCase.id && it.kind == testCase.kind && it.service == codegenContext.serviceShape.id.toString()
        } != null
    private fun TestCase.toFailingTest(): FailingTest =
        when (this) {
            is TestCase.MalformedRequestTest -> FailingTest.MalformedRequestTest(serviceShapeId.toString(), this.id)
            is TestCase.RequestTest -> FailingTest.RequestTest(serviceShapeId.toString(), this.id)
            is TestCase.ResponseTest -> FailingTest.ResponseTest(serviceShapeId.toString(), this.id)
        }

    /** Do we expect this test case to fail? */
    private fun TestCase.expectFail(): Boolean = this.isBroken() || expectFail.contains(this.toFailingTest())

    /** Is this test case broken? */
    private fun TestCase.isBroken(): Boolean = this.findInBroken() != null

    private fun TestCase.findInBroken(): BrokenTest? =
        brokenTests.find { brokenTest ->
            (this is TestCase.RequestTest && brokenTest is BrokenTest.RequestTest && this.id == brokenTest.id) ||
                (this is TestCase.ResponseTest && brokenTest is BrokenTest.ResponseTest && this.id == brokenTest.id) ||
                (this is TestCase.MalformedRequestTest && brokenTest is BrokenTest.MalformedRequestTest && this.id == brokenTest.id)
        }

    fun requestTestCases(): List<TestCase> {
        val requestTests =
@@ -160,6 +238,7 @@ abstract class ProtocolTestGenerator {
        block: Writable,
    ) {
        if (testCase.documentation != null) {
            testModuleWriter.rust("")
            testModuleWriter.docs(testCase.documentation!!, templating = false)
        }
        testModuleWriter.docs("Test ID: ${testCase.id}")
@@ -171,7 +250,7 @@ abstract class ProtocolTestGenerator {
        Attribute.TokioTest.render(testModuleWriter)
        Attribute.TracedTest.render(testModuleWriter)

        if (expectFail(testCase)) {
        if (testCase.expectFail()) {
            shouldPanic().render(testModuleWriter)
        }
        val fnNameSuffix =
@@ -281,6 +360,51 @@ abstract class ProtocolTestGenerator {
    }
}

sealed class BrokenTest(
    open val serviceShapeId: String,
    open val id: String,
    /** A non-exhaustive set of Smithy versions where the test was found to be broken. */
    open val inAtLeast: Set<String>,
    /**
     * GitHub URLs related to the test brokenness, like a GitHub issue in Smithy where we reported the test was broken,
     * or a PR where we fixed it.
     **/
    open val trackedIn: Set<String>,
) {
    data class RequestTest(
        override val serviceShapeId: String,
        override val id: String,
        override val inAtLeast: Set<String>,
        override val trackedIn: Set<String>,
        val howToFixItFn: (TestCase.RequestTest) -> TestCase.RequestTest,
    ) : BrokenTest(serviceShapeId, id, inAtLeast, trackedIn)

    data class ResponseTest(
        override val serviceShapeId: String,
        override val id: String,
        override val inAtLeast: Set<String>,
        override val trackedIn: Set<String>,
        val howToFixItFn: (TestCase.ResponseTest) -> TestCase.ResponseTest,
    ) : BrokenTest(serviceShapeId, id, inAtLeast, trackedIn)

    data class MalformedRequestTest(
        override val serviceShapeId: String,
        override val id: String,
        override val inAtLeast: Set<String>,
        override val trackedIn: Set<String>,
        val howToFixItFn: (TestCase.MalformedRequestTest) -> TestCase.MalformedRequestTest,
    ) : BrokenTest(serviceShapeId, id, inAtLeast, trackedIn)

    fun fixIt(testToFix: TestCase): TestCase {
        check(testToFix.id == this.id)
        return when (this) {
            is MalformedRequestTest -> howToFixItFn(testToFix as TestCase.MalformedRequestTest)
            is RequestTest -> howToFixItFn(testToFix as TestCase.RequestTest)
            is ResponseTest -> howToFixItFn(testToFix as TestCase.ResponseTest)
        }
    }
}

/**
 * Service shape IDs in common protocol test suites defined upstream.
 */
@@ -291,7 +415,16 @@ object ServiceShapeId {
    const val REST_JSON_VALIDATION = "aws.protocoltests.restjson.validation#RestJsonValidation"
}

data class FailingTest(val service: String, val id: String, val kind: TestCaseKind)
sealed class FailingTest(open val serviceShapeId: String, open val id: String) {
    data class RequestTest(override val serviceShapeId: String, override val id: String) :
        FailingTest(serviceShapeId, id)

    data class ResponseTest(override val serviceShapeId: String, override val id: String) :
        FailingTest(serviceShapeId, id)

    data class MalformedRequestTest(override val serviceShapeId: String, override val id: String) :
        FailingTest(serviceShapeId, id)
}

sealed class TestCaseKind {
    data object Request : TestCaseKind()
@@ -302,11 +435,60 @@ sealed class TestCaseKind {
}

sealed class TestCase {
    data class RequestTest(val testCase: HttpRequestTestCase) : TestCase()
    /*
     * The properties of these data classes don't implement `equals()` usefully in Smithy, so we delegate to `equals()`
     * of their `Node` representations.
     */

    data class RequestTest(val testCase: HttpRequestTestCase) : TestCase() {
        override fun equals(other: Any?): Boolean {
            if (this === other) return true
            if (other !is RequestTest) return false
            return testCase.toNode().equals(other.testCase.toNode())
        }

        override fun hashCode(): Int = testCase.hashCode()
    }

    data class ResponseTest(val testCase: HttpResponseTestCase, val targetShape: StructureShape) : TestCase() {
        override fun equals(other: Any?): Boolean {
            if (this === other) return true
            if (other !is ResponseTest) return false
            return testCase.toNode().equals(other.testCase.toNode())
        }

        override fun hashCode(): Int = testCase.hashCode()
    }

    data class MalformedRequestTest(val testCase: HttpMalformedRequestTestCase) : TestCase() {
        override fun equals(other: Any?): Boolean {
            if (this === other) return true
            if (other !is MalformedRequestTest) return false
            return this.protocol == other.protocol && this.id == other.id && this.documentation == other.documentation &&
                this.testCase.request.toNode()
                    .equals(other.testCase.request.toNode()) &&
                this.testCase.response.toNode()
                    .equals(other.testCase.response.toNode())
        }

        override fun hashCode(): Int = testCase.hashCode()
    }

    fun suffixIdWith(suffix: String): TestCase =
        when (this) {
            is RequestTest -> RequestTest(this.testCase.suffixIdWith(suffix))
            is MalformedRequestTest -> MalformedRequestTest(this.testCase.suffixIdWith(suffix))
            is ResponseTest -> ResponseTest(this.testCase.suffixIdWith(suffix), this.targetShape)
        }

    private fun HttpRequestTestCase.suffixIdWith(suffix: String): HttpRequestTestCase =
        this.toBuilder().id(this.id + suffix).build()

    data class ResponseTest(val testCase: HttpResponseTestCase, val targetShape: StructureShape) : TestCase()
    private fun HttpResponseTestCase.suffixIdWith(suffix: String): HttpResponseTestCase =
        this.toBuilder().id(this.id + suffix).build()

    data class MalformedRequestTest(val testCase: HttpMalformedRequestTestCase) : TestCase()
    private fun HttpMalformedRequestTestCase.suffixIdWith(suffix: String): HttpMalformedRequestTestCase =
        this.toBuilder().id(this.id + suffix).build()

    /*
     * `HttpRequestTestCase` and `HttpResponseTestCase` both implement `HttpMessageTestCase`, but
+93 −107

File changed.

Preview size limit exceeded, changes collapsed.