Unverified Commit dfe42b73 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Refactor generation to use unified Meta (#26)

* Refactor generation to use unified Meta

A `Meta` object is introduced that is attached during Symbol generation. This configures the visibility and attributes that the generators should use when generating structs/unions/etc. This is a refactoring to setup the next set of PRs which will generate serializers for models via the same mechanism.

* Cleanup SymbolVisitor cruft
parent 56ee9c52
Loading
Loading
Loading
Loading
+81 −0
Original line number Diff line number Diff line
@@ -5,6 +5,8 @@

package software.amazon.smithy.rust.codegen.lang

import software.amazon.smithy.rust.codegen.smithy.RuntimeType

/**
 * A hierarchy of types handled by Smithy codegen
 */
@@ -73,3 +75,82 @@ fun RustType.render(): String = when (this) {
    is RustType.Box -> "${this.name}<${this.value.render()}>"
    is RustType.Opaque -> this.name
}

/**
 * Meta information about a Rust construction (field, struct, or enum)
 */
data class Meta(val derives: Derives = Derives.Empty, val additionalAttributes: List<Attribute> = listOf(), val public: Boolean) {
    fun attributes(): List<Attribute> = additionalAttributes + derives
    fun renderAttributes(writer: RustWriter): Meta {
        attributes().forEach {
            it.render(writer)
        }
        return this
    }

    fun renderVisibility(writer: RustWriter): Meta {
        if (public) {
            writer.writeInline("pub ")
        }
        return this
    }

    fun render(writer: RustWriter) {
        renderAttributes(writer)
        renderVisibility(writer)
    }
}

/**
 * [Attributes](https://doc.rust-lang.org/reference/attributes.html) are general free form metadata
 * that are interpreted by the compiler.
 *
 * For example:
 * ```rust
 *
 * #[derive(Clone, PartialEq, Serialize)] // <-- this is an attribute
 * #[serde(serialize_with = "abc")] // <-- this is an attribute
 * struct Abc {
 *   a: i64
 * }
 */
sealed class Attribute {
    abstract fun render(writer: RustWriter)

    companion object {
        /**
         * [non_exhaustive](https://doc.rust-lang.org/reference/attributes/type_system.html#the-non_exhaustive-attribute)
         * indicates that more fields may be added in the future
         */
        val NonExhaustive = Custom("non_exhaustive")
    }
}

data class Derives(val derives: Set<RuntimeType>) : Attribute() {
    override fun render(writer: RustWriter) {
        if (derives.isEmpty()) {
            return
        }
        writer.writeInline("#[derive(")
        derives.sortedBy { it.name }.forEach { derive ->
            writer.writeInline("\$T, ", derive)
        }
        writer.write(")]")
    }

    companion object {
        val Empty = Derives(setOf())
    }
}

data class Custom(val annot: String, val symbols: List<RuntimeType> = listOf()) : Attribute() {
    override fun render(writer: RustWriter) {
        writer.writeInline("#[")
        writer.writeInline(annot)
        writer.write("]")

        symbols.forEach {
            writer.addDependency(it.dependency)
        }
    }
}
+3 −3
Original line number Diff line number Diff line
@@ -80,17 +80,17 @@ class RustWriter private constructor(private val filename: String, val namespace

    /**
     * Create an inline module.
     * [header] should be the declaration of the module, eg. `pub mod Hello`.
     *
     * The returned writer will inject any local imports into the module as needed.
     */
    fun withModule(moduleName: String, visibility: String = "pub", moduleWriter: RustWriter.() -> Unit) {
    fun withModule(moduleName: String, meta: Meta = Meta(public = true), moduleWriter: RustWriter.() -> Unit) {
        // In Rust, modules must specify their own imports—they don't have access to the parent scope.
        // To easily handle this, create a new inner writer to collect imports, then dump it
        // into an inline module.
        val innerWriter = RustWriter(this.filename, "${this.namespace}::$moduleName")
        moduleWriter(innerWriter)
        rustBlock("$visibility mod $moduleName") {
        meta.render(this)
        rustBlock("mod $moduleName") {
            write(innerWriter.toString())
        }
        innerWriter.dependencies.forEach { addDependency(it) }
+6 −5
Original line number Diff line number Diff line
@@ -48,16 +48,17 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Unit>() {
    private val protocolConfig: ProtocolConfig
    private val protocolGenerator: ProtocolGeneratorFactory<HttpProtocolGenerator>
    private val httpGenerator: HttpProtocolGenerator

    init {
        val symbolVisitorConfig = SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig)
        val bootstrapProvider = RustCodegenPlugin.BaseSymbolProvider(context.model, symbolVisitorConfig)
        model = OperationNormalizer(bootstrapProvider).addOperationInputs(context.model)
        symbolProvider =
            RustCodegenPlugin.BaseSymbolProvider(model, SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig))
        val service = settings.getService(context.model)
        val (protocol, generator) = ProtocolLoader.Default.protocolFor(context.model, service)
        protocolGenerator = generator

        val baseVisitor = SymbolVisitor(context.model, config = SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig))
        val normalizer = OperationNormalizer(baseVisitor)

        model = normalizer.addOperationInputs(context.model)
        symbolProvider = SymbolVisitor(model, config = SymbolVisitorConfig(runtimeConfig = settings.runtimeConfig))
        protocolConfig = ProtocolConfig(model, symbolProvider, settings.runtimeConfig, service, protocol)
        writers = CodegenWriterDelegator(
            context.fileManifest,
+1 −0
Original line number Diff line number Diff line
@@ -45,6 +45,7 @@ data class RuntimeType(val name: String, val dependency: RustDependency?, val na
        val From = RuntimeType("From", dependency = null, namespace = "std::convert")
        val AsRef = RuntimeType("AsRef", dependency = null, namespace = "std::convert")
        fun StdFmt(member: String) = RuntimeType("fmt::$member", dependency = null, namespace = "std")
        fun Std(member: String) = RuntimeType(member, dependency = null, namespace = "std")
        val StdError = RuntimeType("Error", dependency = null, namespace = "std::error")
        val HashSet = RuntimeType("HashSet", dependency = null, namespace = "std::collections")
        val HashMap = RuntimeType("HashMap", dependency = null, namespace = "std::collections")
+5 −0
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.smithy

import software.amazon.smithy.build.PluginContext
import software.amazon.smithy.build.SmithyBuildPlugin
import software.amazon.smithy.model.Model

class RustCodegenPlugin : SmithyBuildPlugin {
    override fun getName(): String = "rust-codegen"
@@ -14,4 +15,8 @@ class RustCodegenPlugin : SmithyBuildPlugin {
    override fun execute(context: PluginContext) {
        CodegenVisitor(context).execute()
    }

    companion object {
        fun BaseSymbolProvider(model: Model, symbolVisitorConfig: SymbolVisitorConfig = DefaultConfig) = SymbolVisitor(model, config = symbolVisitorConfig).let { SymbolMetadataProvider(it) }
    }
}
Loading