Unverified Commit 30502ebc authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Refactor Operation shape builders (#68)

* Refactor structure builders out of StructureGenerator

* Refactor Operation Shape

This commit builds on the builder refactoring to introduce `Operation`s as the top level object that users interact with. This has a number of improvements:

1. The operations page in the documentation is now clean and includes exactly the operations that are available.
2. Builders for Input shapes now actually return an `Operation`. In a follow-up commit, the `build()` method on input shapes will be
updated to accept a service configuration object.
3. The split-impl block for operation shapes has been removed.
4. Docs for operation shapes are now the documentation for the operation and not for the input
5. This sets us up for having a top level operation onto which we can attach a `from_response(..)` method in a follow-up PR to implement response parsing.
parent 63fcc795
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -26,9 +26,10 @@ import java.util.function.BiFunction
fun <T : CodeWriter> T.withBlock(
    textBeforeNewLine: String,
    textAfterNewLine: String,
    vararg args: Any,
    block: T.() -> Unit
): T {
    return conditionalBlock(textBeforeNewLine, textAfterNewLine, conditional = true, block = block)
    return conditionalBlock(textBeforeNewLine, textAfterNewLine, conditional = true, block = block, args = *args)
}

/**
@@ -47,10 +48,11 @@ fun <T : CodeWriter> T.conditionalBlock(
    textBeforeNewLine: String,
    textAfterNewLine: String,
    conditional: Boolean = true,
    vararg args: Any,
    block: T.() -> Unit
): T {
    if (conditional) {
        openBlock(textBeforeNewLine)
        openBlock(textBeforeNewLine, *args)
    }

    block(this)
+16 −6
Original line number Diff line number Diff line
@@ -6,7 +6,6 @@
package software.amazon.smithy.rust.codegen.smithy

import software.amazon.smithy.build.PluginContext
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.codegen.core.writer.CodegenWriterDelegator
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.neighbor.Walker
@@ -26,12 +25,15 @@ import software.amazon.smithy.rust.codegen.smithy.generators.CargoTomlGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.HttpProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.LibRsGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.ModelBuilderGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.smithy.generators.ServiceGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.smithy.transformers.RecursiveShapeBoxer
import software.amazon.smithy.rust.codegen.util.CommandFailed
import software.amazon.smithy.rust.codegen.util.runCommand
@@ -40,14 +42,14 @@ import java.util.logging.Logger
/**
 * Allowlist of modules that will be exposed publicly in generated crates
 */
private val PublicModules = setOf("error", "operation", "model")
private val PublicModules = setOf("error", "operation", "model", "output", "input")

class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {

    private val logger = Logger.getLogger(javaClass.name)
    private val settings = RustSettings.from(context.model, context.settings)

    private val symbolProvider: SymbolProvider
    private val symbolProvider: RustSymbolProvider
    private val writers: CodegenWriterDelegator<RustWriter>
    private val fileManifest = context.fileManifest
    private val model: Model
@@ -77,7 +79,8 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {

    private fun baselineTransform(model: Model) = RecursiveShapeBoxer.transform(model)

    private fun CodegenWriterDelegator<RustWriter>.includedModules(): List<String> = this.writers.values.mapNotNull { it.module() }
    private fun CodegenWriterDelegator<RustWriter>.includedModules(): List<String> =
        this.writers.values.mapNotNull { it.module() }

    fun execute() {
        logger.info("generating Rust client...")
@@ -120,8 +123,15 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {

    override fun structureShape(shape: StructureShape) {
        logger.info("generating a structure...")
        writers.useShapeWriter(shape) {
            StructureGenerator(model, symbolProvider, it, shape).render()
        writers.useShapeWriter(shape) { writer ->
            StructureGenerator(model, symbolProvider, writer, shape).render()
            if (!shape.hasTrait(SyntheticInputTrait::class.java)) {
                val builderGenerator = ModelBuilderGenerator(protocolConfig.model, protocolConfig.symbolProvider, shape)
                builderGenerator.render(writer)
                writer.implBlock(shape, symbolProvider) {
                    builderGenerator.renderConvenienceMethod(this)
                }
            }
        }
    }

+11 −9
Original line number Diff line number Diff line
@@ -38,9 +38,9 @@ import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpLabelTrait
import software.amazon.smithy.rust.codegen.lang.RustType
import software.amazon.smithy.rust.codegen.smithy.generators.toSnakeCase
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.util.toSnakeCase
import software.amazon.smithy.utils.StringUtils

// TODO: currently, respecting integer types.
@@ -71,10 +71,12 @@ data class SymbolLocation(val namespace: String) {
    val filename = "$namespace.rs"
}

val Shapes = SymbolLocation("model")
val Models = SymbolLocation("model")
val Errors = SymbolLocation("error")
val Operations = SymbolLocation("operation")
val Serializers = SymbolLocation("serializer")
val Inputs = SymbolLocation("input")
val Outputs = SymbolLocation("output")

fun Symbol.makeOptional(): Symbol {
    return if (isOptional()) {
@@ -150,7 +152,7 @@ class SymbolVisitor(
    override fun doubleShape(shape: DoubleShape): Symbol = simpleShape(shape)
    override fun stringShape(shape: StringShape): Symbol {
        return if (shape.hasTrait(EnumTrait::class.java)) {
            symbolBuilder(shape, RustType.Opaque(shape.id.name)).locatedIn(Shapes).build()
            symbolBuilder(shape, RustType.Opaque(shape.id.name)).locatedIn(Models).build()
        } else {
            simpleShape(shape)
        }
@@ -208,8 +210,8 @@ class SymbolVisitor(

    override fun structureShape(shape: StructureShape): Symbol {
        val isError = shape.hasTrait(ErrorTrait::class.java)
        val isIoShape =
            shape.hasTrait(SyntheticInputTrait::class.java) || shape.hasTrait(SyntheticOutputTrait::class.java)
        val isInput = shape.hasTrait(SyntheticInputTrait::class.java)
        val isOutput = shape.hasTrait(SyntheticOutputTrait::class.java)
        val name = StringUtils.capitalize(shape.id.name).letIf(isError && config.codegenConfig.renameExceptions) {
            // TODO: Do we want to do this?
            // https://github.com/awslabs/smithy-rs/issues/77
@@ -218,15 +220,15 @@ class SymbolVisitor(
        val builder = symbolBuilder(shape, RustType.Opaque(name))
        return when {
            isError -> builder.locatedIn(Errors)
            // Input shapes live with their Operations
            isIoShape -> builder.locatedIn(Operations)
            else -> builder.locatedIn(Shapes)
            isInput -> builder.locatedIn(Inputs)
            isOutput -> builder.locatedIn(Outputs)
            else -> builder.locatedIn(Models)
        }.build()
    }

    override fun unionShape(shape: UnionShape): Symbol {
        val name = StringUtils.capitalize(shape.id.name)
        val builder = symbolBuilder(shape, RustType.Opaque(name)).locatedIn(Shapes)
        val builder = symbolBuilder(shape, RustType.Opaque(name)).locatedIn(Models)

        return builder.build()
    }
+179 −0
Original line number Diff line number Diff line
package software.amazon.smithy.rust.codegen.smithy.generators

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.lang.RustType
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.lang.conditionalBlock
import software.amazon.smithy.rust.codegen.lang.docs
import software.amazon.smithy.rust.codegen.lang.documentShape
import software.amazon.smithy.rust.codegen.lang.render
import software.amazon.smithy.rust.codegen.lang.rustBlock
import software.amazon.smithy.rust.codegen.lang.stripOuter
import software.amazon.smithy.rust.codegen.lang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.canUseDefault
import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.smithy.makeOptional
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase

fun StructureShape.builderSymbol(symbolProvider: RustSymbolProvider): RuntimeType {
    val symbol = symbolProvider.toSymbol(this)
    return RuntimeType("Builder", null, "${symbol.namespace}::${symbol.name.toSnakeCase()}")
}

class ModelBuilderGenerator(
    model: Model,
    private val symbolProvider: RustSymbolProvider,
    private val shape: StructureShape
) :
    BuilderGenerator(model, symbolProvider, shape) {
    override fun buildFn(implBlockWriter: RustWriter) {
        val fallibleBuilder = StructureGenerator.fallibleBuilder(shape, symbolProvider)
        val returnType = when (fallibleBuilder) {
            true -> "Result<#T, String>"
            false -> "#T"
        }
        val outputSymbol = symbolProvider.toSymbol(shape)
        implBlockWriter.docs("Consumes the builder and constructs a #D", outputSymbol)
        implBlockWriter.rustBlock("pub fn build(self) -> $returnType", outputSymbol) {
            conditionalBlock("Ok(", ")", conditional = fallibleBuilder) {
                // If a wrapper is specified, use the `::new` associated function to construct the wrapper
                coreBuilder(this)
            }
        }
    }
}

class OperationInputBuilderGenerator(
    model: Model,
    private val symbolProvider: RustSymbolProvider,
    private val shape: OperationShape
) : BuilderGenerator(model, symbolProvider, shape.inputShape(model)) {
    override fun buildFn(implBlockWriter: RustWriter) {
        val fallibleBuilder = StructureGenerator.fallibleBuilder(shape.inputShape(model), symbolProvider)
        val returnType = when (fallibleBuilder) {
            true -> "Result<#T, String>"
            false -> "#T"
        }
        val outputSymbol = symbolProvider.toSymbol(shape)

        implBlockWriter.docs("Consumes the builder and constructs a #D", outputSymbol)
        implBlockWriter.rustBlock("pub fn build(self) -> $returnType", outputSymbol) {
            conditionalBlock("Ok(", ")", conditional = fallibleBuilder) {
                // If a wrapper is specified, use the `::new` associated function to construct the wrapper
                withBlock("#T::new(", ")", outputSymbol) {
                    coreBuilder(this)
                }
            }
        }
    }
}

abstract class BuilderGenerator(
    val model: Model,
    private val symbolProvider: RustSymbolProvider,
    private val shape: StructureShape
) {
    private val members: List<MemberShape> = shape.allMembers.values.toList()
    private val structureSymbol = symbolProvider.toSymbol(shape)
    fun render(writer: RustWriter) {
        val symbol = symbolProvider.toSymbol(shape)
        // TODO: figure out exactly what docs we want on a the builder module
        writer.docs("See #D", symbol)
        // check(writer.namespace == shape.builderSymbol(symbolProvider).namespace)
        val segments = shape.builderSymbol(symbolProvider).namespace.split("::")
        writer.withModule(segments.last()) {
            renderBuilder(this)
        }
    }

    fun renderConvenienceMethod(implBlock: RustWriter) {
        val builderSymbol = shape.builderSymbol(symbolProvider)
        implBlock.docs("Creates a new builder-style object to manufacture #D", structureSymbol)
        implBlock.rustBlock("pub fn builder() -> #T", builderSymbol) {
            write("#T::default()", builderSymbol)
        }
    }

    private fun renderBuilder(writer: RustWriter) {
        val builderName = "Builder"

        val symbol = structureSymbol
        writer.docs("A builder for #D", symbol)
        writer.write("##[non_exhaustive]")
        writer.write("##[derive(Debug, Clone, Default)]")
        writer.rustBlock("pub struct $builderName") {
            members.forEach { member ->
                val memberName = symbolProvider.toMemberName(member)
                // All fields in the builder are optional
                val memberSymbol = symbolProvider.toSymbol(member).makeOptional()
                // TODO: should the builder members be public?
                write("$memberName: #T,", memberSymbol)
            }
        }

        fun builderConverter(coreType: RustType) = when (coreType) {
            is RustType.String,
            is RustType.Box -> "inp.into()"
            else -> "inp"
        }

        writer.rustBlock("impl $builderName") {
            members.forEach { member ->
                val memberName = symbolProvider.toMemberName(member)
                // All fields in the builder are optional
                val memberSymbol = symbolProvider.toSymbol(member)
                val outerType = memberSymbol.rustType()
                val coreType = outerType.stripOuter<RustType.Option>()
                val signature = when (coreType) {
                    is RustType.String,
                    is RustType.Box -> "(mut self, inp: impl Into<${coreType.render(true)}>) -> Self"
                    else -> "(mut self, inp: ${coreType.render(true)}) -> Self"
                }
                writer.documentShape(member, model)
                writer.rustBlock("pub fn $memberName$signature") {
                    write("self.$memberName = Some(${builderConverter(coreType)});")
                    write("self")
                }
            }

            buildFn(this)
        }
    }

    abstract fun buildFn(implBlockWriter: RustWriter)

    /**
     * The core builder of the inner type. If the structure requires a fallible builder, this may use `?` to return
     * errors
     * ```rust
     * SomeStruct {
     *    field: builder.field,
     *    field2: builder.field2,
     *    field3: builder.field3.unwrap_or_default()
     *    field4: builder.field4.ok_or("field4 is required when building SomeStruct")?
     * }
     */
    protected fun coreBuilder(writer: RustWriter) {
        writer.rustBlock("#T", structureSymbol) {
            members.forEach { member ->
                val memberName = symbolProvider.toMemberName(member)
                val memberSymbol = symbolProvider.toSymbol(member)
                val errorWhenMissing = "$memberName is required when building ${structureSymbol.name}"
                val modifier = when {
                    !memberSymbol.isOptional() && memberSymbol.canUseDefault() -> ".unwrap_or_default()"
                    !memberSymbol.isOptional() -> ".ok_or(${errorWhenMissing.dq()})?"
                    else -> ""
                }
                write("$memberName: self.$memberName$modifier,")
            }
        }
    }
}
+1 −0
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.lang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.util.doubleQuote
import software.amazon.smithy.rust.codegen.util.toPascalCase
import java.lang.IllegalStateException

class EnumGenerator(
Loading