Unverified Commit 507ebe8a authored by Zelda Hessler's avatar Zelda Hessler Committed by GitHub
Browse files

v2 Smoketest codegen (#3758)

This PR adds codegen for service-defined smoketests.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
parent 1223f61e
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -27,6 +27,8 @@ dependencies {
    implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
    implementation("software.amazon.smithy:smithy-rules-engine:$smithyVersion")
    implementation("software.amazon.smithy:smithy-aws-endpoints:$smithyVersion")
    implementation("software.amazon.smithy:smithy-smoke-test-traits:$smithyVersion")
    implementation("software.amazon.smithy:smithy-aws-smoke-test-model:$smithyVersion")
}

java {
+1 −0
Original line number Diff line number Diff line
@@ -63,6 +63,7 @@ val DECORATORS: List<ClientCodegenDecorator> =
            TokenProvidersDecorator(),
            ServiceEnvConfigDecorator(),
            HttpRequestCompressionDecorator(),
            SmokeTestsDecorator(),
        ),
        // S3 needs `AwsErrorCodeClassifier` to handle an `InternalError` as a transient error. We need to customize
        // that behavior for S3 in a way that does not conflict with the globally applied `RetryClassifierDecorator`.
+219 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rustsdk

import software.amazon.smithy.aws.smoketests.model.AwsSmokeTestModel
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.cfg
import software.amazon.smithy.rust.codegen.core.rustlang.AttributeKind
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.containerDocs
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.inputShape
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.smoketests.traits.Expectation
import software.amazon.smithy.smoketests.traits.SmokeTestCase
import software.amazon.smithy.smoketests.traits.SmokeTestsTrait
import java.util.Optional
import java.util.logging.Logger

class SmokeTestsDecorator : ClientCodegenDecorator {
    override val name: String = "SmokeTests"
    override val order: Byte = 0
    private val logger: Logger = Logger.getLogger(javaClass.name)

    private fun isSmokeTestSupported(smokeTestCase: SmokeTestCase): Boolean {
        AwsSmokeTestModel.getAwsVendorParams(smokeTestCase)?.orNull()?.let { vendorParams ->
            if (vendorParams.sigv4aRegionSet.isPresent) {
                logger.warning("skipping smoketest `${smokeTestCase.id}` with unsupported vendorParam `sigv4aRegionSet`")
                return false
            }
            // TODO(https://github.com/smithy-lang/smithy-rs/issues/3776) Once Account ID routing is supported,
            //     update the vendorParams setter and remove this check.
            if (vendorParams.useAccountIdRouting()) {
                logger.warning("skipping smoketest `${smokeTestCase.id}` with unsupported vendorParam `useAccountIdRouting`")
                return false
            }
        }
        AwsSmokeTestModel.getS3VendorParams(smokeTestCase)?.orNull()?.let { s3VendorParams ->
            if (s3VendorParams.useGlobalEndpoint()) {
                logger.warning("skipping smoketest `${smokeTestCase.id}` with unsupported vendorParam `useGlobalEndpoint`")
                return false
            }
        }

        return true
    }

    override fun extras(
        codegenContext: ClientCodegenContext,
        rustCrate: RustCrate,
    ) {
        // Get all operations with smoke tests
        val smokeTestedOperations =
            codegenContext.model.getOperationShapesWithTrait(SmokeTestsTrait::class.java).toList()
        val supportedTests =
            smokeTestedOperations.map { operationShape ->
                // filter out unsupported smoke tests, logging a warning for each one.
                val testCases =
                    operationShape.expectTrait<SmokeTestsTrait>().testCases.filter { smokeTestCase ->
                        isSmokeTestSupported(smokeTestCase)
                    }

                operationShape to testCases
            }
                // filter out operations with no supported smoke tests
                .filter { (_, testCases) -> testCases.isNotEmpty() }
        // Return if there are no supported smoke tests across all operations
        if (supportedTests.isEmpty()) return

        rustCrate.integrationTest("smoketests") {
            // Don't run the tests in this module unless `RUSTFLAGS="--cfg smoketests"` is passed.
            Attribute(cfg("smoketests")).render(this, AttributeKind.Inner)

            containerDocs(
                """
                The tests in this module run against live AWS services. As such,
                they are disabled by default. To enable them, run the tests with

                ```sh
                RUSTFLAGS="--cfg smoketests" cargo test.
                ```""",
            )

            val model = codegenContext.model
            val moduleUseName = codegenContext.moduleUseName()
            rust("use $moduleUseName::{ Client, config };")

            for ((operationShape, testCases) in supportedTests) {
                val operationName = operationShape.id.name.toSnakeCase()
                val operationInput = operationShape.inputShape(model)

                docs("Smoke tests for the `$operationName` operation")

                for (testCase in testCases) {
                    Attribute.TokioTest.render(this)
                    this.rustBlock("async fn test_${testCase.id.toSnakeCase()}()") {
                        val instantiator = SmokeTestsInstantiator(codegenContext)
                        instantiator.renderConf(this, testCase)
                        rust("let client = Client::from_conf(conf);")
                        instantiator.renderInput(this, operationShape, operationInput, testCase.params)
                        instantiator.renderExpectation(this, model, testCase.expectation)
                    }
                }
            }
        }
    }
}

class SmokeTestsBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiator.BuilderKindBehavior {
    override fun hasFallibleBuilder(shape: StructureShape): Boolean =
        BuilderGenerator.hasFallibleBuilder(shape, codegenContext.symbolProvider)

    override fun setterName(memberShape: MemberShape): String = memberShape.setterName()

    override fun doesSetterTakeInOption(memberShape: MemberShape): Boolean = true
}

class SmokeTestsInstantiator(private val codegenContext: ClientCodegenContext) : Instantiator(
    codegenContext.symbolProvider,
    codegenContext.model,
    codegenContext.runtimeConfig,
    SmokeTestsBuilderKindBehavior(codegenContext),
) {
    fun renderConf(
        writer: RustWriter,
        testCase: SmokeTestCase,
    ) {
        writer.rust("let conf = config::Builder::new()")
        writer.indent()
        writer.rust(".behavior_version(config::BehaviorVersion::latest())")

        val vendorParams = AwsSmokeTestModel.getAwsVendorParams(testCase)
        vendorParams.orNull()?.let { params ->
            writer.rust(".region(config::Region::new(${params.region.dq()}))")
            writer.rust(".use_dual_stack(${params.useDualstack()})")
            writer.rust(".use_fips(${params.useFips()})")
            params.uri.orNull()?.let { writer.rust(".endpoint_url($it)") }
        }

        val s3VendorParams = AwsSmokeTestModel.getS3VendorParams(testCase)
        s3VendorParams.orNull()?.let { params ->
            writer.rust(".accelerate_(${params.useAccelerate()})")
            writer.rust(".force_path_style_(${params.forcePathStyle()})")
            writer.rust(".use_arn_region(${params.useArnRegion()})")
            writer.rust(".disable_multi_region_access_points(${params.useMultiRegionAccessPoints().not()})")
        }

        writer.rust(".build();")
        writer.dedent()
    }

    fun renderInput(
        writer: RustWriter,
        operationShape: OperationShape,
        inputShape: StructureShape,
        data: Optional<ObjectNode>,
        headers: Map<String, String> = mapOf(),
        ctx: Ctx = Ctx(),
    ) {
        val operationBuilderName =
            FluentClientGenerator.clientOperationFnName(operationShape, codegenContext.symbolProvider)

        writer.rust("let res = client.$operationBuilderName()")
        writer.indent()
        data.orNull()?.let {
            renderStructureMembers(writer, inputShape, it, headers, ctx)
        }
        writer.rust(".send().await;")
        writer.dedent()
    }

    fun renderExpectation(
        writer: RustWriter,
        model: Model,
        expectation: Expectation,
    ) {
        if (expectation.isSuccess) {
            writer.rust("""res.expect("request should succeed");""")
        } else if (expectation.isFailure) {
            val expectedErrShape = expectation.failure.orNull()?.errorId?.orNull()
            println(expectedErrShape)
            if (expectedErrShape != null) {
                val failureShape = model.expectShape(expectedErrShape)
                val errName = codegenContext.symbolProvider.toSymbol(failureShape).name.toSnakeCase()
                writer.rust(
                    """
                    let err = res.expect_err("request should fail");
                    let err = err.into_service_error();
                    assert!(err.is_$errName())
                    """,
                )
            } else {
                writer.rust("""res.expect_err("request should fail");""")
            }
        }
    }
}
+96 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rustsdk

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel

class SmokeTestsDecoratorTest {
    companion object {
        // Can't use the dollar sign in a multiline string with doing it like this.
        private const val PREFIX = "\$version: \"2\""
        val model =
            """
            $PREFIX
            namespace test

            use aws.api#service
            use smithy.test#smokeTests
            use aws.auth#sigv4
            use aws.protocols#restJson1
            use smithy.rules#endpointRuleSet

            @service(sdkId: "dontcare")
            @restJson1
            @sigv4(name: "dontcare")
            @auth([sigv4])
            @endpointRuleSet({
                "version": "1.0",
                "rules": [{ "type": "endpoint", "conditions": [], "endpoint": { "url": "https://example.com" } }],
                "parameters": {
                    "Region": { "required": false, "type": "String", "builtIn": "AWS::Region" },
                }
            })
            service TestService {
                version: "2023-01-01",
                operations: [SomeOperation]
            }

            @smokeTests([
                {
                    id: "SomeOperationSuccess",
                    params: {}
                    vendorParams: {
                        region: "us-west-2"
                    }
                    expect: { success: {} }
                }
                {
                    id: "SomeOperationFailure",
                    params: {}
                    vendorParams: {
                        region: "us-west-2"
                    }
                    expect: { failure: {} }
                }
                {
                    id: "SomeOperationFailureExplicitShape",
                    params: {}
                    vendorParams: {
                        region: "us-west-2"
                    }
                    expect: {
                        failure: { errorId: FooException }
                    }
                }
            ])
            @http(uri: "/SomeOperation", method: "POST")
            @optionalAuth
            operation SomeOperation {
                input: SomeInput,
                output: SomeOutput,
                errors: [FooException]
            }

            @input
            structure SomeInput {}

            @output
            structure SomeOutput {}

            @error("server")
            structure FooException { }
            """.asSmithyModel()
    }

    @Test
    fun smokeTestSdkCodegen() {
        awsSdkIntegrationTest(model) { _, _ ->
            // It should compile. We can't run the tests
            // because they don't target a real service.
        }
    }
}