Unverified Commit 6a8e4274 authored by Victor Barbu's avatar Victor Barbu Committed by GitHub
Browse files

Add as_xyz and is_xyz helper methods on Smithy Unions (#321)



* Add as_xyz and is_xyz helper methods on Smithy Unions

Issue #298

* Fix Unions generation and add unit test

Co-authored-by: default avatarVictor Barbu <vicbarbu@amazon.com>
parent 7c764055
Loading
Loading
Loading
Loading
+16 −0
Original line number Diff line number Diff line
@@ -11,9 +11,11 @@ import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.documentShape
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.util.toPascalCase
import software.amazon.smithy.rust.codegen.util.toSnakeCase

class UnionGenerator(
    val model: Model,
@@ -39,5 +41,19 @@ class UnionGenerator(
                write("${member.memberName.toPascalCase()}(#T),", symbolProvider.toSymbol(member))
            }
        }
        writer.rustBlock("impl ${symbol.name}") {
            sortedMembers.forEach { member ->
                val memberSymbol = symbolProvider.toSymbol(member)
                val funcNamePart = member.memberName.toSnakeCase()
                val variantName = member.memberName.toPascalCase()

                writer.rustBlock("pub fn as_$funcNamePart(&self) -> Option<&#T>", memberSymbol) {
                    rust("if let ${symbol.name}::$variantName(val) = &self { Some(&val) } else { None }")
                }
                writer.rustBlock("pub fn is_$funcNamePart(&self) -> bool") {
                    rust("self.as_$funcNamePart().is_some()")
                }
            }
        }
    }
}
+30 −0
Original line number Diff line number Diff line
@@ -40,4 +40,34 @@ class UnionGeneratorTest {
        )
        writer.toString() shouldContain "#[non_exhaustive]"
    }

    @Test
    fun `generate conversion helper methods`() {
        val model = """
        namespace test
        union MyUnion {
            stringValue: String,
            intValue: PrimitiveInteger
        }
        """.asSmithyModel()
        val provider: SymbolProvider = testSymbolProvider(model)
        val writer = RustWriter.forModule("model")
        val generator = UnionGenerator(model, provider, writer, model.lookup("test#MyUnion"))
        generator.render()

        writer.compileAndTest(
            """
            let foo = MyUnion::StringValue("foo".to_string());
            let bar = MyUnion::IntValue(10);
            assert_eq!(foo.is_string_value(), true);
            assert_eq!(foo.is_int_value(), false);
            assert_eq!(foo.as_string_value(), Some(&"foo".to_string()));
            assert_eq!(foo.as_int_value(), None);
            assert_eq!(bar.is_string_value(), false);
            assert_eq!(bar.is_int_value(), true);
            assert_eq!(bar.as_string_value(), None);
            assert_eq!(bar.as_int_value(), Some(&10));
        """
        )
    }
}