Unverified Commit 97b47adb authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Add non_exhaustive to structs, unions, and enums (#59)

* Add non_exhaustive to structs, unions, and enums

During a refactoring, `non_exhaustive` was accidentally removed from generated structurs. This commit re-adds it and validates its present with tests.

* Remove print

* Delete empty test
parent 7954752a
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -125,7 +125,7 @@ data class RustMetadata(
    val additionalAttributes: List<Attribute> = listOf(),
    val public: Boolean
) {
    fun withDerive(newDerive: RuntimeType): RustMetadata =
    fun withDerives(vararg newDerive: RuntimeType): RustMetadata =
        this.copy(derives = derives.copy(derives = derives.derives + newDerive))

    fun attributes(): List<Attribute> = additionalAttributes + derives
+17 −14
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.lang.Attribute.Companion.NonExhaustive
import software.amazon.smithy.rust.codegen.lang.Derives
import software.amazon.smithy.rust.codegen.lang.RustMetadata

@@ -55,31 +56,32 @@ abstract class SymbolMetadataProvider(private val base: RustSymbolProvider) : Wr
}

class BaseSymbolMetadataProvider(base: RustSymbolProvider) : SymbolMetadataProvider(base) {
    private val containerDefault = RustMetadata(
        Derives(defaultDerives.toSet()),
        additionalAttributes = listOf(NonExhaustive),
        public = true
    )

    override fun memberMeta(memberShape: MemberShape): RustMetadata {
        return RustMetadata(public = true)
    }

    override fun structureMeta(structureShape: StructureShape): RustMetadata {
        return RustMetadata(Derives(defaultDerives.toSet()), public = true)
        return containerDefault
    }

    override fun unionMeta(unionShape: UnionShape): RustMetadata {
        return RustMetadata(Derives(defaultDerives.toSet()), public = true)
        return containerDefault
    }

    override fun enumMeta(stringShape: StringShape): RustMetadata {
        return RustMetadata(
            Derives(
                defaultDerives.toSet() +
                    // enums must be hashable because string sets are hashable
                    RuntimeType.Std("hash::Hash") +
                    // enums can be eq because they can only contain strings
                    RuntimeType.Std("cmp::Eq") +
        return containerDefault.withDerives(
            RuntimeType.Std("hash::Hash")
        ).withDerives( // enums can be eq because they can only contain strings
            RuntimeType.Std("cmp::Eq"),
            // enums can be Ord because they can only contain strings
                    RuntimeType.Std("cmp::PartialOrd") +
            RuntimeType.Std("cmp::PartialOrd"),
            RuntimeType.Std("cmp::Ord")
            ),
            public = true
        )
    }

@@ -93,6 +95,7 @@ private const val MetaKey = "meta"
fun Symbol.Builder.meta(rustMetadata: RustMetadata?): Symbol.Builder {
    return this.putProperty(MetaKey, rustMetadata)
}

fun Symbol.expectRustMetadata(): RustMetadata = this.getProperty(MetaKey, RustMetadata::class.java).orElseThrow {
    CodegenException(
        "Expected $this to have metadata attached but it did not. "
+2 −2
Original line number Diff line number Diff line
@@ -82,8 +82,8 @@ class JsonSerializerSymbolProvider(
        val currentMeta = base.toSymbol(container).expectRustMetadata()
        val requiredSerde = serdeRequired(container)
        return currentMeta
            .letIf(requiredSerde.serialize) { it.withDerive(RuntimeType.Serialize) }
            .letIf(requiredSerde.deserialize) { it.withDerive(RuntimeType.Deserialize) }
            .letIf(requiredSerde.serialize) { it.withDerives(RuntimeType.Serialize) }
            .letIf(requiredSerde.deserialize) { it.withDerives(RuntimeType.Deserialize) }
    }

    private fun serdeRequired(shape: Shape): SerdeConfig {
+23 −30
Original line number Diff line number Diff line
@@ -5,13 +5,11 @@

package software.amazon.smithy.rust.codegen.generators

import io.kotest.matchers.string.shouldContain
import org.junit.jupiter.api.Test
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.traits.DocumentationTrait
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator
@@ -19,42 +17,35 @@ import software.amazon.smithy.rust.codegen.util.lookup
import software.amazon.smithy.rust.testutil.asSmithyModel
import software.amazon.smithy.rust.testutil.compileAndRun
import software.amazon.smithy.rust.testutil.compileAndTest
import software.amazon.smithy.rust.testutil.shouldCompile
import software.amazon.smithy.rust.testutil.shouldParseAsRust
import software.amazon.smithy.rust.testutil.testSymbolProvider

class EnumGeneratorTest {
    @Test
    fun `it generates named enums`() {
        val trait = EnumTrait.builder()
            .addEnum(EnumDefinition.builder().value("t2.nano").name("T2_NANO").build())
            .addEnum(
                EnumDefinition.builder().value("t2.micro").name("T2_MICRO").documentation(
                    "T2 instances are Burstable Performance\n" +
                        "Instances that provide a baseline level of CPU\n" +
                        "performance with the ability to burst above the\n" +
                        "baseline."
                ).build()
            )
            .build()

        val shape = StringShape.builder()
            .id("com.test#InstanceType")
            .addTrait(trait)
            .addTrait(DocumentationTrait("Documentation for this enum"))
            .build()

        val model = Model.assembler()
            .addShapes(shape)
            .assemble()
            .unwrap()
        val model = """
        namespace test
        @enum([
            {
                value: "t2.nano",
                name: "T2_NANO",
                documentation: "T2 instances are Burstable Performance Instances.",
                tags: ["ebsOnly"]
            },
            {
                value: "t2.micro",
                name: "T2_MICRO",
                documentation: "T2 instances are Burstable Performance Instances.",
                tags: ["ebsOnly"]
            },
        ])
        string InstanceType
        """.asSmithyModel()
        val provider: SymbolProvider = testSymbolProvider(model)
        val writer = RustWriter.forModule("model")
        val generator = EnumGenerator(provider, writer, shape, trait)
        val shape = model.lookup<StringShape>("test#InstanceType")
        val generator = EnumGenerator(provider, writer, shape, shape.expectTrait(EnumTrait::class.java))
        generator.render()
        val result = writer.toString()
        result.shouldParseAsRust()
        result.shouldCompile()
        result.compileAndRun(
            """
            let instance = InstanceType::T2Micro;
@@ -65,6 +56,8 @@ class EnumGeneratorTest {
            assert_eq!(InstanceType::from("other").as_str(), "other");
            """.trimIndent()
        )

        writer.toString() shouldContain "#[non_exhaustive]"
    }

    @Test
+2 −0
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@

package software.amazon.smithy.rust.codegen.generators

import io.kotest.matchers.string.shouldContainInOrder
import org.junit.jupiter.api.Test
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.SymbolProvider
@@ -65,6 +66,7 @@ class StructureGeneratorTest {
            s.map(|i|println!("{:?}, {:?}", i.ts, i.byte_value));
            """.trimIndent()
        )
        writer.toString().shouldContainInOrder("this documents the shape", "#[non_exhaustive]", "pub", "struct MyStruct")
    }

    @Test
Loading