Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt +6 −2 Original line number Diff line number Diff line Loading @@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ServiceGenerator import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader import software.amazon.smithy.rust.codegen.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.util.CommandFailed import software.amazon.smithy.rust.codegen.util.runCommand import java.util.logging.Logger Loading @@ -57,10 +58,11 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() { init { val symbolVisitorConfig = SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig) val service = settings.getService(context.model) val baseModel = baselineTransform(context.model) val service = settings.getService(baseModel) val (protocol, generator) = ProtocolLoader.Default.protocolFor(context.model, service) protocolGenerator = generator model = generator.transformModel(context.model) model = generator.transformModel(baseModel) val baseProvider = RustCodegenPlugin.BaseSymbolProvider(model, symbolVisitorConfig) symbolProvider = generator.symbolProvider(model, baseProvider) Loading @@ -74,6 +76,8 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() { httpGenerator = protocolGenerator.buildProtocolGenerator(protocolConfig) } private fun baselineTransform(model: Model) = RecursiveShapeBoxer.transform(model) private fun CodegenWriterDelegator<RustWriter>.includedModules(): List<String> = this.writers.values.mapNotNull { it.module() } fun execute() { Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RustBox.kt→codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RustBoxTrait.kt +1 −1 Original line number Diff line number Diff line Loading @@ -9,7 +9,7 @@ import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.traits.Trait class RustBox : Trait { class RustBoxTrait : Trait { val ID = ShapeId.from("software.amazon.smithy.rust.codegen.smithy.rust.synthetic#box") override fun toNode(): Node = Node.objectNode() Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt +1 −1 Original line number Diff line number Diff line Loading @@ -132,7 +132,7 @@ class SymbolVisitor( } private fun handleRustBoxing(symbol: Symbol, shape: Shape): Symbol { return if (shape.hasTrait(RustBox::class.java)) { return if (shape.hasTrait(RustBoxTrait::class.java)) { val rustType = RustType.Box(symbol.rustType()) with(Symbol.builder()) { rustType(rustType) Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/RecursiveShapeBoxer.kt 0 → 100644 +87 −0 Original line number Diff line number Diff line package software.amazon.smithy.rust.codegen.smithy.transformers import software.amazon.smithy.codegen.core.TopologicalIndex import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.SetShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.smithy.RustBoxTrait object RecursiveShapeBoxer { /** * Transform a model which may contain recursive shapes into a model annotated with [RustBoxTrait] * * When recursive shapes do NOT go through a List, Map, or Set, they must be boxed in Rust. This function will * iteratively find loops & add the `RustBox` trait in a deterministic way until it reaches a fixed point. * * This function MUST be deterministic (always choose the same shapes to `Box`). If it is not, that is a bug. Even so * this function may cause backward compatibility issues in certain pathological cases where a changes to recursive * structures cause different members to be boxed. We may need to address these via customizations. */ fun transform(model: Model): Model { val next = transformInner(model) return if (next == null) { model } else { transform(next) } } /** * If [model] contains a recursive loop that must be boxed, apply one instance of [RustBoxTrait] return the new model. * If [model] contains no loops, return null. */ private fun transformInner(model: Model): Model? { // Execute 1-step of the boxing algorithm in the path to reaching a fixed point // 1. Find all the shapes that are part of a cycle // 2. Find all the loops that those shapes are part of // 3. Filter out the loops that go through a layer of indirection // 3. Pick _just one_ of the remaining loops to fix // 4. Select the member shape in that loop with the earliest shape id // 5. Box it. // (External to this function) Go back to 1. val index = TopologicalIndex(model) val recursiveShapes = index.recursiveShapes val loops = recursiveShapes.map { // Get all the shapes in the closure (represented as Paths shapeId -> index.getRecursiveClosure(shapeId) }.flatMap { // flatten the connections into shapes loops -> loops.map { it.shapes } } val loopToFix = loops.firstOrNull { !containsIndirection(it) } return loopToFix?.let { loop: List<Shape> -> check(loop.isNotEmpty()) // pick the shape to box in a deterministic way val shapeToBox = loop.filterIsInstance<MemberShape>().minBy { it.id }!! ModelTransformer.create().mapShapes(model) { shape -> if (shape == shapeToBox) { shape.asMemberShape().get().toBuilder().addTrait(RustBoxTrait()).build() } else { shape } } } } /** * Check if a List<Shape> contains a shape which will use a pointer when represented in Rust, avoiding the * need to add more Boxes */ private fun containsIndirection(loop: List<Shape>): Boolean { return loop.find { when (it) { is ListShape, is MapShape, is SetShape -> true else -> it.hasTrait(RustBoxTrait::class.java) } } != null } } codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/RecursiveShapeBoxerTest.kt 0 → 100644 +75 −0 Original line number Diff line number Diff line package software.amazon.smithy.rust.codegen.smithy.transformers import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.rust.codegen.smithy.RustBoxTrait import software.amazon.smithy.rust.codegen.util.lookup import software.amazon.smithy.rust.testutil.asSmithy import kotlin.streams.toList internal class RecursiveShapeBoxerTest { @Test fun `leave non-recursive models unchanged`() { val model = """ namespace com.example list BarList { member: Bar } structure Hello { bars: BarList } structure Bar { hello: Hello } """.asSmithy() RecursiveShapeBoxer.transform(model) shouldBe model } @Test fun `add the box trait to simple recursive shapes`() { val model = """ namespace com.example structure Recursive { RecursiveStruct: Recursive, anotherField: Boolean } """.asSmithy() val transformed = RecursiveShapeBoxer.transform(model) val member: MemberShape = transformed.lookup("com.example#Recursive\$RecursiveStruct") member.expectTrait(RustBoxTrait::class.java) } @Test fun `add the box trait to complex structures`() { val model = """ namespace com.example structure Expr { left: Atom, right: Atom } union Atom { add: Expr, sub: Expr, literal: Integer, more: SecondTree } structure SecondTree { member: Expr, otherMember: Atom, third: SecondTree } """.asSmithy() val transformed = RecursiveShapeBoxer.transform(model) val boxed = transformed.shapes().filter { it.hasTrait(RustBoxTrait::class.java) }.toList() boxed.map { it.id.toString().removePrefix("com.example#") }.toSet() shouldBe setOf( "Atom\$add", "Atom\$sub", "SecondTree\$third", "Atom\$more" ) } } Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/CodegenVisitor.kt +6 −2 Original line number Diff line number Diff line Loading @@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ServiceGenerator import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader import software.amazon.smithy.rust.codegen.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.util.CommandFailed import software.amazon.smithy.rust.codegen.util.runCommand import java.util.logging.Logger Loading @@ -57,10 +58,11 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() { init { val symbolVisitorConfig = SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig) val service = settings.getService(context.model) val baseModel = baselineTransform(context.model) val service = settings.getService(baseModel) val (protocol, generator) = ProtocolLoader.Default.protocolFor(context.model, service) protocolGenerator = generator model = generator.transformModel(context.model) model = generator.transformModel(baseModel) val baseProvider = RustCodegenPlugin.BaseSymbolProvider(model, symbolVisitorConfig) symbolProvider = generator.symbolProvider(model, baseProvider) Loading @@ -74,6 +76,8 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() { httpGenerator = protocolGenerator.buildProtocolGenerator(protocolConfig) } private fun baselineTransform(model: Model) = RecursiveShapeBoxer.transform(model) private fun CodegenWriterDelegator<RustWriter>.includedModules(): List<String> = this.writers.values.mapNotNull { it.module() } fun execute() { Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RustBox.kt→codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RustBoxTrait.kt +1 −1 Original line number Diff line number Diff line Loading @@ -9,7 +9,7 @@ import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.traits.Trait class RustBox : Trait { class RustBoxTrait : Trait { val ID = ShapeId.from("software.amazon.smithy.rust.codegen.smithy.rust.synthetic#box") override fun toNode(): Node = Node.objectNode() Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt +1 −1 Original line number Diff line number Diff line Loading @@ -132,7 +132,7 @@ class SymbolVisitor( } private fun handleRustBoxing(symbol: Symbol, shape: Shape): Symbol { return if (shape.hasTrait(RustBox::class.java)) { return if (shape.hasTrait(RustBoxTrait::class.java)) { val rustType = RustType.Box(symbol.rustType()) with(Symbol.builder()) { rustType(rustType) Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/RecursiveShapeBoxer.kt 0 → 100644 +87 −0 Original line number Diff line number Diff line package software.amazon.smithy.rust.codegen.smithy.transformers import software.amazon.smithy.codegen.core.TopologicalIndex import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.SetShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.smithy.RustBoxTrait object RecursiveShapeBoxer { /** * Transform a model which may contain recursive shapes into a model annotated with [RustBoxTrait] * * When recursive shapes do NOT go through a List, Map, or Set, they must be boxed in Rust. This function will * iteratively find loops & add the `RustBox` trait in a deterministic way until it reaches a fixed point. * * This function MUST be deterministic (always choose the same shapes to `Box`). If it is not, that is a bug. Even so * this function may cause backward compatibility issues in certain pathological cases where a changes to recursive * structures cause different members to be boxed. We may need to address these via customizations. */ fun transform(model: Model): Model { val next = transformInner(model) return if (next == null) { model } else { transform(next) } } /** * If [model] contains a recursive loop that must be boxed, apply one instance of [RustBoxTrait] return the new model. * If [model] contains no loops, return null. */ private fun transformInner(model: Model): Model? { // Execute 1-step of the boxing algorithm in the path to reaching a fixed point // 1. Find all the shapes that are part of a cycle // 2. Find all the loops that those shapes are part of // 3. Filter out the loops that go through a layer of indirection // 3. Pick _just one_ of the remaining loops to fix // 4. Select the member shape in that loop with the earliest shape id // 5. Box it. // (External to this function) Go back to 1. val index = TopologicalIndex(model) val recursiveShapes = index.recursiveShapes val loops = recursiveShapes.map { // Get all the shapes in the closure (represented as Paths shapeId -> index.getRecursiveClosure(shapeId) }.flatMap { // flatten the connections into shapes loops -> loops.map { it.shapes } } val loopToFix = loops.firstOrNull { !containsIndirection(it) } return loopToFix?.let { loop: List<Shape> -> check(loop.isNotEmpty()) // pick the shape to box in a deterministic way val shapeToBox = loop.filterIsInstance<MemberShape>().minBy { it.id }!! ModelTransformer.create().mapShapes(model) { shape -> if (shape == shapeToBox) { shape.asMemberShape().get().toBuilder().addTrait(RustBoxTrait()).build() } else { shape } } } } /** * Check if a List<Shape> contains a shape which will use a pointer when represented in Rust, avoiding the * need to add more Boxes */ private fun containsIndirection(loop: List<Shape>): Boolean { return loop.find { when (it) { is ListShape, is MapShape, is SetShape -> true else -> it.hasTrait(RustBoxTrait::class.java) } } != null } }
codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/RecursiveShapeBoxerTest.kt 0 → 100644 +75 −0 Original line number Diff line number Diff line package software.amazon.smithy.rust.codegen.smithy.transformers import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.rust.codegen.smithy.RustBoxTrait import software.amazon.smithy.rust.codegen.util.lookup import software.amazon.smithy.rust.testutil.asSmithy import kotlin.streams.toList internal class RecursiveShapeBoxerTest { @Test fun `leave non-recursive models unchanged`() { val model = """ namespace com.example list BarList { member: Bar } structure Hello { bars: BarList } structure Bar { hello: Hello } """.asSmithy() RecursiveShapeBoxer.transform(model) shouldBe model } @Test fun `add the box trait to simple recursive shapes`() { val model = """ namespace com.example structure Recursive { RecursiveStruct: Recursive, anotherField: Boolean } """.asSmithy() val transformed = RecursiveShapeBoxer.transform(model) val member: MemberShape = transformed.lookup("com.example#Recursive\$RecursiveStruct") member.expectTrait(RustBoxTrait::class.java) } @Test fun `add the box trait to complex structures`() { val model = """ namespace com.example structure Expr { left: Atom, right: Atom } union Atom { add: Expr, sub: Expr, literal: Integer, more: SecondTree } structure SecondTree { member: Expr, otherMember: Atom, third: SecondTree } """.asSmithy() val transformed = RecursiveShapeBoxer.transform(model) val boxed = transformed.shapes().filter { it.hasTrait(RustBoxTrait::class.java) }.toList() boxed.map { it.id.toString().removePrefix("com.example#") }.toSet() shouldBe setOf( "Atom\$add", "Atom\$sub", "SecondTree\$third", "Atom\$more" ) } }