Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/UnionGenerator.kt +16 −0 Original line number Diff line number Diff line Loading @@ -11,9 +11,11 @@ import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.documentShape import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.util.toPascalCase import software.amazon.smithy.rust.codegen.util.toSnakeCase class UnionGenerator( val model: Model, Loading @@ -39,5 +41,19 @@ class UnionGenerator( write("${member.memberName.toPascalCase()}(#T),", symbolProvider.toSymbol(member)) } } writer.rustBlock("impl ${symbol.name}") { sortedMembers.forEach { member -> val memberSymbol = symbolProvider.toSymbol(member) val funcNamePart = member.memberName.toSnakeCase() val variantName = member.memberName.toPascalCase() writer.rustBlock("pub fn as_$funcNamePart(&self) -> Option<&#T>", memberSymbol) { rust("if let ${symbol.name}::$variantName(val) = &self { Some(&val) } else { None }") } writer.rustBlock("pub fn is_$funcNamePart(&self) -> bool") { rust("self.as_$funcNamePart().is_some()") } } } } } codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/UnionGeneratorTest.kt +30 −0 Original line number Diff line number Diff line Loading @@ -40,4 +40,34 @@ class UnionGeneratorTest { ) writer.toString() shouldContain "#[non_exhaustive]" } @Test fun `generate conversion helper methods`() { val model = """ namespace test union MyUnion { stringValue: String, intValue: PrimitiveInteger } """.asSmithyModel() val provider: SymbolProvider = testSymbolProvider(model) val writer = RustWriter.forModule("model") val generator = UnionGenerator(model, provider, writer, model.lookup("test#MyUnion")) generator.render() writer.compileAndTest( """ let foo = MyUnion::StringValue("foo".to_string()); let bar = MyUnion::IntValue(10); assert_eq!(foo.is_string_value(), true); assert_eq!(foo.is_int_value(), false); assert_eq!(foo.as_string_value(), Some(&"foo".to_string())); assert_eq!(foo.as_int_value(), None); assert_eq!(bar.is_string_value(), false); assert_eq!(bar.is_int_value(), true); assert_eq!(bar.as_string_value(), None); assert_eq!(bar.as_int_value(), Some(&10)); """ ) } } Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/UnionGenerator.kt +16 −0 Original line number Diff line number Diff line Loading @@ -11,9 +11,11 @@ import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.documentShape import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.util.toPascalCase import software.amazon.smithy.rust.codegen.util.toSnakeCase class UnionGenerator( val model: Model, Loading @@ -39,5 +41,19 @@ class UnionGenerator( write("${member.memberName.toPascalCase()}(#T),", symbolProvider.toSymbol(member)) } } writer.rustBlock("impl ${symbol.name}") { sortedMembers.forEach { member -> val memberSymbol = symbolProvider.toSymbol(member) val funcNamePart = member.memberName.toSnakeCase() val variantName = member.memberName.toPascalCase() writer.rustBlock("pub fn as_$funcNamePart(&self) -> Option<&#T>", memberSymbol) { rust("if let ${symbol.name}::$variantName(val) = &self { Some(&val) } else { None }") } writer.rustBlock("pub fn is_$funcNamePart(&self) -> bool") { rust("self.as_$funcNamePart().is_some()") } } } } }
codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/UnionGeneratorTest.kt +30 −0 Original line number Diff line number Diff line Loading @@ -40,4 +40,34 @@ class UnionGeneratorTest { ) writer.toString() shouldContain "#[non_exhaustive]" } @Test fun `generate conversion helper methods`() { val model = """ namespace test union MyUnion { stringValue: String, intValue: PrimitiveInteger } """.asSmithyModel() val provider: SymbolProvider = testSymbolProvider(model) val writer = RustWriter.forModule("model") val generator = UnionGenerator(model, provider, writer, model.lookup("test#MyUnion")) generator.render() writer.compileAndTest( """ let foo = MyUnion::StringValue("foo".to_string()); let bar = MyUnion::IntValue(10); assert_eq!(foo.is_string_value(), true); assert_eq!(foo.is_int_value(), false); assert_eq!(foo.as_string_value(), Some(&"foo".to_string())); assert_eq!(foo.as_int_value(), None); assert_eq!(bar.is_string_value(), false); assert_eq!(bar.is_int_value(), true); assert_eq!(bar.as_string_value(), None); assert_eq!(bar.as_int_value(), Some(&10)); """ ) } }