Unverified Commit 5bbee2fc authored by Landon James's avatar Landon James Committed by GitHub
Browse files

Merge branch 'main' into update-all-lockfiles-1747173848

parents fc93cdae 32b6f563
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
---
applies_to: ["client", "server"]
authors: ["rcoh"]
references: ["smithy-rs#4132"]
breaking: false
new_feature: false
bug_fix: true
---

Smithy unions that contain members named "unknown" will now codegen correctly
+6 −2
Original line number Diff line number Diff line
@@ -110,7 +110,8 @@ open class UnionGenerator(
    private fun renderImplBlock(unionSymbol: Symbol) {
        writer.rustBlock("impl ${unionSymbol.name}") {
            sortedMembers.forEach { member ->
                val funcNamePart = member.memberName.toSnakeCase()
                // We need to get the symbol first because the member can be renamed
                val funcNamePart = symbolProvider.toSymbol(member).name.toSnakeCase()
                val variantName = symbolProvider.toMemberName(member)

                if (sortedMembers.size == 1) {
@@ -219,7 +220,10 @@ private fun RustWriter.renderAsVariant(
            targetSymbol,
        )
        rust("/// Returns `Err(&Self)` if it can't be converted.")
        rustBlockTemplate("pub fn as_$funcNamePart(&self) -> #{Result}<&${memberSymbol.rustType().render()}, &Self>", *preludeScope) {
        rustBlockTemplate(
            "pub fn as_$funcNamePart(&self) -> #{Result}<&${memberSymbol.rustType().render()}, &Self>",
            *preludeScope,
        ) {
            rustTemplate(
                "if let ${unionSymbol.name}::$variantName(val) = &self { #{Ok}(val) } else { #{Err}(self) }",
                *preludeScope,
+49 −3
Original line number Diff line number Diff line
@@ -7,9 +7,11 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators

import io.kotest.matchers.string.shouldContain
import org.junit.jupiter.api.Test
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordConfig
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolProvider
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
@@ -42,6 +44,28 @@ class UnionGeneratorTest {
        writer.toString() shouldContain "#[non_exhaustive]"
    }

    @Test
    fun `generate basic union with member names Unknown`() {
        val writer =
            generateUnion(
                """
                union MyUnion {
                    unknown: String
                }
                """,
            )

        writer.compileAndTest(
            """
            let var_a = MyUnion::UnknownValue("abc".to_string());
            let var_b = MyUnion::Unknown;
            assert_ne!(var_a, var_b);
            assert_eq!(var_a, var_a);
            """,
        )
        writer.toString() shouldContain "#[non_exhaustive]"
    }

    @Test
    fun `generate conversion helper methods`() {
        val writer =
@@ -232,9 +256,31 @@ class UnionGeneratorTest {
        unknownVariant: Boolean = true,
    ): RustWriter {
        val model = "namespace test\n$modelSmithy".asSmithyModel()
        val provider: SymbolProvider = testSymbolProvider(model)
        // Reserved words to test generation of renamed members
        val reservedWords =
            RustReservedWordConfig(
                structureMemberMap =
                    StructureGenerator.structureMemberNameMap,
                unionMemberMap =
                    mapOf(
                        // Unions contain an `Unknown` variant. This exists to support parsing data returned from the server
                        // that represent union variants that have been added since this SDK was generated.
                        UnionGenerator.UNKNOWN_VARIANT_NAME to "${UnionGenerator.UNKNOWN_VARIANT_NAME}Value",
                        "${UnionGenerator.UNKNOWN_VARIANT_NAME}Value" to "${UnionGenerator.UNKNOWN_VARIANT_NAME}Value_",
                    ),
                enumMemberMap =
                    mapOf(),
            )
        val provider: RustSymbolProvider = testSymbolProvider(model)
        val reservedWordsProvider = RustReservedWordSymbolProvider(provider, reservedWords)
        val writer = RustWriter.forModule("model")
        UnionGenerator(model, provider, writer, model.lookup("test#$unionName"), renderUnknownVariant = unknownVariant).render()
        UnionGenerator(
            model,
            reservedWordsProvider,
            writer,
            model.lookup("test#$unionName"),
            renderUnknownVariant = unknownVariant,
        ).render()
        return writer
    }
}