Unverified Commit 4f327bc1 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Generate request body struct (#30)

* Run all tests in the same cargo workspace

Running all the tests in the same cargo workspace allows test runners to avoid recompiling dependencies every time. This allows us to drop parallel tests which seemed to be causing OOMs.

* Doc cleanups and improvements

* Add support for public and private modules

* Generate Synthetic Body structs for serialization

To enable serialization & deserialization, we generate synthetic structs representing the request body. Currently, the serialization method returns a stub, however, in a subsequent PR, serde will be used to generate actual request bodies.

* Fix doc typo

* CR feedback
parent 33705fa3
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
package software.amazon.smithy.rust.codegen.lang

data class RustModule(val name: String, val meta: Meta) {
    fun render(writer: RustWriter) {
        meta.render(writer)
        writer.write("mod $name;")
    }
}
+6 −2
Original line number Diff line number Diff line
@@ -46,7 +46,7 @@ fun <T : CodeWriter> T.rustBlock(header: String, vararg args: Any, block: T.() -
}

class RustWriter private constructor(private val filename: String, val namespace: String, private val commentCharacter: String = "//") :
    CodegenWriter<RustWriter, UseDeclarations>(null, UseDeclarations(filename, namespace)) {
    CodegenWriter<RustWriter, UseDeclarations>(null, UseDeclarations(namespace)) {
    companion object {
        fun forModule(module: String): RustWriter {
            return RustWriter("$module.rs", "crate::$module")
@@ -73,7 +73,11 @@ class RustWriter private constructor(private val filename: String, val namespace
        putFormatter('T', formatter)
    }

    fun safeName(prefix: String = "var"): String {
    fun module(): String? = if (filename.endsWith(".rs")) {
        filename.removeSuffix(".rs").split('/').last()
    } else null

    private fun safeName(prefix: String = "var"): String {
        n += 1
        return "${prefix}_$n"
    }
+1 −1
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ package software.amazon.smithy.rust.codegen.lang
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.writer.ImportContainer

class UseDeclarations(private val filename: String, private val namespace: String) : ImportContainer {
class UseDeclarations(private val namespace: String) : ImportContainer {
    private val imports: MutableSet<UseStatement> = mutableSetOf()
    fun addImport(moduleName: String, symbolName: String, alias: String = symbolName) {
        imports.add(UseStatement(moduleName, symbolName, alias))
+17 −11
Original line number Diff line number Diff line
@@ -17,7 +17,9 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.lang.Meta
import software.amazon.smithy.rust.codegen.lang.RustDependency
import software.amazon.smithy.rust.codegen.lang.RustModule
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.smithy.generators.CargoTomlGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator
@@ -29,12 +31,16 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ServiceGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.util.CommandFailed
import software.amazon.smithy.rust.codegen.util.runCommand
import java.util.logging.Logger

private val PublicModules = listOf("error", "operation", "model")
private val Modules = listOf(
    RustModule("error", Meta(public = true)),
    RustModule("operation", Meta(public = true)),
    RustModule("model", Meta(public = true)),
    RustModule("serializer", Meta(public = false))
)

class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {

@@ -51,13 +57,12 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {

    init {
        val symbolVisitorConfig = SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig)
        val bootstrapProvider = RustCodegenPlugin.BaseSymbolProvider(context.model, symbolVisitorConfig)
        model = OperationNormalizer(bootstrapProvider).addOperationInputs(context.model)
        symbolProvider =
            RustCodegenPlugin.BaseSymbolProvider(model, SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig))
        val service = settings.getService(context.model)
        val (protocol, generator) = ProtocolLoader.Default.protocolFor(context.model, service)
        protocolGenerator = generator
        model = generator.transformModel(context.model)
        val baseProvider = RustCodegenPlugin.BaseSymbolProvider(model, symbolVisitorConfig)
        symbolProvider = generator.symbolProvider(model, baseProvider)

        protocolConfig = ProtocolConfig(model, symbolProvider, settings.runtimeConfig, service, protocol)
        writers = CodegenWriterDelegator(
@@ -69,6 +74,8 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
        httpGenerator = protocolGenerator.buildProtocolGenerator(protocolConfig)
    }

    private fun CodegenWriterDelegator<RustWriter>.includedModules(): List<String> = this.writers.values.mapNotNull { it.module() }

    fun execute() {
        logger.info("generating Rust client...")
        val service = settings.getService(model)
@@ -82,10 +89,10 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
            )
            cargoToml.render()
        }
        writers.useFileWriter("src/lib.rs", "crate::lib") {
            // TODO: a more structured method of signaling what modules should get loaded.
            val modules = PublicModules.filter { writers.writers.containsKey("src/$it.rs") }
            LibRsGenerator(modules, it).render()
        writers.useFileWriter("src/lib.rs", "crate::lib") { writer ->
            val includedModules = writers.includedModules().toSet()
            val modules = Modules.filter { module -> includedModules.contains(module.name) }
            LibRsGenerator(modules).render(writer)
        }
        writers.flushWriters()
        try {
@@ -99,7 +106,6 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
    }

    override fun structureShape(shape: StructureShape) {
        // super.structureShape(shape)
        logger.info("generating a structure...")
        writers.useShapeWriter(shape) {
            StructureGenerator(model, symbolProvider, it, shape).render()
+14 −12
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpLabelTrait
import software.amazon.smithy.rust.codegen.lang.RustType
import software.amazon.smithy.rust.codegen.smithy.generators.toSnakeCase
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInput
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.utils.StringUtils

// TODO: currently, respecting integer types.
@@ -79,11 +79,14 @@ data class SymbolVisitorConfig(
val DefaultConfig =
    SymbolVisitorConfig(runtimeConfig = RuntimeConfig(), handleOptionality = true, handleRustBoxing = true)

data class SymbolLocation(val filename: String, val namespace: String)
data class SymbolLocation(val namespace: String) {
    val filename = "$namespace.rs"
}

val Shapes = SymbolLocation("model.rs", "model")
val Errors = SymbolLocation("error.rs", "error")
val Operations = SymbolLocation("operation.rs", "operation")
val Shapes = SymbolLocation("model")
val Errors = SymbolLocation("error")
val Operations = SymbolLocation("operation")
val Serializers = SymbolLocation("serializer")

fun Symbol.makeOptional(): Symbol {
    return if (isOptional()) {
@@ -98,9 +101,12 @@ fun Symbol.makeOptional(): Symbol {
    }
}

fun Symbol.Builder.locatedIn(symbolLocation: SymbolLocation): Symbol.Builder =
    this.definitionFile("src/${symbolLocation.filename}")
        .namespace("crate::${symbolLocation.namespace}", "::")

class SymbolVisitor(
    private val model: Model,
    private val rootNamespace: String = "crate",
    private val config: SymbolVisitorConfig = DefaultConfig
) : SymbolProvider,
    ShapeVisitor<Symbol> {
@@ -115,15 +121,11 @@ class SymbolVisitor(
        return RuntimeType.Blob(config.runtimeConfig).toSymbol()
    }

    private fun Symbol.Builder.locatedIn(symbolLocation: SymbolLocation): Symbol.Builder =
        this.definitionFile("src/${symbolLocation.filename}")
            .namespace("$rootNamespace::${symbolLocation.namespace}", "::")

    private fun handleOptionality(symbol: Symbol, member: MemberShape, container: Shape): Symbol {
        // If a field has the httpLabel trait and we are generating
        // an Input shape, then the field is _not optional_.
        val httpLabeledInput =
            container.hasTrait(SyntheticInput::class.java) && member.hasTrait(HttpLabelTrait::class.java)
            container.hasTrait(SyntheticInputTrait::class.java) && member.hasTrait(HttpLabelTrait::class.java)
        return if (nullableIndex.isNullable(member) && !httpLabeledInput) {
            symbol.makeOptional()
        } else symbol
@@ -214,7 +216,7 @@ class SymbolVisitor(

    override fun structureShape(shape: StructureShape): Symbol {
        val isError = shape.hasTrait(ErrorTrait::class.java)
        val isInput = shape.hasTrait(SyntheticInput::class.java)
        val isInput = shape.hasTrait(SyntheticInputTrait::class.java)
        val name = StringUtils.capitalize(shape.id.name).letIf(isError) {
            // TODO: this is should probably be a configurable mixin
            it.replace("Exception", "Error")
Loading