Unverified Commit a5f1653c authored by 82marbag's avatar 82marbag Committed by GitHub
Browse files

Allow custom writables in Instantiator (#3104)



`Instantiator` currently renders only static data.
This PR allows it to customize this behavior to allow non-static data
(e.g. data taken from other structs).


----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>
parent 7fb78bb6
Loading
Loading
Loading
Loading
+141 −103
Original line number Diff line number Diff line
@@ -91,6 +91,7 @@ open class Instantiator(
    private val defaultsForRequiredFields: Boolean = false,
    private val customizations: List<InstantiatorCustomization> = listOf(),
    private val constructPattern: InstantiatorConstructPattern = InstantiatorConstructPattern.BUILDER,
    private val customWritable: CustomWritable = NoCustomWritable(),
) {
    data class Ctx(
        // The `http` crate requires that headers be lowercase, but Smithy protocol tests
@@ -122,11 +123,32 @@ open class Instantiator(
        fun doesSetterTakeInOption(memberShape: MemberShape): Boolean
    }

    /**
     * Customize how each shape is rendered, instead of relying on static `Node` data.
     */
    interface CustomWritable {
        // Return `null` to rely on the default behavior, which uses the static `Node` data.
        fun generate(shape: Shape): Writable?
    }

    class NoCustomWritable : CustomWritable {
        override fun generate(shape: Shape): Writable? = null
    }

    fun generate(shape: Shape, data: Node, headers: Map<String, String> = mapOf(), ctx: Ctx = Ctx()) = writable {
        render(this, shape, data, headers, ctx)
    }

    fun render(writer: RustWriter, shape: Shape, data: Node, headers: Map<String, String> = mapOf(), ctx: Ctx = Ctx()) {
    open fun render(
        writer: RustWriter,
        shape: Shape,
        data: Node,
        headers: Map<String, String> = mapOf(),
        ctx: Ctx = Ctx(),
    ) {
        customWritable.generate(shape)
            ?.let { it(writer) }
            ?: run {
                when (shape) {
                    // Compound Shapes
                    is StructureShape -> renderStructure(writer, shape, data as ObjectNode, headers, ctx)
@@ -140,11 +162,16 @@ open class Instantiator(
                    // Members, supporting potentially optional members
                    is MemberShape -> renderMember(writer, shape, data, ctx)

            is SimpleShape -> PrimitiveInstantiator(runtimeConfig, symbolProvider).instantiate(shape, data)(writer)
                    is SimpleShape -> PrimitiveInstantiator(runtimeConfig, symbolProvider).instantiate(
                        shape,
                        data,
                        customWritable,
                    )(writer)

                    else -> writer.writeWithNoFormatting("todo!() /* $shape $data */")
                }
            }
    }

    /**
     * If the shape is optional: `Some(inner)` or `None`.
@@ -153,6 +180,9 @@ open class Instantiator(
    private fun renderMember(writer: RustWriter, memberShape: MemberShape, data: Node, ctx: Ctx) {
        val targetShape = model.expectShape(memberShape.target)
        val symbol = symbolProvider.toSymbol(memberShape)
        customWritable.generate(memberShape)
            ?.let { it(writer) }
            ?: run {
                if (data is NullNode && !targetShape.isDocumentShape) {
                    check(symbol.isOptional()) {
                        "A null node was provided for $memberShape but the symbol was not optional. This is invalid input data."
@@ -193,6 +223,7 @@ open class Instantiator(
                    }
                }
            }
    }

    private fun renderSet(writer: RustWriter, shape: SetShape, data: ArrayNode, ctx: Ctx) =
        renderList(writer, shape, data, ctx)
@@ -291,6 +322,7 @@ open class Instantiator(
                writer.withBlockTemplate("#{T} {", "}", "T" to symbolProvider.toSymbol(shape)) {
                    renderStructureMembers(writer, shape, data, headers, ctx)
                }

            InstantiatorConstructPattern.BUILDER -> {
                writer.rust("#T::builder()", symbolProvider.toSymbol(shape))

@@ -321,6 +353,7 @@ open class Instantiator(
                        renderMember(this, memberShape, value, ctx)
                    }
                }

                InstantiatorConstructPattern.BUILDER -> {
                    val setterName = builderKindBehavior.setterName(memberShape)
                    writer.withBlock(".$setterName(", ")") {
@@ -398,7 +431,12 @@ open class Instantiator(
}

class PrimitiveInstantiator(private val runtimeConfig: RuntimeConfig, private val symbolProvider: SymbolProvider) {
    fun instantiate(shape: SimpleShape, data: Node): Writable = writable {
    fun instantiate(
        shape: SimpleShape,
        data: Node,
        customWritable: Instantiator.CustomWritable = Instantiator.NoCustomWritable(),
    ): Writable =
        customWritable.generate(shape) ?: writable {
            when (shape) {
                // Simple Shapes
                is TimestampShape -> {
+11 −9
Original line number Diff line number Diff line
@@ -68,7 +68,8 @@ class ServerBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiat
        codegenContext.symbolProvider.toSymbol(memberShape).isOptional()
}

class ServerInstantiator(codegenContext: CodegenContext) : Instantiator(
class ServerInstantiator(codegenContext: CodegenContext, customWritable: CustomWritable = NoCustomWritable()) :
    Instantiator(
        codegenContext.symbolProvider,
        codegenContext.model,
        codegenContext.runtimeConfig,
@@ -76,6 +77,7 @@ class ServerInstantiator(codegenContext: CodegenContext) : Instantiator(
        defaultsForRequiredFields = true,
        customizations = listOf(ServerAfterInstantiatingValueConstrainItIfNecessary(codegenContext)),
        constructPattern = InstantiatorConstructPattern.DIRECT,
        customWritable = customWritable,
    )

class ServerBuilderInstantiator(
+92 −0
Original line number Diff line number Diff line
@@ -8,18 +8,25 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators
import io.kotest.matchers.string.shouldNotContain
import org.junit.jupiter.api.Test
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer
import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.lookup
@@ -248,4 +255,89 @@ class ServerInstantiatorTest {
        sut.render(writer, shape, data)
        writer.toString() shouldNotContain "builder()"
    }

    @Test
    fun `uses writable for shapes`() {
        val nestedStruct = model.lookup<StructureShape>("com.test#NestedStruct")
        val inner = model.lookup<StructureShape>("com.test#Inner")

        val project = TestWorkspace.testProject(model)
        nestedStruct.renderWithModelBuilder(model, symbolProvider, project)
        inner.renderWithModelBuilder(model, symbolProvider, project)
        project.moduleFor(nestedStruct) {
            val nestedUnion = model.lookup<UnionShape>("com.test#NestedUnion")
            UnionGenerator(model, symbolProvider, this, nestedUnion).render()

            unitTest("writable_for_shapes") {
                val sut = ServerInstantiator(
                    codegenContext,
                    customWritable = object : Instantiator.CustomWritable {
                        override fun generate(shape: Shape): Writable? =
                            if (model.lookup<MemberShape>("com.test#NestedStruct\$num") == shape) {
                                writable("40 + 2")
                            } else {
                                null
                            }
                    },
                )
                val data = Node.parse("""{ "str": "hello", "num": 1 }""")
                withBlock("let result = ", ";") {
                    sut.render(this, model.lookup("com.test#NestedStruct"), data as ObjectNode)
                }
                rust(
                    """
                    assert_eq!(result.num, 42);
                    assert_eq!(result.str, "hello");
                    """,
                )
            }

            unitTest("writable_for_nested_inner_members") {
                val map = model.lookup<MemberShape>("com.test#Inner\$map")
                val sut = ServerInstantiator(
                    codegenContext,
                    customWritable = object : Instantiator.CustomWritable {
                        private var n: Int = 0
                        override fun generate(shape: Shape): Writable? =
                            if (shape != map) {
                                null
                            } else if (n != 2) {
                                n += 1
                                null
                            } else {
                                n += 1
                                writable("None")
                            }
                    },
                )
                val data = Node.parse(
                    """
                    {
                        "map": {
                            "k1": {
                                "map": {
                                    "k2": {
                                        "map": {
                                            "never": {}
                                        }
                                    }
                                }
                            }
                        }
                    }
                    """,
                )

                withBlock("let result = ", ";") {
                    sut.render(this, inner, data as ObjectNode)
                }
                rust(
                    """
                    assert_eq!(result.map().unwrap().get("k1").unwrap().map().unwrap().get("k2").unwrap().map(), None);
                    """,
                )
            }
        }
        project.compileAndTest(runClippy = true)
    }
}