Unverified Commit 6e772b9d authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Generate Builders for Structures (#3)

A builder object is generated for Structure shapes. The builder will be fallible (return `Result<T, String>`) if the structure has required members without defaults.

If the structure has no required members, the `build()` method directly returns the constructed object. This required a number of refinements to our module handling as the builders are namespaced to 1-module-per-shape.
parent 4557f6a6
Loading
Loading
Loading
Loading
+57 −13
Original line number Diff line number Diff line
@@ -18,10 +18,19 @@ import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.utils.CodeWriter

fun CodeWriter.withBlock(textBeforeNewLine: String, textAfterNewLine: String, block: CodeWriter.() -> Unit): CodeWriter {
fun CodeWriter.withBlock(
    textBeforeNewLine: String,
    textAfterNewLine: String,
    conditional: Boolean = true,
    block: CodeWriter.() -> Unit
): CodeWriter {
    if (conditional) {
        openBlock(textBeforeNewLine)
    }
    block(this)
    if (conditional) {
        closeBlock(textAfterNewLine)
    }
    return this
}

@@ -35,9 +44,30 @@ fun <T : CodeWriter> T.rustBlock(header: String, vararg args: Any, block: T.() -
    return this
}

class RustWriter(filename: String, private val namespace: String, private val commentCharacter: String = "//") : CodegenWriter<RustWriter, UseDeclarations>(null, UseDeclarations(filename, namespace)) {
class RustWriter private constructor(private val filename: String, val namespace: String, private val commentCharacter: String = "//") :
    CodegenWriter<RustWriter, UseDeclarations>(null, UseDeclarations(filename, namespace)) {
    companion object {
        fun forModule(module: String): RustWriter {
            return RustWriter("$module.rs", "crate::$module")
        }

        val Factory: CodegenWriterFactory<RustWriter> =
            CodegenWriterFactory<RustWriter> { filename, namespace ->
                when {
                    filename.endsWith(".toml") -> RustWriter(filename, namespace, "#")
                    else -> RustWriter(filename, namespace)
                }
            }
    }
    init {
        if (filename.endsWith(".rs")) {
            require(namespace.startsWith("crate")) { "We can only write into files in the crate (got $namespace)" }
        }
    }

    private val formatter = RustSymbolFormatter()
    private var n = 0

    init {
        putFormatter('T', formatter)
    }
@@ -47,6 +77,23 @@ class RustWriter(filename: String, private val namespace: String, private val co
        return "${prefix}_$n"
    }

    /**
     * Create an inline module.
     * [header] should be the declaration of the module, eg. `pub mod Hello`.
     *
     * The returned writer will inject any local imports into the module as needed.
     */
    fun withModule(moduleName: String, visibility: String = "pub", moduleWriter: RustWriter.() -> Unit) {
        // In Rust, modules must specify their own imports—they don't have access to the parent scope.
        // To easily handle this, create a new inner writer to collect imports, then dump it
        // into an inline module.
        val innerWriter = RustWriter(this.filename, "${this.namespace}::$moduleName")
        moduleWriter(innerWriter)
        rustBlock("$visibility mod $moduleName") {
            write(innerWriter.toString())
        }
    }

    // TODO: refactor both of these methods & add a parent method to for_each across any field type
    // generically
    fun OptionForEach(member: Symbol, outerField: String, block: CodeWriter.(field: String) -> Unit) {
@@ -99,7 +146,12 @@ class RustWriter(filename: String, private val namespace: String, private val co
                is RuntimeType -> {
                    t.dependency?.also { addDependency(it) }
                    // for now, use the fully qualified type name
                    "::${t.namespace}::${t.name}"
                    val prefix = if (t.namespace.startsWith("crate")) {
                        ""
                    } else {
                        "::"
                    }
                    "$prefix${t.namespace}::${t.name}"
                }
                is Symbol -> {
                    if (t.namespace != namespace) {
@@ -111,12 +163,4 @@ class RustWriter(filename: String, private val namespace: String, private val co
            }
        }
    }

    companion object {
        val Factory: CodegenWriterFactory<RustWriter> =
            CodegenWriterFactory<RustWriter> { filename, namespace -> when {
                filename.endsWith(".toml") -> RustWriter(filename, namespace, "#")
                else -> RustWriter(filename, namespace)
            } }
    }
}
+1 −1
Original line number Diff line number Diff line
@@ -65,7 +65,7 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
            )
            cargoToml.render()
        }
        writers.useFileWriter("src/lib.rs") {
        writers.useFileWriter("src/lib.rs", "crate::lib") {
            // TODO: a more structured method of signaling what modules should get loaded.
            val modules = PublicModules.filter { writers.writers.containsKey("src/$it.rs") }
            LibRsGenerator(modules, it).render()
+40 −13
Original line number Diff line number Diff line
@@ -69,10 +69,15 @@ fun Symbol.referenceClosure(): List<Symbol> {
    return listOf(this) + referencedSymbols.flatMap { it.referenceClosure() }
}

data class SymbolVisitorConfig(val runtimeConfig: RuntimeConfig, val handleOptionality: Boolean = true, val handleRustBoxing: Boolean = true)
data class SymbolVisitorConfig(
    val runtimeConfig: RuntimeConfig,
    val handleOptionality: Boolean = true,
    val handleRustBoxing: Boolean = true
)

// TODO: consider if this is better handled as a wrapper
val DefaultConfig = SymbolVisitorConfig(runtimeConfig = RuntimeConfig(), handleOptionality = true, handleRustBoxing = true)
val DefaultConfig =
    SymbolVisitorConfig(runtimeConfig = RuntimeConfig(), handleOptionality = true, handleRustBoxing = true)

data class SymbolLocation(val filename: String, val namespace: String)

@@ -84,6 +89,19 @@ val Shapes = SymbolLocation("model.rs", "model")
val Errors = SymbolLocation("error.rs", "error")
val Operations = SymbolLocation("operation.rs", "operation")

fun Symbol.makeOptional(): Symbol {
    return if (isOptional()) {
        this
    } else {
        val rustType = RustType.Option(this.rustType())
        Symbol.builder().rustType(rustType)
            .rustType(rustType)
            .addReference(this)
            .name(rustType.name)
            .build()
    }
}

class SymbolVisitor(
    private val model: Model,
    private val rootNamespace: String = "crate",
@@ -102,16 +120,12 @@ class SymbolVisitor(
    }

    private fun handleOptionality(symbol: Symbol, member: MemberShape, container: Shape): Symbol {
        val httpLabeledInput = container.hasTrait(SyntheticInput::class.java) && member.hasTrait(HttpLabelTrait::class.java)
        // If a field has the httpLabel trait and we are generating
        // an Input shape, then the field is _not optional_.
        val httpLabeledInput =
            container.hasTrait(SyntheticInput::class.java) && member.hasTrait(HttpLabelTrait::class.java)
        return if (nullableIndex.isNullable(member) && !httpLabeledInput) {
            with(Symbol.builder()) {
                val rustType = RustType.Option(symbol.rustType())
                rustType(rustType)
                addReference(symbol)
                name(rustType.name)
                putProperty(SHAPE_KEY, member)
                build()
            }
            symbol.makeOptional()
        } else symbol
    }

@@ -128,7 +142,7 @@ class SymbolVisitor(
    }

    private fun simpleShape(shape: SimpleShape): Symbol {
        return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).build()
        return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).canUseDefault().build()
    }

    override fun booleanShape(shape: BooleanShape): Symbol = simpleShape(shape)
@@ -239,7 +253,7 @@ class SymbolVisitor(
        return builder.rustType(rustType)
            .name(rustType.name)
            // Every symbol that actually gets defined somewhere should set a definition file
            // If we ever generate a `thisisabug.rs`, we messed something up
            // If we ever generate a `thisisabug.rs`, there is a bug in our symbol generation
            .definitionFile("thisisabug.rs")
    }
}
@@ -247,11 +261,24 @@ class SymbolVisitor(
// TODO(chore): Move this to a useful place
private const val RUST_TYPE_KEY = "rusttype"
private const val SHAPE_KEY = "shape"
private const val CAN_USE_DEFAULT = "canusedefault"

fun Symbol.Builder.rustType(rustType: RustType): Symbol.Builder {
    return this.putProperty(RUST_TYPE_KEY, rustType)
}

fun Symbol.Builder.canUseDefault(value: Boolean = true): Symbol.Builder {
    return this.putProperty(CAN_USE_DEFAULT, value)
}

/**
 * True when it is valid to use the default/0 value for [this] symbol during construction.
 */
fun Symbol.canUseDefault(): Boolean = this.getProperty(CAN_USE_DEFAULT, Boolean::class.javaObjectType).orElse(false)

/**
 * True when [this] is will be represented by Option<T> in Rust
 */
fun Symbol.isOptional(): Boolean = when (this.rustType()) {
    is RustType.Option -> true
    else -> false
+111 −8
Original line number Diff line number Diff line
@@ -10,18 +10,31 @@ import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.lang.RustType
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.lang.render
import software.amazon.smithy.rust.codegen.lang.rustBlock
import software.amazon.smithy.rust.codegen.lang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.canUseDefault
import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.smithy.makeOptional
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.utils.CaseUtils

// TODO(maybe): extract struct generation from Smithy shapes to support generating body objects
// TODO: generate builders; 1d
// TODO: generate documentation
class StructureGenerator(
    val model: Model,
    private val symbolProvider: SymbolProvider,
    private val writer: RustWriter,
    private val shape: StructureShape
    private val shape: StructureShape,
    private val renderBuilder: Boolean = true
) {
    private val sortedMembers: List<MemberShape> = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) }
    private val members: List<MemberShape> = shape.allMembers.values.toList()
    private val structureSymbol = symbolProvider.toSymbol(shape)
    private val builderSymbol = RuntimeType("Builder", null, "${structureSymbol.namespace}::${structureSymbol.name.toSnakeCase()}")
    fun render() {
        renderStructure()
        val errorTrait = shape.getTrait(ErrorTrait::class.java)
@@ -29,6 +42,12 @@ class StructureGenerator(
            val errorGenerator = ErrorGenerator(model, symbolProvider, writer, shape, it)
            errorGenerator.render()
        }
        if (renderBuilder) {
            val symbol = symbolProvider.toSymbol(shape)
            writer.withModule(symbol.name.toSnakeCase()) {
                renderBuilder(this)
            }
        }
    }

    private fun renderStructure() {
@@ -36,11 +55,95 @@ class StructureGenerator(
        // TODO(maybe): Pull derive info from the symbol so that the symbol provider can alter things as necessary; 4h
        writer.write("#[non_exhaustive]")
        writer.write("#[derive(Debug, PartialEq, Clone)]")
        val blockWriter = writer.openBlock("pub struct ${symbol.name} {")
        sortedMembers.forEach { member ->
        writer.rustBlock("pub struct ${symbol.name}") {
            members.forEach { member ->
                val memberName = symbolProvider.toMemberName(member)
                write("pub $memberName: \$T,", symbolProvider.toSymbol(member))
            }
        }

        if (renderBuilder) {
            writer.rustBlock("impl ${symbol.name}") {
                rustBlock("pub fn builder() -> \$T", builderSymbol) {
                    write("\$T::default()", builderSymbol)
                }
            }
        }
    }

    private fun renderBuilder(writer: RustWriter) {
        // Eventually, I want to do a fancier module layout:
        // model/some_model.rs [contains builder and impl for a single model] struct SomeModel, struct Builder
        // model/mod.rs [contains pub use for each model to bring it into top level scope]
        // users will do models::SomeModel, models::SomeModel::builder()
        val builderName = "Builder"
        writer.write("#[non_exhaustive]")
        writer.write("#[derive(Debug, Clone, Default)]")
        writer.rustBlock("pub struct $builderName") {
            members.forEach { member ->
                val memberName = symbolProvider.toMemberName(member)
                // All fields in the builder are optional
                val memberSymbol = symbolProvider.toSymbol(member).makeOptional()
                // TODO: should the builder members be public?
                write("$memberName: \$T,", memberSymbol)
            }
        }

        fun builderConverter(rustType: RustType) = when (rustType) {
            is RustType.String -> "inp.into()"
            else -> "inp"
        }

        writer.rustBlock("impl $builderName") {
            members.forEach { member ->
                val memberName = symbolProvider.toMemberName(member)
                // All fields in the builder are optional
                val memberSymbol = symbolProvider.toSymbol(member)
                val coreType = memberSymbol.rustType().let {
                    when (it) {
                        is RustType.Option -> it.value
                        else -> it
                    }
                }
                val signature = when (coreType) {
                    is RustType.String -> "<T: Into<String>>(mut self, inp: T) -> Self"
                    else -> "(mut self, inp: ${coreType.render()}) -> Self"
                }
                writer.rustBlock("pub fn $memberName$signature") {
                    write("self.$memberName = Some(${builderConverter(coreType)});")
                    write("self")
                }
            }

            val fallible = members.map { symbolProvider.toSymbol(it) }.any {
                // If any members are not optional && we can't use a default, we need to
                // generate a fallible builder
                !it.isOptional() && !it.canUseDefault()
            }

            val returnType = when (fallible) {
                true -> "Result<\$T, String>"
                false -> "\$T"
            }

            writer.rustBlock("pub fn build(self) -> $returnType", structureSymbol) {
                withBlock("Ok(", ")", conditional = fallible) {
                    rustBlock("\$T", structureSymbol) {
                        members.forEach { member ->
                            val memberName = symbolProvider.toMemberName(member)
            blockWriter.write("pub $memberName: \$T,", symbolProvider.toSymbol(member)) }
        blockWriter.closeBlock("}")
                            val memberSymbol = symbolProvider.toSymbol(member)
                            val errorWhenMissing = "$memberName is required when building ${structureSymbol.name}"
                            val modifier = when {
                                !memberSymbol.isOptional() && memberSymbol.canUseDefault() -> ".unwrap_or_default()"
                                !memberSymbol.isOptional() -> ".ok_or(${errorWhenMissing.dq()})?"
                                else -> ""
                            }
                            write("$memberName: self.$memberName$modifier,")
                        }
                    }
                }
            }
        }
    }
}

+3 −2
Original line number Diff line number Diff line
@@ -23,8 +23,9 @@ fun String.runCommand(workdir: Path? = null): String? {

    proc.waitFor(60, TimeUnit.MINUTES)
    if (proc.exitValue() != 0) {
        val output = proc.errorStream.bufferedReader().readText()
        throw CommandFailed("Command Failed\n$output")
        val stdErr = proc.errorStream.bufferedReader().readText()
        val stdOut = proc.inputStream.bufferedReader().readText()
        throw CommandFailed("Command Failed\n$stdErr\n$stdOut")
    }
    return proc.inputStream.bufferedReader().readText()
}
Loading