Unverified Commit 4f327bc1 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Generate request body struct (#30)

* Run all tests in the same cargo workspace

Running all the tests in the same cargo workspace allows test runners to avoid recompiling dependencies every time. This allows us to drop parallel tests which seemed to be causing OOMs.

* Doc cleanups and improvements

* Add support for public and private modules

* Generate Synthetic Body structs for serialization

To enable serialization & deserialization, we generate synthetic structs representing the request body. Currently, the serialization method returns a stub, however, in a subsequent PR, serde will be used to generate actual request bodies.

* Fix doc typo

* CR feedback
parent 33705fa3
Loading
Loading
Loading
Loading
+8 −0
Original line number Original line Diff line number Diff line
package software.amazon.smithy.rust.codegen.lang

data class RustModule(val name: String, val meta: Meta) {
    fun render(writer: RustWriter) {
        meta.render(writer)
        writer.write("mod $name;")
    }
}
+6 −2
Original line number Original line Diff line number Diff line
@@ -46,7 +46,7 @@ fun <T : CodeWriter> T.rustBlock(header: String, vararg args: Any, block: T.() -
}
}


class RustWriter private constructor(private val filename: String, val namespace: String, private val commentCharacter: String = "//") :
class RustWriter private constructor(private val filename: String, val namespace: String, private val commentCharacter: String = "//") :
    CodegenWriter<RustWriter, UseDeclarations>(null, UseDeclarations(filename, namespace)) {
    CodegenWriter<RustWriter, UseDeclarations>(null, UseDeclarations(namespace)) {
    companion object {
    companion object {
        fun forModule(module: String): RustWriter {
        fun forModule(module: String): RustWriter {
            return RustWriter("$module.rs", "crate::$module")
            return RustWriter("$module.rs", "crate::$module")
@@ -73,7 +73,11 @@ class RustWriter private constructor(private val filename: String, val namespace
        putFormatter('T', formatter)
        putFormatter('T', formatter)
    }
    }


    fun safeName(prefix: String = "var"): String {
    fun module(): String? = if (filename.endsWith(".rs")) {
        filename.removeSuffix(".rs").split('/').last()
    } else null

    private fun safeName(prefix: String = "var"): String {
        n += 1
        n += 1
        return "${prefix}_$n"
        return "${prefix}_$n"
    }
    }
+1 −1
Original line number Original line Diff line number Diff line
@@ -8,7 +8,7 @@ package software.amazon.smithy.rust.codegen.lang
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.writer.ImportContainer
import software.amazon.smithy.codegen.core.writer.ImportContainer


class UseDeclarations(private val filename: String, private val namespace: String) : ImportContainer {
class UseDeclarations(private val namespace: String) : ImportContainer {
    private val imports: MutableSet<UseStatement> = mutableSetOf()
    private val imports: MutableSet<UseStatement> = mutableSetOf()
    fun addImport(moduleName: String, symbolName: String, alias: String = symbolName) {
    fun addImport(moduleName: String, symbolName: String, alias: String = symbolName) {
        imports.add(UseStatement(moduleName, symbolName, alias))
        imports.add(UseStatement(moduleName, symbolName, alias))
+17 −11
Original line number Original line Diff line number Diff line
@@ -17,7 +17,9 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.lang.Meta
import software.amazon.smithy.rust.codegen.lang.RustDependency
import software.amazon.smithy.rust.codegen.lang.RustDependency
import software.amazon.smithy.rust.codegen.lang.RustModule
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.lang.RustWriter
import software.amazon.smithy.rust.codegen.smithy.generators.CargoTomlGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.CargoTomlGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator
@@ -29,12 +31,16 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ServiceGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolLoader
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.util.CommandFailed
import software.amazon.smithy.rust.codegen.util.CommandFailed
import software.amazon.smithy.rust.codegen.util.runCommand
import software.amazon.smithy.rust.codegen.util.runCommand
import java.util.logging.Logger
import java.util.logging.Logger


private val PublicModules = listOf("error", "operation", "model")
private val Modules = listOf(
    RustModule("error", Meta(public = true)),
    RustModule("operation", Meta(public = true)),
    RustModule("model", Meta(public = true)),
    RustModule("serializer", Meta(public = false))
)


class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {


@@ -51,13 +57,12 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {


    init {
    init {
        val symbolVisitorConfig = SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig)
        val symbolVisitorConfig = SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig)
        val bootstrapProvider = RustCodegenPlugin.BaseSymbolProvider(context.model, symbolVisitorConfig)
        model = OperationNormalizer(bootstrapProvider).addOperationInputs(context.model)
        symbolProvider =
            RustCodegenPlugin.BaseSymbolProvider(model, SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig))
        val service = settings.getService(context.model)
        val service = settings.getService(context.model)
        val (protocol, generator) = ProtocolLoader.Default.protocolFor(context.model, service)
        val (protocol, generator) = ProtocolLoader.Default.protocolFor(context.model, service)
        protocolGenerator = generator
        protocolGenerator = generator
        model = generator.transformModel(context.model)
        val baseProvider = RustCodegenPlugin.BaseSymbolProvider(model, symbolVisitorConfig)
        symbolProvider = generator.symbolProvider(model, baseProvider)


        protocolConfig = ProtocolConfig(model, symbolProvider, settings.runtimeConfig, service, protocol)
        protocolConfig = ProtocolConfig(model, symbolProvider, settings.runtimeConfig, service, protocol)
        writers = CodegenWriterDelegator(
        writers = CodegenWriterDelegator(
@@ -69,6 +74,8 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
        httpGenerator = protocolGenerator.buildProtocolGenerator(protocolConfig)
        httpGenerator = protocolGenerator.buildProtocolGenerator(protocolConfig)
    }
    }


    private fun CodegenWriterDelegator<RustWriter>.includedModules(): List<String> = this.writers.values.mapNotNull { it.module() }

    fun execute() {
    fun execute() {
        logger.info("generating Rust client...")
        logger.info("generating Rust client...")
        val service = settings.getService(model)
        val service = settings.getService(model)
@@ -82,10 +89,10 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
            )
            )
            cargoToml.render()
            cargoToml.render()
        }
        }
        writers.useFileWriter("src/lib.rs", "crate::lib") {
        writers.useFileWriter("src/lib.rs", "crate::lib") { writer ->
            // TODO: a more structured method of signaling what modules should get loaded.
            val includedModules = writers.includedModules().toSet()
            val modules = PublicModules.filter { writers.writers.containsKey("src/$it.rs") }
            val modules = Modules.filter { module -> includedModules.contains(module.name) }
            LibRsGenerator(modules, it).render()
            LibRsGenerator(modules).render(writer)
        }
        }
        writers.flushWriters()
        writers.flushWriters()
        try {
        try {
@@ -99,7 +106,6 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
    }
    }


    override fun structureShape(shape: StructureShape) {
    override fun structureShape(shape: StructureShape) {
        // super.structureShape(shape)
        logger.info("generating a structure...")
        logger.info("generating a structure...")
        writers.useShapeWriter(shape) {
        writers.useShapeWriter(shape) {
            StructureGenerator(model, symbolProvider, it, shape).render()
            StructureGenerator(model, symbolProvider, it, shape).render()
+14 −12
Original line number Original line Diff line number Diff line
@@ -39,7 +39,7 @@ import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpLabelTrait
import software.amazon.smithy.model.traits.HttpLabelTrait
import software.amazon.smithy.rust.codegen.lang.RustType
import software.amazon.smithy.rust.codegen.lang.RustType
import software.amazon.smithy.rust.codegen.smithy.generators.toSnakeCase
import software.amazon.smithy.rust.codegen.smithy.generators.toSnakeCase
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInput
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.utils.StringUtils
import software.amazon.smithy.utils.StringUtils


// TODO: currently, respecting integer types.
// TODO: currently, respecting integer types.
@@ -79,11 +79,14 @@ data class SymbolVisitorConfig(
val DefaultConfig =
val DefaultConfig =
    SymbolVisitorConfig(runtimeConfig = RuntimeConfig(), handleOptionality = true, handleRustBoxing = true)
    SymbolVisitorConfig(runtimeConfig = RuntimeConfig(), handleOptionality = true, handleRustBoxing = true)


data class SymbolLocation(val filename: String, val namespace: String)
data class SymbolLocation(val namespace: String) {
    val filename = "$namespace.rs"
}


val Shapes = SymbolLocation("model.rs", "model")
val Shapes = SymbolLocation("model")
val Errors = SymbolLocation("error.rs", "error")
val Errors = SymbolLocation("error")
val Operations = SymbolLocation("operation.rs", "operation")
val Operations = SymbolLocation("operation")
val Serializers = SymbolLocation("serializer")


fun Symbol.makeOptional(): Symbol {
fun Symbol.makeOptional(): Symbol {
    return if (isOptional()) {
    return if (isOptional()) {
@@ -98,9 +101,12 @@ fun Symbol.makeOptional(): Symbol {
    }
    }
}
}


fun Symbol.Builder.locatedIn(symbolLocation: SymbolLocation): Symbol.Builder =
    this.definitionFile("src/${symbolLocation.filename}")
        .namespace("crate::${symbolLocation.namespace}", "::")

class SymbolVisitor(
class SymbolVisitor(
    private val model: Model,
    private val model: Model,
    private val rootNamespace: String = "crate",
    private val config: SymbolVisitorConfig = DefaultConfig
    private val config: SymbolVisitorConfig = DefaultConfig
) : SymbolProvider,
) : SymbolProvider,
    ShapeVisitor<Symbol> {
    ShapeVisitor<Symbol> {
@@ -115,15 +121,11 @@ class SymbolVisitor(
        return RuntimeType.Blob(config.runtimeConfig).toSymbol()
        return RuntimeType.Blob(config.runtimeConfig).toSymbol()
    }
    }


    private fun Symbol.Builder.locatedIn(symbolLocation: SymbolLocation): Symbol.Builder =
        this.definitionFile("src/${symbolLocation.filename}")
            .namespace("$rootNamespace::${symbolLocation.namespace}", "::")

    private fun handleOptionality(symbol: Symbol, member: MemberShape, container: Shape): Symbol {
    private fun handleOptionality(symbol: Symbol, member: MemberShape, container: Shape): Symbol {
        // If a field has the httpLabel trait and we are generating
        // If a field has the httpLabel trait and we are generating
        // an Input shape, then the field is _not optional_.
        // an Input shape, then the field is _not optional_.
        val httpLabeledInput =
        val httpLabeledInput =
            container.hasTrait(SyntheticInput::class.java) && member.hasTrait(HttpLabelTrait::class.java)
            container.hasTrait(SyntheticInputTrait::class.java) && member.hasTrait(HttpLabelTrait::class.java)
        return if (nullableIndex.isNullable(member) && !httpLabeledInput) {
        return if (nullableIndex.isNullable(member) && !httpLabeledInput) {
            symbol.makeOptional()
            symbol.makeOptional()
        } else symbol
        } else symbol
@@ -214,7 +216,7 @@ class SymbolVisitor(


    override fun structureShape(shape: StructureShape): Symbol {
    override fun structureShape(shape: StructureShape): Symbol {
        val isError = shape.hasTrait(ErrorTrait::class.java)
        val isError = shape.hasTrait(ErrorTrait::class.java)
        val isInput = shape.hasTrait(SyntheticInput::class.java)
        val isInput = shape.hasTrait(SyntheticInputTrait::class.java)
        val name = StringUtils.capitalize(shape.id.name).letIf(isError) {
        val name = StringUtils.capitalize(shape.id.name).letIf(isError) {
            // TODO: this is should probably be a configurable mixin
            // TODO: this is should probably be a configurable mixin
            it.replace("Exception", "Error")
            it.replace("Exception", "Error")
Loading