Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustWriter.kt +20 −0 Original line number Diff line number Diff line Loading @@ -21,6 +21,26 @@ import software.amazon.smithy.utils.CodeWriter import java.util.function.BiFunction fun <T : CodeWriter> T.withBlock( textBeforeNewLine: String, textAfterNewLine: String, block: T.() -> Unit ): T { return conditionalBlock(textBeforeNewLine, textAfterNewLine, conditional = true, block = block) } /** * Write a block to the writer. * If [conditional] is true, the [textBeforeNewLine], followed by [block], followed by [textAfterNewLine] * If [conditional] is false, only [block] is written. * This enables conditionally wrapping a block in a prefix/suffix, eg. * * ``` * writer.withBlock("Some(", ")", conditional = symbol.isOptional()) { * write("symbolValue") * } * ``` */ fun <T : CodeWriter> T.conditionalBlock( textBeforeNewLine: String, textAfterNewLine: String, conditional: Boolean = true, Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt +2 −6 Original line number Diff line number Diff line Loading @@ -27,11 +27,7 @@ class HttpProtocolTestGenerator( // TODO: remove these once Smithy publishes fixes. // These tests are not even attempted to be compiled val DisableTests = setOf( "RestJsonListsSerializeNull", "AwsJson11MapsSerializeNullValues", "AwsJson11ListsSerializeNull", "RestJsonSerializesNullMapValues", // This test is flake because of set ordering serialization https://github.com/awslabs/smithy-rs/issues/37 // This test is flaky because of set ordering serialization https://github.com/awslabs/smithy-rs/issues/37 "AwsJson11Enums" ) Loading Loading @@ -104,7 +100,7 @@ class HttpProtocolTestGenerator( } testModuleWriter.rustBlock("fn test_${httpRequestTestCase.id.toSnakeCase()}()") { writeInline("let input =") instantiator.render(httpRequestTestCase.params, inputShape, this) instantiator.render(this, inputShape, httpRequestTestCase.params) write(";") write("let http_request = input.build_http_request().body(()).unwrap();") with(httpRequestTestCase) { Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt +35 −17 Original line number Diff line number Diff line Loading @@ -4,6 +4,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.node.ArrayNode import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.node.NullNode import software.amazon.smithy.model.node.NumberNode import software.amazon.smithy.model.node.ObjectNode import software.amazon.smithy.model.node.StringNode Loading @@ -23,6 +24,7 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.lang.RustType import software.amazon.smithy.rust.codegen.lang.RustWriter import software.amazon.smithy.rust.codegen.lang.conditionalBlock import software.amazon.smithy.rust.codegen.lang.rustBlock import software.amazon.smithy.rust.codegen.lang.withBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig Loading @@ -42,7 +44,7 @@ class Instantiator( private val runtimeConfig: RuntimeConfig ) { fun render(arg: Node, shape: Shape, writer: RustWriter) { fun render(writer: RustWriter, shape: Shape, arg: Node) { when (shape) { // Compound Shapes is StructureShape -> renderStructure(writer, shape, arg as ObjectNode) Loading @@ -53,6 +55,9 @@ class Instantiator( is MapShape -> renderMap(writer, shape, arg as ObjectNode) is SetShape -> renderSet(writer, shape, arg as ArrayNode) // Members, supporting potentially optional members is MemberShape -> renderMember(writer, shape, arg) // Wrapped Shapes is TimestampShape -> writer.write( "\$T::from_epoch_seconds(${(arg as NumberNode).value})", Loading @@ -77,15 +82,37 @@ class Instantiator( } } /** * If the shape is optional: `Some(inner)` or `None` * otherwise: `inner` */ private fun renderMember( writer: RustWriter, shape: MemberShape, arg: Node ) { val target = model.expectShape(shape.target) val symbol = symbolProvider.toSymbol(shape) if (arg is NullNode) { check( symbol.isOptional() ) { "A null node was provided for $shape but the symbol was not optional. This is invalid input data." } writer.write("None") } else { writer.conditionalBlock("Some(", ")", conditional = symbol.isOptional()) { render(this, target, arg) } } } private fun renderSet(writer: RustWriter, shape: SetShape, data: ArrayNode) { if (symbolProvider.toSymbol(shape).rustType() is RustType.HashSet) { if (!data.isEmpty) { writer.rustBlock("") { write("let mut ret = \$T::new();", RuntimeType.HashSet) val valueShape = shape.member.let { model.expectShape(it.target) } data.forEach { v -> withBlock("ret.insert(", ");") { render(v, valueShape, this) renderMember(this, shape.member, v) } } write("ret") Loading Loading @@ -115,10 +142,9 @@ class Instantiator( if (data.members.isNotEmpty()) { writer.rustBlock("") { write("let mut ret = \$T::new();", RuntimeType.HashMap) val valueShape = shape.value.let { model.expectShape(it.target) } data.members.forEach { (k, v) -> withBlock("ret.insert(${k.value.dq()}.to_string(),", ");") { render(v, valueShape, this) renderMember(this, shape.value, v) } } write("ret") Loading Loading @@ -148,7 +174,7 @@ class Instantiator( writer.write("\$T::${memberName.toPascalCase()}", unionSymbol) // unions should specify exactly one member writer.withBlock("(", ")") { render(variant.value, member, this) render(this, member, variant.value) } } Loading @@ -162,17 +188,9 @@ class Instantiator( shape: CollectionShape, data: ArrayNode ) { val member = model.expectShape(shape.member.target) val memberSymbol = symbolProvider.toSymbol(shape.member) writer.withBlock("vec![", "]") { data.elements.forEach { if (it.isNullNode) { write("None") } else { withBlock("Some(", ")", conditional = memberSymbol.isOptional()) { render(it, member, this) } } data.elements.forEach { v -> renderMember(this, shape.member, v) write(",") } } Loading Loading @@ -209,7 +227,7 @@ class Instantiator( val func = symbolProvider.toMemberName(memberShape) if (!value.isNullNode) { writer.withBlock(".$func(", ")") { render(value, targetShape, this) render(this, targetShape, value) } } } Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt +2 −2 Original line number Diff line number Diff line Loading @@ -12,9 +12,9 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.rust.codegen.lang.RustType import software.amazon.smithy.rust.codegen.lang.RustWriter import software.amazon.smithy.rust.codegen.lang.conditionalBlock import software.amazon.smithy.rust.codegen.lang.render import software.amazon.smithy.rust.codegen.lang.rustBlock import software.amazon.smithy.rust.codegen.lang.withBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.canUseDefault import software.amazon.smithy.rust.codegen.smithy.expectRustMetadata Loading Loading @@ -158,7 +158,7 @@ class StructureGenerator( } rustBlock("pub fn build(self) -> $returnType", structureSymbol) { withBlock("Ok(", ")", conditional = fallibleBuilder) { conditionalBlock("Ok(", ")", conditional = fallibleBuilder) { rustBlock("\$T", structureSymbol) { members.forEach { member -> val memberName = symbolProvider.toMemberName(member) Loading codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt +11 −7 Original line number Diff line number Diff line Loading @@ -80,7 +80,7 @@ class InstantiatorTest { writer.write("#[test]") writer.rustBlock("fn inst()") { writer.withBlock("let result = ", ";") { sut.render(data, union, this) sut.render(this, union, data) } writer.write("assert_eq!(result, MyUnion::StringVariant(\"ok!\".to_string()));") } Loading @@ -103,7 +103,7 @@ class InstantiatorTest { writer.write("#[test]") writer.rustBlock("fn inst()") { writer.withBlock("let result = ", ";") { sut.render(data, structure, this) sut.render(this, structure, data) } writer.write("assert_eq!(result.bar, 10);") writer.write("assert_eq!(result.foo.unwrap(), \"hello\");") Loading @@ -129,7 +129,7 @@ class InstantiatorTest { writer.write("#[test]") writer.rustBlock("fn inst()") { withBlock("let result = ", ";") { sut.render(data, structure, this) sut.render(this, structure, data) } rust( """ Loading Loading @@ -160,7 +160,7 @@ class InstantiatorTest { writer.write("#[test]") writer.rustBlock("fn inst()") { writer.withBlock("let result = ", ";") { sut.render(data, model.lookup("com.test#MyList"), writer) sut.render(writer, model.lookup("com.test#MyList"), data) } writer.write("""assert_eq!(result, vec!["bar".to_string(), "foo".to_string()]);""") } Loading @@ -182,7 +182,7 @@ class InstantiatorTest { writer.write("#[test]") writer.rustBlock("fn inst()") { writer.withBlock("let result = ", ";") { sut.render(data, model.lookup("com.test#MySparseList"), writer) sut.render(writer, model.lookup("com.test#MySparseList"), data) } writer.write("""assert_eq!(result, vec![Some("bar".to_string()), Some("foo".to_string()), None]);""") } Loading @@ -206,7 +206,7 @@ class InstantiatorTest { writer.write("#[test]") writer.rustBlock("fn inst()") { writer.withBlock("let result = ", ";") { sut.render(data, model.lookup("com.test#NestedMap"), writer) sut.render(writer, model.lookup("com.test#NestedMap"), data) } writer.write( """ Loading @@ -229,7 +229,11 @@ class InstantiatorTest { writer.write("#[test]") writer.rustBlock("fn test_blob()") { withBlock("let blob = ", ";") { sut.render(StringNode.parse("foo".dq()), BlobShape.builder().id(ShapeId.from("com.example#Blob")).build(), this) sut.render( this, BlobShape.builder().id(ShapeId.from("com.example#Blob")).build(), StringNode.parse("foo".dq()) ) } write("assert_eq!(std::str::from_utf8(blob.as_ref()).unwrap(), \"foo\");") } Loading Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/lang/RustWriter.kt +20 −0 Original line number Diff line number Diff line Loading @@ -21,6 +21,26 @@ import software.amazon.smithy.utils.CodeWriter import java.util.function.BiFunction fun <T : CodeWriter> T.withBlock( textBeforeNewLine: String, textAfterNewLine: String, block: T.() -> Unit ): T { return conditionalBlock(textBeforeNewLine, textAfterNewLine, conditional = true, block = block) } /** * Write a block to the writer. * If [conditional] is true, the [textBeforeNewLine], followed by [block], followed by [textAfterNewLine] * If [conditional] is false, only [block] is written. * This enables conditionally wrapping a block in a prefix/suffix, eg. * * ``` * writer.withBlock("Some(", ")", conditional = symbol.isOptional()) { * write("symbolValue") * } * ``` */ fun <T : CodeWriter> T.conditionalBlock( textBeforeNewLine: String, textAfterNewLine: String, conditional: Boolean = true, Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt +2 −6 Original line number Diff line number Diff line Loading @@ -27,11 +27,7 @@ class HttpProtocolTestGenerator( // TODO: remove these once Smithy publishes fixes. // These tests are not even attempted to be compiled val DisableTests = setOf( "RestJsonListsSerializeNull", "AwsJson11MapsSerializeNullValues", "AwsJson11ListsSerializeNull", "RestJsonSerializesNullMapValues", // This test is flake because of set ordering serialization https://github.com/awslabs/smithy-rs/issues/37 // This test is flaky because of set ordering serialization https://github.com/awslabs/smithy-rs/issues/37 "AwsJson11Enums" ) Loading Loading @@ -104,7 +100,7 @@ class HttpProtocolTestGenerator( } testModuleWriter.rustBlock("fn test_${httpRequestTestCase.id.toSnakeCase()}()") { writeInline("let input =") instantiator.render(httpRequestTestCase.params, inputShape, this) instantiator.render(this, inputShape, httpRequestTestCase.params) write(";") write("let http_request = input.build_http_request().body(()).unwrap();") with(httpRequestTestCase) { Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt +35 −17 Original line number Diff line number Diff line Loading @@ -4,6 +4,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.node.ArrayNode import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.node.NullNode import software.amazon.smithy.model.node.NumberNode import software.amazon.smithy.model.node.ObjectNode import software.amazon.smithy.model.node.StringNode Loading @@ -23,6 +24,7 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.lang.RustType import software.amazon.smithy.rust.codegen.lang.RustWriter import software.amazon.smithy.rust.codegen.lang.conditionalBlock import software.amazon.smithy.rust.codegen.lang.rustBlock import software.amazon.smithy.rust.codegen.lang.withBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig Loading @@ -42,7 +44,7 @@ class Instantiator( private val runtimeConfig: RuntimeConfig ) { fun render(arg: Node, shape: Shape, writer: RustWriter) { fun render(writer: RustWriter, shape: Shape, arg: Node) { when (shape) { // Compound Shapes is StructureShape -> renderStructure(writer, shape, arg as ObjectNode) Loading @@ -53,6 +55,9 @@ class Instantiator( is MapShape -> renderMap(writer, shape, arg as ObjectNode) is SetShape -> renderSet(writer, shape, arg as ArrayNode) // Members, supporting potentially optional members is MemberShape -> renderMember(writer, shape, arg) // Wrapped Shapes is TimestampShape -> writer.write( "\$T::from_epoch_seconds(${(arg as NumberNode).value})", Loading @@ -77,15 +82,37 @@ class Instantiator( } } /** * If the shape is optional: `Some(inner)` or `None` * otherwise: `inner` */ private fun renderMember( writer: RustWriter, shape: MemberShape, arg: Node ) { val target = model.expectShape(shape.target) val symbol = symbolProvider.toSymbol(shape) if (arg is NullNode) { check( symbol.isOptional() ) { "A null node was provided for $shape but the symbol was not optional. This is invalid input data." } writer.write("None") } else { writer.conditionalBlock("Some(", ")", conditional = symbol.isOptional()) { render(this, target, arg) } } } private fun renderSet(writer: RustWriter, shape: SetShape, data: ArrayNode) { if (symbolProvider.toSymbol(shape).rustType() is RustType.HashSet) { if (!data.isEmpty) { writer.rustBlock("") { write("let mut ret = \$T::new();", RuntimeType.HashSet) val valueShape = shape.member.let { model.expectShape(it.target) } data.forEach { v -> withBlock("ret.insert(", ");") { render(v, valueShape, this) renderMember(this, shape.member, v) } } write("ret") Loading Loading @@ -115,10 +142,9 @@ class Instantiator( if (data.members.isNotEmpty()) { writer.rustBlock("") { write("let mut ret = \$T::new();", RuntimeType.HashMap) val valueShape = shape.value.let { model.expectShape(it.target) } data.members.forEach { (k, v) -> withBlock("ret.insert(${k.value.dq()}.to_string(),", ");") { render(v, valueShape, this) renderMember(this, shape.value, v) } } write("ret") Loading Loading @@ -148,7 +174,7 @@ class Instantiator( writer.write("\$T::${memberName.toPascalCase()}", unionSymbol) // unions should specify exactly one member writer.withBlock("(", ")") { render(variant.value, member, this) render(this, member, variant.value) } } Loading @@ -162,17 +188,9 @@ class Instantiator( shape: CollectionShape, data: ArrayNode ) { val member = model.expectShape(shape.member.target) val memberSymbol = symbolProvider.toSymbol(shape.member) writer.withBlock("vec![", "]") { data.elements.forEach { if (it.isNullNode) { write("None") } else { withBlock("Some(", ")", conditional = memberSymbol.isOptional()) { render(it, member, this) } } data.elements.forEach { v -> renderMember(this, shape.member, v) write(",") } } Loading Loading @@ -209,7 +227,7 @@ class Instantiator( val func = symbolProvider.toMemberName(memberShape) if (!value.isNullNode) { writer.withBlock(".$func(", ")") { render(value, targetShape, this) render(this, targetShape, value) } } } Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt +2 −2 Original line number Diff line number Diff line Loading @@ -12,9 +12,9 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.rust.codegen.lang.RustType import software.amazon.smithy.rust.codegen.lang.RustWriter import software.amazon.smithy.rust.codegen.lang.conditionalBlock import software.amazon.smithy.rust.codegen.lang.render import software.amazon.smithy.rust.codegen.lang.rustBlock import software.amazon.smithy.rust.codegen.lang.withBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.canUseDefault import software.amazon.smithy.rust.codegen.smithy.expectRustMetadata Loading Loading @@ -158,7 +158,7 @@ class StructureGenerator( } rustBlock("pub fn build(self) -> $returnType", structureSymbol) { withBlock("Ok(", ")", conditional = fallibleBuilder) { conditionalBlock("Ok(", ")", conditional = fallibleBuilder) { rustBlock("\$T", structureSymbol) { members.forEach { member -> val memberName = symbolProvider.toMemberName(member) Loading
codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/InstantiatorTest.kt +11 −7 Original line number Diff line number Diff line Loading @@ -80,7 +80,7 @@ class InstantiatorTest { writer.write("#[test]") writer.rustBlock("fn inst()") { writer.withBlock("let result = ", ";") { sut.render(data, union, this) sut.render(this, union, data) } writer.write("assert_eq!(result, MyUnion::StringVariant(\"ok!\".to_string()));") } Loading @@ -103,7 +103,7 @@ class InstantiatorTest { writer.write("#[test]") writer.rustBlock("fn inst()") { writer.withBlock("let result = ", ";") { sut.render(data, structure, this) sut.render(this, structure, data) } writer.write("assert_eq!(result.bar, 10);") writer.write("assert_eq!(result.foo.unwrap(), \"hello\");") Loading @@ -129,7 +129,7 @@ class InstantiatorTest { writer.write("#[test]") writer.rustBlock("fn inst()") { withBlock("let result = ", ";") { sut.render(data, structure, this) sut.render(this, structure, data) } rust( """ Loading Loading @@ -160,7 +160,7 @@ class InstantiatorTest { writer.write("#[test]") writer.rustBlock("fn inst()") { writer.withBlock("let result = ", ";") { sut.render(data, model.lookup("com.test#MyList"), writer) sut.render(writer, model.lookup("com.test#MyList"), data) } writer.write("""assert_eq!(result, vec!["bar".to_string(), "foo".to_string()]);""") } Loading @@ -182,7 +182,7 @@ class InstantiatorTest { writer.write("#[test]") writer.rustBlock("fn inst()") { writer.withBlock("let result = ", ";") { sut.render(data, model.lookup("com.test#MySparseList"), writer) sut.render(writer, model.lookup("com.test#MySparseList"), data) } writer.write("""assert_eq!(result, vec![Some("bar".to_string()), Some("foo".to_string()), None]);""") } Loading @@ -206,7 +206,7 @@ class InstantiatorTest { writer.write("#[test]") writer.rustBlock("fn inst()") { writer.withBlock("let result = ", ";") { sut.render(data, model.lookup("com.test#NestedMap"), writer) sut.render(writer, model.lookup("com.test#NestedMap"), data) } writer.write( """ Loading @@ -229,7 +229,11 @@ class InstantiatorTest { writer.write("#[test]") writer.rustBlock("fn test_blob()") { withBlock("let blob = ", ";") { sut.render(StringNode.parse("foo".dq()), BlobShape.builder().id(ShapeId.from("com.example#Blob")).build(), this) sut.render( this, BlobShape.builder().id(ShapeId.from("com.example#Blob")).build(), StringNode.parse("foo".dq()) ) } write("assert_eq!(std::str::from_utf8(blob.as_ref()).unwrap(), \"foo\");") } Loading