Unverified Commit 086d9654 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Refactor client/server logic out of `EnumGenerator` (#2334)

* Refactor client/server enum generation logic into new `EnumType`
* Move client enum logic into `codegen-client`
parent bae93805
Loading
Loading
Loading
Loading
+3 −4
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientEnumGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientProtocolLoader
@@ -29,7 +30,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.OperationErrorGenerator
@@ -43,7 +43,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveSha
import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors
import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.core.util.CommandFailed
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isEventStream
import software.amazon.smithy.rust.codegen.core.util.letIf
@@ -206,9 +205,9 @@ class ClientCodegenVisitor(
     * Although raw strings require no code generation, enums are actually `EnumTrait` applied to string shapes.
     */
    override fun stringShape(shape: StringShape) {
        shape.getTrait<EnumTrait>()?.also { enum ->
        if (shape.hasTrait<EnumTrait>()) {
            rustCrate.useShapeWriter(shape) {
                EnumGenerator(model, symbolProvider, this, shape, enum).render()
                ClientEnumGenerator(codegenContext, shape).render(this)
            }
        }
    }
+170 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGeneratorContext
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumMemberModel
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumType
import software.amazon.smithy.rust.codegen.core.util.dq

/** Infallible enums have an `Unknown` variant and can't fail to parse */
data class InfallibleEnumType(
    val unknownVariantModule: RustModule,
) : EnumType() {
    companion object {
        /** Name of the generated unknown enum member name for enums with named members. */
        const val UnknownVariant = "Unknown"

        /** Name of the opaque struct that is inner data for the generated [UnknownVariant]. */
        const val UnknownVariantValue = "UnknownVariantValue"
    }

    override fun implFromForStr(context: EnumGeneratorContext): Writable = writable {
        rustTemplate(
            """
            impl #{From}<&str> for ${context.enumName} {
                fn from(s: &str) -> Self {
                    match s {
                        #{matchArms}
                    }
                }
            }
            """,
            "From" to RuntimeType.From,
            "matchArms" to writable {
                context.sortedMembers.forEach { member ->
                    rust("${member.value.dq()} => ${context.enumName}::${member.derivedName()},")
                }
                rust(
                    "other => ${context.enumName}::$UnknownVariant(#T(other.to_owned()))",
                    unknownVariantValue(context),
                )
            },
        )
    }

    override fun implFromStr(context: EnumGeneratorContext): Writable = writable {
        rust(
            """
            impl std::str::FromStr for ${context.enumName} {
                type Err = std::convert::Infallible;

                fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
                    Ok(${context.enumName}::from(s))
                }
            }
            """,
        )
    }

    override fun additionalDocs(context: EnumGeneratorContext): Writable = writable {
        renderForwardCompatibilityNote(context.enumName, context.sortedMembers, UnknownVariant, UnknownVariantValue)
    }

    override fun additionalEnumMembers(context: EnumGeneratorContext): Writable = writable {
        docs("`$UnknownVariant` contains new variants that have been added since this code was generated.")
        rust("$UnknownVariant(#T)", unknownVariantValue(context))
    }

    override fun additionalAsStrMatchArms(context: EnumGeneratorContext): Writable = writable {
        rust("${context.enumName}::$UnknownVariant(value) => value.as_str()")
    }

    private fun unknownVariantValue(context: EnumGeneratorContext): RuntimeType {
        return RuntimeType.forInlineFun(UnknownVariantValue, RustModule.Types) {
            docs(
                """
                Opaque struct used as inner data for the `Unknown` variant defined in enums in
                the crate

                While this is not intended to be used directly, it is marked as `pub` because it is
                part of the enums that are public interface.
                """.trimIndent(),
            )
            context.enumMeta.render(this)
            rust("struct $UnknownVariantValue(pub(crate) String);")
            rustBlock("impl $UnknownVariantValue") {
                // The generated as_str is not pub as we need to prevent users from calling it on this opaque struct.
                rustBlock("pub(crate) fn as_str(&self) -> &str") {
                    rust("&self.0")
                }
            }
        }
    }

    /**
     * Generate the rustdoc describing how to write a match expression against a generated enum in a
     * forward-compatible way.
     */
    private fun RustWriter.renderForwardCompatibilityNote(
        enumName: String, sortedMembers: List<EnumMemberModel>,
        unknownVariant: String, unknownVariantValue: String,
    ) {
        docs(
            """
            When writing a match expression against `$enumName`, it is important to ensure
            your code is forward-compatible. That is, if a match arm handles a case for a
            feature that is supported by the service but has not been represented as an enum
            variant in a current version of SDK, your code should continue to work when you
            upgrade SDK to a future version in which the enum does include a variant for that
            feature.
            """.trimIndent(),
        )
        docs("")
        docs("Here is an example of how you can make a match expression forward-compatible:")
        docs("")
        docs("```text")
        rust("/// ## let ${enumName.lowercase()} = unimplemented!();")
        rust("/// match ${enumName.lowercase()} {")
        sortedMembers.mapNotNull { it.name() }.forEach { member ->
            rust("///     $enumName::${member.name} => { /* ... */ },")
        }
        rust("""///     other @ _ if other.as_str() == "NewFeature" => { /* handles a case for `NewFeature` */ },""")
        rust("///     _ => { /* ... */ },")
        rust("/// }")
        docs("```")
        docs(
            """
            The above code demonstrates that when `${enumName.lowercase()}` represents
            `NewFeature`, the execution path will lead to the second last match arm,
            even though the enum does not contain a variant `$enumName::NewFeature`
            in the current version of SDK. The reason is that the variable `other`,
            created by the `@` operator, is bound to
            `$enumName::$unknownVariant($unknownVariantValue("NewFeature".to_owned()))`
            and calling `as_str` on it yields `"NewFeature"`.
            This match expression is forward-compatible when executed with a newer
            version of SDK where the variant `$enumName::NewFeature` is defined.
            Specifically, when `${enumName.lowercase()}` represents `NewFeature`,
            the execution path will hit the second last match arm as before by virtue of
            calling `as_str` on `$enumName::NewFeature` also yielding `"NewFeature"`.
            """.trimIndent(),
        )
        docs("")
        docs(
            """
            Explicitly matching on the `$unknownVariant` variant should
            be avoided for two reasons:
            - The inner data `$unknownVariantValue` is opaque, and no further information can be extracted.
            - It might inadvertently shadow other intended match arms.
            """.trimIndent(),
        )
    }
}

class ClientEnumGenerator(codegenContext: CodegenContext, shape: StringShape) :
    EnumGenerator(codegenContext.model, codegenContext.symbolProvider, shape, InfallibleEnumType(RustModule.Types))
+161 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rust.codegen.client.smithy.generators

import org.junit.jupiter.api.Test
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.rust.codegen.client.testutil.testCodegenContext
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.lookup

class ClientEnumGeneratorTest {
    @Test
    fun `matching on enum should be forward-compatible`() {
        fun expectMatchExpressionCompiles(model: Model, shapeId: String, enumToMatchOn: String) {
            val shape = model.lookup<StringShape>(shapeId)
            val context = testCodegenContext(model)
            val project = TestWorkspace.testProject(context.symbolProvider)
            project.withModule(RustModule.Model) {
                ClientEnumGenerator(context, shape).render(this)
                unitTest(
                    "matching_on_enum_should_be_forward_compatible",
                    """
                    match $enumToMatchOn {
                        SomeEnum::Variant1 => assert!(false, "expected `Variant3` but got `Variant1`"),
                        SomeEnum::Variant2 => assert!(false, "expected `Variant3` but got `Variant2`"),
                        other @ _ if other.as_str() == "Variant3" => assert!(true),
                        _ => assert!(false, "expected `Variant3` but got `_`"),
                    }
                    """.trimIndent(),
                )
            }
            project.compileAndTest()
        }

        val modelV1 = """
            namespace test

            @enum([
                { name: "Variant1", value: "Variant1" },
                { name: "Variant2", value: "Variant2" },
            ])
            string SomeEnum
        """.asSmithyModel()
        val variant3AsUnknown = """SomeEnum::from("Variant3")"""
        expectMatchExpressionCompiles(modelV1, "test#SomeEnum", variant3AsUnknown)

        val modelV2 = """
            namespace test

            @enum([
                { name: "Variant1", value: "Variant1" },
                { name: "Variant2", value: "Variant2" },
                { name: "Variant3", value: "Variant3" },
            ])
            string SomeEnum
        """.asSmithyModel()
        val variant3AsVariant3 = "SomeEnum::Variant3"
        expectMatchExpressionCompiles(modelV2, "test#SomeEnum", variant3AsVariant3)
    }

    @Test
    fun `impl debug for non-sensitive enum should implement the derived debug trait`() {
        val model = """
            namespace test
            @enum([
                { name: "Foo", value: "Foo" },
                { name: "Bar", value: "Bar" },
            ])
            string SomeEnum
        """.asSmithyModel()

        val shape = model.lookup<StringShape>("test#SomeEnum")
        val context = testCodegenContext(model)
        val project = TestWorkspace.testProject(context.symbolProvider)
        project.withModule(RustModule.Model) {
            ClientEnumGenerator(context, shape).render(this)
            unitTest(
                "impl_debug_for_non_sensitive_enum_should_implement_the_derived_debug_trait",
                """
                assert_eq!(format!("{:?}", SomeEnum::Foo), "Foo");
                assert_eq!(format!("{:?}", SomeEnum::Bar), "Bar");
                assert_eq!(
                    format!("{:?}", SomeEnum::from("Baz")),
                    "Unknown(UnknownVariantValue(\"Baz\"))"
                );
                """,
            )
        }
        project.compileAndTest()
    }

    @Test
    fun `it escapes the Unknown variant if the enum has an unknown value in the model`() {
        val model = """
            namespace test
            @enum([
                { name: "Known", value: "Known" },
                { name: "Unknown", value: "Unknown" },
                { name: "UnknownValue", value: "UnknownValue" },
            ])
            string SomeEnum
        """.asSmithyModel()

        val shape = model.lookup<StringShape>("test#SomeEnum")
        val context = testCodegenContext(model)
        val project = TestWorkspace.testProject(context.symbolProvider)
        project.withModule(RustModule.Model) {
            ClientEnumGenerator(context, shape).render(this)
            unitTest(
                "it_escapes_the_unknown_variant_if_the_enum_has_an_unknown_value_in_the_model",
                """
                assert_eq!(SomeEnum::from("Unknown"), SomeEnum::UnknownValue);
                assert_eq!(SomeEnum::from("UnknownValue"), SomeEnum::UnknownValue_);
                assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(crate::types::UnknownVariantValue("SomethingNew".to_owned())));
                """.trimIndent(),
            )
        }
        project.compileAndTest()
    }

    @Test
    fun `generated named enums can roundtrip between string and enum value on the unknown variant`() {
        val model = """
            namespace test
            @enum([
                { value: "t2.nano", name: "T2_NANO" },
                { value: "t2.micro", name: "T2_MICRO" },
            ])
            string InstanceType
        """.asSmithyModel()

        val shape = model.lookup<StringShape>("test#InstanceType")
        val context = testCodegenContext(model)
        val project = TestWorkspace.testProject(context.symbolProvider)
        project.withModule(RustModule.Model) {
            rust("##![allow(deprecated)]")
            ClientEnumGenerator(context, shape).render(this)
            unitTest(
                "generated_named_enums_roundtrip",
                """
                let instance = InstanceType::T2Micro;
                assert_eq!(instance.as_str(), "t2.micro");
                assert_eq!(InstanceType::from("t2.nano"), InstanceType::T2Nano);
                // round trip unknown variants:
                assert_eq!(InstanceType::from("other"), InstanceType::Unknown(crate::types::UnknownVariantValue("other".to_owned())));
                assert_eq!(InstanceType::from("other").as_str(), "other");
                """,
            )
        }
        project.compileAndTest()
    }
}
+2 −4
Original line number Diff line number Diff line
@@ -12,13 +12,11 @@ import software.amazon.smithy.rust.codegen.client.testutil.testCodegenContext
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator
import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.lookup

internal class ClientInstantiatorTest {
@@ -55,7 +53,7 @@ internal class ClientInstantiatorTest {

        val project = TestWorkspace.testProject()
        project.withModule(RustModule.Model) {
            EnumGenerator(model, symbolProvider, this, shape, shape.expectTrait()).render()
            ClientEnumGenerator(codegenContext, shape).render(this)
            unitTest("generate_named_enums") {
                withBlock("let result = ", ";") {
                    sut.render(this, shape, data)
@@ -74,7 +72,7 @@ internal class ClientInstantiatorTest {

        val project = TestWorkspace.testProject()
        project.withModule(RustModule.Model) {
            EnumGenerator(model, symbolProvider, this, shape, shape.expectTrait()).render()
            ClientEnumGenerator(codegenContext, shape).render(this)
            unitTest("generate_unnamed_enums") {
                withBlock("let result = ", ";") {
                    sut.render(this, shape, data)
+137 −204

File changed.

Preview size limit exceeded, changes collapsed.

Loading