Unverified Commit a6ce8395 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Support recursive shapes (#32)

* Support recursive shapes

* Fix doc typo

* Update comments some more
parent 4f327bc1
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ServiceGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader
import software.amazon.smithy.rust.codegen.smithy.transformers.RecursiveShapeBoxer
import software.amazon.smithy.rust.codegen.util.CommandFailed
import software.amazon.smithy.rust.codegen.util.runCommand
import java.util.logging.Logger
@@ -57,10 +58,11 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {

    init {
        val symbolVisitorConfig = SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig)
        val service = settings.getService(context.model)
        val baseModel = baselineTransform(context.model)
        val service = settings.getService(baseModel)
        val (protocol, generator) = ProtocolLoader.Default.protocolFor(context.model, service)
        protocolGenerator = generator
        model = generator.transformModel(context.model)
        model = generator.transformModel(baseModel)
        val baseProvider = RustCodegenPlugin.BaseSymbolProvider(model, symbolVisitorConfig)
        symbolProvider = generator.symbolProvider(model, baseProvider)

@@ -74,6 +76,8 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
        httpGenerator = protocolGenerator.buildProtocolGenerator(protocolConfig)
    }

    private fun baselineTransform(model: Model) = RecursiveShapeBoxer.transform(model)

    private fun CodegenWriterDelegator<RustWriter>.includedModules(): List<String> = this.writers.values.mapNotNull { it.module() }

    fun execute() {
+1 −1
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.Trait

class RustBox : Trait {
class RustBoxTrait : Trait {
    val ID = ShapeId.from("software.amazon.smithy.rust.codegen.smithy.rust.synthetic#box")
    override fun toNode(): Node = Node.objectNode()

+1 −1
Original line number Diff line number Diff line
@@ -132,7 +132,7 @@ class SymbolVisitor(
    }

    private fun handleRustBoxing(symbol: Symbol, shape: Shape): Symbol {
        return if (shape.hasTrait(RustBox::class.java)) {
        return if (shape.hasTrait(RustBoxTrait::class.java)) {
            val rustType = RustType.Box(symbol.rustType())
            with(Symbol.builder()) {
                rustType(rustType)
+87 −0
Original line number Diff line number Diff line
package software.amazon.smithy.rust.codegen.smithy.transformers

import software.amazon.smithy.codegen.core.TopologicalIndex
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.SetShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.smithy.RustBoxTrait

object RecursiveShapeBoxer {
    /**
     * Transform a model which may contain recursive shapes into a model annotated with [RustBoxTrait]
     *
     * When recursive shapes do NOT go through a List, Map, or Set, they must be boxed in Rust. This function will
     * iteratively find loops & add the `RustBox` trait in a deterministic way until it reaches a fixed point.
     *
     * This function MUST be deterministic (always choose the same shapes to `Box`). If it is not, that is a bug. Even so
     * this function may cause backward compatibility issues in certain pathological cases where a changes to recursive
     * structures cause different members to be boxed. We may need to address these via customizations.
     */
    fun transform(model: Model): Model {
        val next = transformInner(model)
        return if (next == null) {
            model
        } else {
            transform(next)
        }
    }

    /**
     * If [model] contains a recursive loop that must be boxed, apply one instance of [RustBoxTrait] return the new model.
     * If [model] contains no loops, return null.
     */
    private fun transformInner(model: Model): Model? {
        // Execute 1-step of the boxing algorithm in the path to reaching a fixed point
        // 1. Find all the shapes that are part of a cycle
        // 2. Find all the loops that those shapes are part of
        // 3. Filter out the loops that go through a layer of indirection
        // 3. Pick _just one_ of the remaining loops to fix
        // 4. Select the member shape in that loop with the earliest shape id
        // 5. Box it.
        // (External to this function) Go back to 1.
        val index = TopologicalIndex(model)
        val recursiveShapes = index.recursiveShapes
        val loops = recursiveShapes.map {
            // Get all the shapes in the closure (represented as Paths
            shapeId ->
            index.getRecursiveClosure(shapeId)
        }.flatMap {
            // flatten the connections into shapes
            loops ->
            loops.map { it.shapes }
        }
        val loopToFix = loops.firstOrNull { !containsIndirection(it) }

        return loopToFix?.let { loop: List<Shape> ->
            check(loop.isNotEmpty())
            // pick the shape to box in a deterministic way
            val shapeToBox = loop.filterIsInstance<MemberShape>().minBy { it.id }!!
            ModelTransformer.create().mapShapes(model) { shape ->
                if (shape == shapeToBox) {
                    shape.asMemberShape().get().toBuilder().addTrait(RustBoxTrait()).build()
                } else {
                    shape
                }
            }
        }
    }

    /**
     * Check if a List<Shape> contains a shape which will use a pointer when represented in Rust, avoiding the
     * need to add more Boxes
     */
    private fun containsIndirection(loop: List<Shape>): Boolean {
        return loop.find {
            when (it) {
                is ListShape,
                is MapShape,
                is SetShape -> true
                else -> it.hasTrait(RustBoxTrait::class.java)
            }
        } != null
    }
}
+75 −0
Original line number Diff line number Diff line
package software.amazon.smithy.rust.codegen.smithy.transformers

import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.rust.codegen.smithy.RustBoxTrait
import software.amazon.smithy.rust.codegen.util.lookup
import software.amazon.smithy.rust.testutil.asSmithy
import kotlin.streams.toList

internal class RecursiveShapeBoxerTest {
    @Test
    fun `leave non-recursive models unchanged`() {
        val model = """
            namespace com.example
            list BarList {
                member: Bar
            }
            structure Hello {
                bars: BarList
            }

            structure Bar {
                hello: Hello
            }
        """.asSmithy()
        RecursiveShapeBoxer.transform(model) shouldBe model
    }

    @Test
    fun `add the box trait to simple recursive shapes`() {
        val model = """
        namespace com.example
        structure Recursive {
            RecursiveStruct: Recursive,
            anotherField: Boolean
        }
        """.asSmithy()
        val transformed = RecursiveShapeBoxer.transform(model)
        val member: MemberShape = transformed.lookup("com.example#Recursive\$RecursiveStruct")
        member.expectTrait(RustBoxTrait::class.java)
    }

    @Test
    fun `add the box trait to complex structures`() {
        val model = """
       namespace com.example
       structure Expr {
            left: Atom,
            right: Atom
       }

       union Atom {
            add: Expr,
            sub: Expr,
            literal: Integer,
            more: SecondTree
       }

       structure SecondTree {
            member: Expr,
            otherMember: Atom,
            third: SecondTree
       }
       """.asSmithy()
        val transformed = RecursiveShapeBoxer.transform(model)
        val boxed = transformed.shapes().filter { it.hasTrait(RustBoxTrait::class.java) }.toList()
        boxed.map { it.id.toString().removePrefix("com.example#") }.toSet() shouldBe setOf(
            "Atom\$add",
            "Atom\$sub",
            "SecondTree\$third",
            "Atom\$more"
        )
    }
}
Loading