Unverified Commit 7e580f64 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Generate builders for recursive shapes (#33)

parent a6ce8395
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@

package software.amazon.smithy.rust.codegen.lang

import org.intellij.lang.annotations.Language
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.writer.CodegenWriter
@@ -35,6 +36,13 @@ fun <T : CodeWriter> T.withBlock(
    return this
}

/**
 * Convenience wrapper that tells Intellij that the contents of this block are Rust
 */
fun <T : CodeWriter> T.rust(@Language("Rust", prefix = "fn foo() {", suffix = "}") contents: String, vararg args: Any) {
    this.write(contents, *args)
}

/*
 * Writes a Rust-style block, demarcated by curly braces
 */
+4 −3
Original line number Diff line number Diff line
@@ -240,10 +240,11 @@ class SymbolVisitor(
    override fun memberShape(shape: MemberShape): Symbol {
        val target = model.expectShape(shape.target)
        val targetSymbol = this.toSymbol(target)
        return targetSymbol.letIf(config.handleOptionality) {
            handleOptionality(it, shape, model.expectShape(shape.container))
        }.letIf(config.handleRustBoxing) {
        // Handle boxing first so we end up with Option<Box<_>>, not Box<Option<_>>
        return targetSymbol.letIf(config.handleRustBoxing) {
            handleRustBoxing(it, shape)
        }.letIf(config.handleOptionality) {
            handleOptionality(it, shape, model.expectShape(shape.container))
        }
    }

+9 −6
Original line number Diff line number Diff line
@@ -122,8 +122,9 @@ class StructureGenerator(
            }
        }

        fun builderConverter(rustType: RustType) = when (rustType) {
            is RustType.String -> "inp.into()"
        fun builderConverter(coreType: RustType, outerType: RustType) = when (coreType) {
            is RustType.String,
            is RustType.Box -> "inp.into()"
            else -> "inp"
        }

@@ -132,18 +133,20 @@ class StructureGenerator(
                val memberName = symbolProvider.toMemberName(member)
                // All fields in the builder are optional
                val memberSymbol = symbolProvider.toSymbol(member)
                val coreType = memberSymbol.rustType().let {
                val outerType = memberSymbol.rustType()
                val coreType = outerType.let {
                    when (it) {
                        is RustType.Option -> it.value
                        else -> it
                    }
                }
                val signature = when (coreType) {
                    is RustType.String -> "<T: Into<String>>(mut self, inp: T) -> Self"
                    is RustType.String -> "<Str: Into<String>>(mut self, inp: Str) -> Self"
                    is RustType.Box -> "<T>(mut self, inp: T) -> Self where T: Into<${coreType.render()}>"
                    else -> "(mut self, inp: ${coreType.render()}) -> Self"
                }
                writer.rustBlock("pub fn $memberName$signature") {
                    write("self.$memberName = Some(${builderConverter(coreType)});")
                    write("self.$memberName = Some(${builderConverter(coreType, outerType)});")
                    write("self")
                }
            }
@@ -154,7 +157,7 @@ class StructureGenerator(
                false -> "\$T"
            }

            writer.rustBlock("pub fn build(self) -> $returnType", structureSymbol) {
            rustBlock("pub fn build(self) -> $returnType", structureSymbol) {
                withBlock("Ok(", ")", conditional = fallibleBuilder) {
                    rustBlock("\$T", structureSymbol) {
                        members.forEach { member ->
+43 −1
Original line number Diff line number Diff line
@@ -8,8 +8,10 @@ import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.lang.rust
import software.amazon.smithy.rust.codegen.lang.rustBlock
import software.amazon.smithy.rust.codegen.lang.withBlock
import software.amazon.smithy.rust.codegen.smithy.transformers.RecursiveShapeBoxer
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.lookup
import software.amazon.smithy.rust.testutil.TestRuntimeConfig
@@ -52,7 +54,12 @@ class InstantiatorTest {
            key: String,
            value: Inner
        }
        """.asSmithy()

        structure WithBox {
            member: WithBox,
            value: Integer
        }
        """.asSmithy().let { RecursiveShapeBoxer.transform(it) }

    private val symbolProvider = testSymbolProvider(model)
    private val runtimeConfig = TestRuntimeConfig
@@ -104,6 +111,41 @@ class InstantiatorTest {
        writer.compileAndTest()
    }

    @Test
    fun `generate builders for boxed structs`() {
        val structure = model.lookup<StructureShape>("com.test#WithBox")
        val sut = Instantiator(symbolProvider, model, runtimeConfig)
        val data = Node.parse(
            """ {
                "member": {
                    "member": { }
                }, "value": 10
            }
            """.trimIndent()
        )
        val writer = RustWriter.forModule("model")
        val structureGenerator = StructureGenerator(model, symbolProvider, writer, structure)
        structureGenerator.render()
        writer.write("#[test]")
        writer.rustBlock("fn inst()") {
            withBlock("let result = ", ";") {
                sut.render(data, structure, this)
            }
            rust(
                """
                assert_eq!(result, WithBox {
                    value: Some(10),
                    member: Some(Box::new(WithBox {
                        value: None,
                        member: Some(Box::new(WithBox { value: None, member: None })),
                    }))
                });
            """
            )
        }
        writer.compileAndTest()
    }

    @Test
    fun `generate lists`() {
        val data = Node.parse(