Unverified Commit 9e05d1f0 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Escape enum variants to avoid collision with generated Unknown variant (#388)

* Escape enum variants to avoid collision with generated Unknown variant

* Output enum documentation with unknown variant renaming note

* CR feedback
parent 3eb53818
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -92,3 +92,10 @@ structure CollidingException {
// Fixing this is more refactoring than I want to get into right now
// @error("client")
// structure ErrCollisionsException { }

// The "Unknown" value on this enum collides with the code generated "Unknown" variant used for backwards compatibility
@enum([
    { name: "Known", value: "Known" },
    { name: "Unknown", value: "Unknown" },
])
string UnknownVariantCollidingEnum
+1 −1
Original line number Diff line number Diff line
@@ -117,7 +117,7 @@ class CodegenVisitor(context: PluginContext, private val codegenDecorator: RustC
    override fun stringShape(shape: StringShape) {
        shape.getTrait<EnumTrait>()?.also { enum ->
            rustCrate.useShapeWriter(shape) { writer ->
                EnumGenerator(symbolProvider, writer, shape, enum).render()
                EnumGenerator(model, symbolProvider, writer, shape, enum).render()
            }
        }
    }
+81 −21
Original line number Diff line number Diff line
@@ -6,10 +6,14 @@
package software.amazon.smithy.rust.codegen.smithy.generators

import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.traits.DocumentationTrait
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.docs
import software.amazon.smithy.rust.codegen.rustlang.documentShape
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
@@ -17,21 +21,78 @@ import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.util.doubleQuote
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.orNull
import software.amazon.smithy.rust.codegen.util.toPascalCase
import java.lang.IllegalStateException

/** Model that wraps [EnumDefinition] to calculate and cache values required to generate the Rust enum source. */
internal class EnumMemberModel(private val definition: EnumDefinition) {
    // Because enum variants always start with an upper case letter, they will never
    // conflict with reserved words (which are always lower case), therefore, we never need
    // to fall back to raw identifiers
    private val unescapedName: String? = definition.name.orNull()?.toPascalCase()

    val collidesWithUnknown: Boolean = unescapedName == EnumGenerator.UnknownVariant

    /** Enum name with correct case format and collision resolution */
    fun derivedName(): String = when (collidesWithUnknown) {
        // If there is a variant named "Unknown", then rename it to "UnknownValue" so that it
        // doesn't conflict with the code generator's "Unknown" variant that exists for backwards compatibility.
        true -> "UnknownValue"
        else -> checkNotNull(unescapedName) { "Enum variants must be named to derive a name. This is a bug." }
    }

    val value: String get() = definition.value

    private fun renderDocumentation(writer: RustWriter) {
        writer.docWithNote(
            definition.documentation.orNull(),
            when (collidesWithUnknown) {
                true ->
                    "`::${EnumGenerator.UnknownVariant}` has been renamed to `::${EnumGenerator.EscapedUnknownVariant}`. " +
                        "`::${EnumGenerator.UnknownVariant}` refers to additional values that may have been added since " +
                        "this enum was generated."
                else -> null
            }
        )
    }

    fun render(writer: RustWriter) {
        renderDocumentation(writer)
        writer.write("${derivedName()},")
    }
}

private fun RustWriter.docWithNote(doc: String?, note: String?) {
    doc?.also { docs(it) }
    note?.also {
        // Add a blank line between the docs and the note to visually differentiate
        doc?.also { write("///") }
        docs("**NOTE:** $it")
    }
}

class EnumGenerator(
    private val model: Model,
    symbolProvider: SymbolProvider,
    private val writer: RustWriter,
    shape: StringShape,
    private val shape: StringShape,
    private val enumTrait: EnumTrait
) {
    private val sortedMembers: List<EnumDefinition> = enumTrait.values.sortedBy { it.value }
    private val symbol = symbolProvider.toSymbol(shape)
    private val enumName = symbol.name
    private val meta = symbol.expectRustMetadata()
    private val sortedMembers: List<EnumMemberModel> = enumTrait.values.sortedBy { it.value }.map(::EnumMemberModel)

    companion object {
        /**
         * For enums with named members, variants with names that collide with the generated unknown enum
         * member get renamed to this [EscapedUnknownVariant] value.
         */
        const val EscapedUnknownVariant = "UnknownValue"
        /** Name of the generated unknown enum member name for enums with named members. */
        const val UnknownVariant = "Unknown"
        /** Name of the function on the enum impl to get a vec of value names */
        const val Values = "values"
    }

@@ -57,6 +118,7 @@ class EnumGenerator(
    }

    private fun renderUnamedEnum() {
        writer.documentShape(shape, model)
        meta.render(writer)
        writer.write("struct $enumName(String);")
        writer.rustBlock("impl $enumName") {
@@ -79,23 +141,22 @@ class EnumGenerator(
        }
    }

    private fun EnumDefinition.derivedName(): String {
        // Because enum variants always start with an upper case letter, they will never
        // conflict with reserved words (which are always lower case), therefore, we never need
        // to fall back to raw identifiers
        return name.orElse(null)?.toPascalCase()
            ?: throw IllegalStateException("Enum variants must be named to derive a name. This is a bug.")
    private fun renderEnum() {
        writer.docWithNote(
            shape.getTrait<DocumentationTrait>()?.value,
            when (sortedMembers.any { it.collidesWithUnknown }) {
                true ->
                    "`$enumName::$UnknownVariant` has been renamed to `::$EscapedUnknownVariant`. " +
                        "`$enumName::$UnknownVariant` refers to additional values that may have been added since " +
                        "this enum was generated."
                else -> null
            }
        )

    private fun renderEnum() {
        meta.render(writer)
        writer.rustBlock("enum $enumName") {
            sortedMembers.forEach { member ->
                member.documentation.map { setNewlinePrefix("/// ").write(it).setNewlinePrefix("") }
                // use the name, or escape the value
                write("${member.derivedName()},")
            }
            write("Unknown(String)")
            sortedMembers.forEach { member -> member.render(writer) }
            write("$UnknownVariant(String)")
        }
    }

@@ -106,7 +167,7 @@ class EnumGenerator(
                    sortedMembers.forEach { member ->
                        write("""$enumName::${member.derivedName()} => "${member.value}",""")
                    }
                    write("$enumName::Unknown(s) => s.as_ref()")
                    write("$enumName::$UnknownVariant(s) => s.as_ref()")
                }
            }
        }
@@ -142,7 +203,7 @@ class EnumGenerator(
                    sortedMembers.forEach { member ->
                        write(""""${member.value}" => $enumName::${member.derivedName()},""")
                    }
                    write("other => $enumName::Unknown(other.to_owned())")
                    write("other => $enumName::$UnknownVariant(other.to_owned())")
                }
            }
        }
@@ -156,7 +217,6 @@ class EnumGenerator(
                    Ok($enumName::from(s))
                }
            }

            """
        )
    }
+247 −107
Original line number Diff line number Diff line
@@ -5,21 +5,85 @@

package software.amazon.smithy.rust.codegen.generators

import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.Test
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.EnumMemberModel
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.lookup
import software.amazon.smithy.rust.codegen.util.orNull

class EnumGeneratorTest {
    @Nested
    inner class EnumMemberModelTests {
        private val testModel = """
            namespace test
            @enum([
                { value: "some-value-1",
                  name: "some_name_1",
                  documentation: "Some documentation." },
                { value: "some-value-2",
                  name: "someName2",
                  documentation: "More documentation" },
                { value: "unknown",
                  name: "unknown",
                  documentation: "It has some docs" }
            ])
            string EnumWithUnknown
        """.asSmithyModel()

        private val enumTrait = testModel.lookup<StringShape>("test#EnumWithUnknown").expectTrait<EnumTrait>()

        private fun model(name: String): EnumMemberModel =
            EnumMemberModel(enumTrait.values.first { it.name.orNull() == name })

        @Test
        fun `it converts enum names to PascalCase and renames any named Unknown to UnknownValue`() {
            model("some_name_1").derivedName() shouldBe "SomeName1"
            model("someName2").also { someName2 ->
                someName2.derivedName() shouldBe "SomeName2"
                someName2.collidesWithUnknown shouldBe false
            }
            model("unknown").also { unknown ->
                unknown.derivedName() shouldBe "UnknownValue"
                unknown.collidesWithUnknown shouldBe true
            }
        }

        @Test
        fun `it should render documentation`() {
            val rendered = RustWriter.forModule("model").also { model("some_name_1").render(it) }.toString()
            rendered shouldContain
                """
                /// Some documentation.
                SomeName1,
                """.trimIndent()
        }

        @Test
        fun `it adds a documentation note when renaming an enum named Unknown`() {
            val rendered = RustWriter.forModule("model").also { model("unknown").render(it) }.toString()
            rendered shouldContain
                """
                /// It has some docs
                ///
                /// **NOTE:** `::Unknown` has been renamed to `::UnknownValue`. `::Unknown` refers to additional values that may have been added since this enum was generated.
                UnknownValue,
                """.trimIndent()
        }
    }

    @Nested
    inner class EnumGeneratorTests {
        @Test
        fun `it generates named enums`() {
            val model = """
@@ -43,7 +107,7 @@ class EnumGeneratorTest {
            val provider: SymbolProvider = testSymbolProvider(model)
            val writer = RustWriter.forModule("model")
            val shape = model.lookup<StringShape>("test#InstanceType")
        val generator = EnumGenerator(provider, writer, shape, shape.expectTrait<EnumTrait>())
            val generator = EnumGenerator(model, provider, writer, shape, shape.expectTrait<EnumTrait>())
            generator.render()
            writer.compileAndTest(
                """
@@ -77,7 +141,7 @@ class EnumGeneratorTest {
            val shape: StringShape = model.lookup("test#FooEnum")
            val trait = shape.expectTrait<EnumTrait>()
            val writer = RustWriter.forModule("model")
        val generator = EnumGenerator(testSymbolProvider(model), writer, shape, trait)
            val generator = EnumGenerator(model, testSymbolProvider(model), writer, shape, trait)
            generator.render()
            writer.compileAndTest(
                """
@@ -105,7 +169,7 @@ class EnumGeneratorTest {
            val shape: StringShape = model.lookup("test#FooEnum")
            val trait = shape.expectTrait<EnumTrait>()
            val writer = RustWriter.forModule("model")
        val generator = EnumGenerator(testSymbolProvider(model), writer, shape, trait)
            val generator = EnumGenerator(model, testSymbolProvider(model), writer, shape, trait)
            generator.render()
            writer.compileAndTest(
                """
@@ -140,11 +204,11 @@ class EnumGeneratorTest {
            ])
            string FooEnum
            """.asSmithyModel()
        val shape = model.expectShape(ShapeId.from("test#FooEnum"), StringShape::class.java)
            val shape: StringShape = model.lookup("test#FooEnum")
            val trait = shape.expectTrait<EnumTrait>()
            val provider: SymbolProvider = testSymbolProvider(model)
            val writer = RustWriter.forModule("model")
        val generator = EnumGenerator(provider, writer, shape, trait)
            val generator = EnumGenerator(model, provider, writer, shape, trait)
            generator.render()
            writer.compileAndTest(
                """
@@ -153,4 +217,80 @@ class EnumGeneratorTest {
                """
            )
        }

        @Test
        fun `it escapes the Unknown variant if the enum has an unknown value in the model`() {
            val model = """
                namespace test
                @enum([
                    { name: "Known", value: "Known" },
                    { name: "Unknown", value: "Unknown" },
                ])
                string SomeEnum
            """.asSmithyModel()

            val shape: StringShape = model.lookup("test#SomeEnum")
            val trait = shape.expectTrait<EnumTrait>()
            val provider = testSymbolProvider(model)
            val writer = RustWriter.forModule("model")
            EnumGenerator(model, provider, writer, shape, trait).render()

            writer.compileAndTest(
                """
                assert_eq!(SomeEnum::from("Unknown"), SomeEnum::UnknownValue);
                assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown("SomethingNew".into()));
                """
            )
        }

        @Test
        fun `it should generate documentation for enums`() {
            val model = """
                namespace test
                
                /// Some top-level documentation.
                @enum([
                    { name: "Known", value: "Known" },
                    { name: "Unknown", value: "Unknown" },
                ])
                string SomeEnum
            """.asSmithyModel()

            val shape: StringShape = model.lookup("test#SomeEnum")
            val trait = shape.expectTrait<EnumTrait>()
            val provider = testSymbolProvider(model)
            val rendered = RustWriter.forModule("model").also { EnumGenerator(model, provider, it, shape, trait).render() }.toString()

            rendered shouldContain
                """
                    /// Some top-level documentation.
                    ///
                    /// **NOTE:** `SomeEnum::Unknown` has been renamed to `::UnknownValue`. `SomeEnum::Unknown` refers to additional values that may have been added since this enum was generated.
                """.trimIndent()
        }

        @Test
        fun `it should generate documentation for unnamed enums`() {
            val model = """
                namespace test
                
                /// Some top-level documentation.
                @enum([
                    { value: "One" },
                    { value: "Two" },
                ])
                string SomeEnum
            """.asSmithyModel()

            val shape: StringShape = model.lookup("test#SomeEnum")
            val trait = shape.expectTrait<EnumTrait>()
            val provider = testSymbolProvider(model)
            val rendered = RustWriter.forModule("model").also { EnumGenerator(model, provider, it, shape, trait).render() }.toString()

            rendered shouldContain
                """
                    /// Some top-level documentation.
                """.trimIndent()
        }
    }
}
+1 −1
Original line number Diff line number Diff line
@@ -175,7 +175,7 @@ internal class XmlBindingTraitParserGeneratorTest {
            model.lookup<StructureShape>("test#Top").renderWithModelBuilder(model, symbolProvider, it)
            UnionGenerator(model, symbolProvider, it, model.lookup("test#Choice")).render()
            val enum = model.lookup<StringShape>("test#FooEnum")
            EnumGenerator(symbolProvider, it, enum, enum.expectTrait()).render()
            EnumGenerator(model, symbolProvider, it, enum, enum.expectTrait()).render()
        }

        project.withModule(RustModule.default("output", public = true)) {