From a5f1653c0fc9f010b9026f5ea52c4cacbecee4e2 Mon Sep 17 00:00:00 2001 From: 82marbag <69267416+82marbag@users.noreply.github.com> Date: Mon, 30 Oct 2023 12:39:07 +0000 Subject: [PATCH] Allow custom writables in Instantiator (#3104) `Instantiator` currently renders only static data. This PR allows it to customize this behavior to allow non-static data (e.g. data taken from other structs). ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._ --------- Signed-off-by: Daniele Ahmed --- .../core/smithy/generators/Instantiator.kt | 244 ++++++++++-------- .../smithy/generators/ServerInstantiator.kt | 20 +- .../generators/ServerInstantiatorTest.kt | 92 +++++++ 3 files changed, 244 insertions(+), 112 deletions(-) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt index 8a36d8ecb..795a6b5a1 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/Instantiator.kt @@ -91,6 +91,7 @@ open class Instantiator( private val defaultsForRequiredFields: Boolean = false, private val customizations: List = listOf(), private val constructPattern: InstantiatorConstructPattern = InstantiatorConstructPattern.BUILDER, + private val customWritable: CustomWritable = NoCustomWritable(), ) { data class Ctx( // The `http` crate requires that headers be lowercase, but Smithy protocol tests @@ -122,28 +123,54 @@ open class Instantiator( fun doesSetterTakeInOption(memberShape: MemberShape): Boolean } - fun generate(shape: Shape, data: Node, headers: Map = mapOf(), ctx: Ctx = Ctx()) = writable { - render(this, shape, data, headers, ctx) + /** + * Customize how each shape is rendered, instead of relying on static `Node` data. + */ + interface CustomWritable { + // Return `null` to rely on the default behavior, which uses the static `Node` data. + fun generate(shape: Shape): Writable? } - fun render(writer: RustWriter, shape: Shape, data: Node, headers: Map = mapOf(), ctx: Ctx = Ctx()) { - when (shape) { - // Compound Shapes - is StructureShape -> renderStructure(writer, shape, data as ObjectNode, headers, ctx) - is UnionShape -> renderUnion(writer, shape, data as ObjectNode, ctx) - - // Collections - is ListShape -> renderList(writer, shape, data as ArrayNode, ctx) - is MapShape -> renderMap(writer, shape, data as ObjectNode, ctx) - is SetShape -> renderSet(writer, shape, data as ArrayNode, ctx) + class NoCustomWritable : CustomWritable { + override fun generate(shape: Shape): Writable? = null + } - // Members, supporting potentially optional members - is MemberShape -> renderMember(writer, shape, data, ctx) + fun generate(shape: Shape, data: Node, headers: Map = mapOf(), ctx: Ctx = Ctx()) = writable { + render(this, shape, data, headers, ctx) + } - is SimpleShape -> PrimitiveInstantiator(runtimeConfig, symbolProvider).instantiate(shape, data)(writer) + open fun render( + writer: RustWriter, + shape: Shape, + data: Node, + headers: Map = mapOf(), + ctx: Ctx = Ctx(), + ) { + customWritable.generate(shape) + ?.let { it(writer) } + ?: run { + when (shape) { + // Compound Shapes + is StructureShape -> renderStructure(writer, shape, data as ObjectNode, headers, ctx) + is UnionShape -> renderUnion(writer, shape, data as ObjectNode, ctx) + + // Collections + is ListShape -> renderList(writer, shape, data as ArrayNode, ctx) + is MapShape -> renderMap(writer, shape, data as ObjectNode, ctx) + is SetShape -> renderSet(writer, shape, data as ArrayNode, ctx) + + // Members, supporting potentially optional members + is MemberShape -> renderMember(writer, shape, data, ctx) + + is SimpleShape -> PrimitiveInstantiator(runtimeConfig, symbolProvider).instantiate( + shape, + data, + customWritable, + )(writer) - else -> writer.writeWithNoFormatting("todo!() /* $shape $data */") - } + else -> writer.writeWithNoFormatting("todo!() /* $shape $data */") + } + } } /** @@ -153,45 +180,49 @@ open class Instantiator( private fun renderMember(writer: RustWriter, memberShape: MemberShape, data: Node, ctx: Ctx) { val targetShape = model.expectShape(memberShape.target) val symbol = symbolProvider.toSymbol(memberShape) - if (data is NullNode && !targetShape.isDocumentShape) { - check(symbol.isOptional()) { - "A null node was provided for $memberShape but the symbol was not optional. This is invalid input data." - } - writer.rustTemplate("#{None}", *preludeScope) - } else { - // Structure builder setters for structure shape members _always_ take in `Option`. - // Other aggregate shapes' members are optional only when their symbol is. - writer.conditionalBlockTemplate( - "#{Some}(", - ")", - // The conditions are not commutative: note client builders always take in `Option`. - conditional = symbol.isOptional() || - ( - model.expectShape(memberShape.container) is StructureShape && builderKindBehavior.doesSetterTakeInOption( - memberShape, - ) - ), - *preludeScope, - ) { - writer.conditionalBlockTemplate( - "#{Box}::new(", - ")", - conditional = symbol.rustType().stripOuter() is RustType.Box, - *preludeScope, - ) { - render( - this, - targetShape, - data, - mapOf(), - ctx.copy() - .letIf(memberShape.hasTrait()) { - it.copy(lowercaseMapKeys = true) - }, - ) + customWritable.generate(memberShape) + ?.let { it(writer) } + ?: run { + if (data is NullNode && !targetShape.isDocumentShape) { + check(symbol.isOptional()) { + "A null node was provided for $memberShape but the symbol was not optional. This is invalid input data." + } + writer.rustTemplate("#{None}", *preludeScope) + } else { + // Structure builder setters for structure shape members _always_ take in `Option`. + // Other aggregate shapes' members are optional only when their symbol is. + writer.conditionalBlockTemplate( + "#{Some}(", + ")", + // The conditions are not commutative: note client builders always take in `Option`. + conditional = symbol.isOptional() || + ( + model.expectShape(memberShape.container) is StructureShape && builderKindBehavior.doesSetterTakeInOption( + memberShape, + ) + ), + *preludeScope, + ) { + writer.conditionalBlockTemplate( + "#{Box}::new(", + ")", + conditional = symbol.rustType().stripOuter() is RustType.Box, + *preludeScope, + ) { + render( + this, + targetShape, + data, + mapOf(), + ctx.copy() + .letIf(memberShape.hasTrait()) { + it.copy(lowercaseMapKeys = true) + }, + ) + } + } } } - } } private fun renderSet(writer: RustWriter, shape: SetShape, data: ArrayNode, ctx: Ctx) = @@ -291,6 +322,7 @@ open class Instantiator( writer.withBlockTemplate("#{T} {", "}", "T" to symbolProvider.toSymbol(shape)) { renderStructureMembers(writer, shape, data, headers, ctx) } + InstantiatorConstructPattern.BUILDER -> { writer.rust("#T::builder()", symbolProvider.toSymbol(shape)) @@ -321,6 +353,7 @@ open class Instantiator( renderMember(this, memberShape, value, ctx) } } + InstantiatorConstructPattern.BUILDER -> { val setterName = builderKindBehavior.setterName(memberShape) writer.withBlock(".$setterName(", ")") { @@ -398,67 +431,72 @@ open class Instantiator( } class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private val symbolProvider: SymbolProvider) { - fun instantiate(shape: SimpleShape, data: Node): Writable = writable { - when (shape) { - // Simple Shapes - is TimestampShape -> { - val node = (data as NumberNode) - val num = BigDecimal(node.toString()) - val wholePart = num.toInt() - val fractionalPart = num.remainder(BigDecimal.ONE) - rust( - "#T::from_fractional_secs($wholePart, ${fractionalPart}_f64)", - RuntimeType.dateTime(runtimeConfig), - ) - } - - /** - * ```rust - * Blob::new("arg") - * ``` - */ - is BlobShape -> if (shape.hasTrait()) { - rust( - "#T::from_static(b${(data as StringNode).value.dq()})", - RuntimeType.byteStream(runtimeConfig), - ) - } else { - rust( - "#T::new(${(data as StringNode).value.dq()})", - RuntimeType.blob(runtimeConfig), - ) - } + fun instantiate( + shape: SimpleShape, + data: Node, + customWritable: Instantiator.CustomWritable = Instantiator.NoCustomWritable(), + ): Writable = + customWritable.generate(shape) ?: writable { + when (shape) { + // Simple Shapes + is TimestampShape -> { + val node = (data as NumberNode) + val num = BigDecimal(node.toString()) + val wholePart = num.toInt() + val fractionalPart = num.remainder(BigDecimal.ONE) + rust( + "#T::from_fractional_secs($wholePart, ${fractionalPart}_f64)", + RuntimeType.dateTime(runtimeConfig), + ) + } - is StringShape -> renderString(shape, data as StringNode)(this) - is NumberShape -> when (data) { - is StringNode -> { - val numberSymbol = symbolProvider.toSymbol(shape) - // support Smithy custom values, such as Infinity + /** + * ```rust + * Blob::new("arg") + * ``` + */ + is BlobShape -> if (shape.hasTrait()) { rust( - """<#T as #T>::parse_smithy_primitive(${data.value.dq()}).expect("invalid string for number")""", - numberSymbol, - RuntimeType.smithyTypes(runtimeConfig).resolve("primitive::Parse"), + "#T::from_static(b${(data as StringNode).value.dq()})", + RuntimeType.byteStream(runtimeConfig), + ) + } else { + rust( + "#T::new(${(data as StringNode).value.dq()})", + RuntimeType.blob(runtimeConfig), ) } - is NumberNode -> write(data.value) - } + is StringShape -> renderString(shape, data as StringNode)(this) + is NumberShape -> when (data) { + is StringNode -> { + val numberSymbol = symbolProvider.toSymbol(shape) + // support Smithy custom values, such as Infinity + rust( + """<#T as #T>::parse_smithy_primitive(${data.value.dq()}).expect("invalid string for number")""", + numberSymbol, + RuntimeType.smithyTypes(runtimeConfig).resolve("primitive::Parse"), + ) + } - is BooleanShape -> rust(data.asBooleanNode().get().toString()) - is DocumentShape -> rustBlock("") { - val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType() - rustTemplate( - """ + is NumberNode -> write(data.value) + } + + is BooleanShape -> rust(data.asBooleanNode().get().toString()) + is DocumentShape -> rustBlock("") { + val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType() + rustTemplate( + """ let json_bytes = br##"${Node.prettyPrintJson(data)}"##; let mut tokens = #{json_token_iter}(json_bytes).peekable(); #{expect_document}(&mut tokens).expect("well formed json") """, - "expect_document" to smithyJson.resolve("deserialize::token::expect_document"), - "json_token_iter" to smithyJson.resolve("deserialize::json_token_iter"), - ) + "expect_document" to smithyJson.resolve("deserialize::token::expect_document"), + "json_token_iter" to smithyJson.resolve("deserialize::json_token_iter"), + ) + } } } - } private fun renderString(shape: StringShape, arg: StringNode): Writable = { val data = escape(arg.value).dq() diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt index 77a96c816..fade44886 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt @@ -68,15 +68,17 @@ class ServerBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiat codegenContext.symbolProvider.toSymbol(memberShape).isOptional() } -class ServerInstantiator(codegenContext: CodegenContext) : Instantiator( - codegenContext.symbolProvider, - codegenContext.model, - codegenContext.runtimeConfig, - ServerBuilderKindBehavior(codegenContext), - defaultsForRequiredFields = true, - customizations = listOf(ServerAfterInstantiatingValueConstrainItIfNecessary(codegenContext)), - constructPattern = InstantiatorConstructPattern.DIRECT, -) +class ServerInstantiator(codegenContext: CodegenContext, customWritable: CustomWritable = NoCustomWritable()) : + Instantiator( + codegenContext.symbolProvider, + codegenContext.model, + codegenContext.runtimeConfig, + ServerBuilderKindBehavior(codegenContext), + defaultsForRequiredFields = true, + customizations = listOf(ServerAfterInstantiatingValueConstrainItIfNecessary(codegenContext)), + constructPattern = InstantiatorConstructPattern.DIRECT, + customWritable = customWritable, + ) class ServerBuilderInstantiator( private val symbolProvider: RustSymbolProvider, diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt index cfb4d00bb..5a6f97915 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt @@ -8,18 +8,25 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import io.kotest.matchers.string.shouldNotContain import org.junit.jupiter.api.Test import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.ObjectNode +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.lookup @@ -248,4 +255,89 @@ class ServerInstantiatorTest { sut.render(writer, shape, data) writer.toString() shouldNotContain "builder()" } + + @Test + fun `uses writable for shapes`() { + val nestedStruct = model.lookup("com.test#NestedStruct") + val inner = model.lookup("com.test#Inner") + + val project = TestWorkspace.testProject(model) + nestedStruct.renderWithModelBuilder(model, symbolProvider, project) + inner.renderWithModelBuilder(model, symbolProvider, project) + project.moduleFor(nestedStruct) { + val nestedUnion = model.lookup("com.test#NestedUnion") + UnionGenerator(model, symbolProvider, this, nestedUnion).render() + + unitTest("writable_for_shapes") { + val sut = ServerInstantiator( + codegenContext, + customWritable = object : Instantiator.CustomWritable { + override fun generate(shape: Shape): Writable? = + if (model.lookup("com.test#NestedStruct\$num") == shape) { + writable("40 + 2") + } else { + null + } + }, + ) + val data = Node.parse("""{ "str": "hello", "num": 1 }""") + withBlock("let result = ", ";") { + sut.render(this, model.lookup("com.test#NestedStruct"), data as ObjectNode) + } + rust( + """ + assert_eq!(result.num, 42); + assert_eq!(result.str, "hello"); + """, + ) + } + + unitTest("writable_for_nested_inner_members") { + val map = model.lookup("com.test#Inner\$map") + val sut = ServerInstantiator( + codegenContext, + customWritable = object : Instantiator.CustomWritable { + private var n: Int = 0 + override fun generate(shape: Shape): Writable? = + if (shape != map) { + null + } else if (n != 2) { + n += 1 + null + } else { + n += 1 + writable("None") + } + }, + ) + val data = Node.parse( + """ + { + "map": { + "k1": { + "map": { + "k2": { + "map": { + "never": {} + } + } + } + } + } + } + """, + ) + + withBlock("let result = ", ";") { + sut.render(this, inner, data as ObjectNode) + } + rust( + """ + assert_eq!(result.map().unwrap().get("k1").unwrap().map().unwrap().get("k2").unwrap().map(), None); + """, + ) + } + } + project.compileAndTest(runClippy = true) + } } -- GitLab