Unverified Commit a47ada58 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Name collision refactor (#492)

* Centralize colliding renames

* Add integration test

* Cleanup exception

* Add tests & fixes for more invalid Rust identifiers
parent 14f435e8
Loading
Loading
Loading
Loading
+8 −1
Original line number Diff line number Diff line
@@ -39,7 +39,11 @@ operation ReservedWordsAsMembers {

structure ReservedWords {
    as: Integer,
    async: Boolean
    async: Boolean,
    enum: UnknownVariantCollidingEnum,
    self: Boolean,
    crate: Boolean,
    super: Boolean
}

@httpRequestTests([
@@ -97,5 +101,8 @@ structure CollidingException {
@enum([
    { name: "Known", value: "Known" },
    { name: "Unknown", value: "Unknown" },
    { name: "Self", value: "Self" },
    { name: "UnknownValue", value: "UnknownValue" },
    { name: "SelfValue", value: "SelfValue" }
])
string UnknownVariantCollidingEnum
+35 −3
Original line number Diff line number Diff line
@@ -10,11 +10,16 @@ import software.amazon.smithy.codegen.core.ReservedWords
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.rust.codegen.smithy.MaybeRenamed
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.WrappingSymbolProvider
import software.amazon.smithy.rust.codegen.util.toPascalCase

class RustReservedWordSymbolProvider(private val base: RustSymbolProvider) : WrappingSymbolProvider(base) {
    private val internal =
        ReservedWordSymbolProvider.builder().symbolProvider(base).memberReservedWords(RustReservedWords).build()

class RustReservedWordSymbolProvider(base: RustSymbolProvider) : WrappingSymbolProvider(base) {
    private val internal = ReservedWordSymbolProvider.builder().symbolProvider(base).memberReservedWords(RustReservedWords).build()
    override fun toMemberName(shape: MemberShape): String {
        return internal.toMemberName(shape)
    }
@@ -22,6 +27,28 @@ class RustReservedWordSymbolProvider(base: RustSymbolProvider) : WrappingSymbolP
    override fun toSymbol(shape: Shape): Symbol {
        return internal.toSymbol(shape)
    }

    override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? {
        val baseName = base.toEnumVariantName(definition) ?: return null
        check(baseName.name.toPascalCase() == baseName.name) {
            "Enum variants must already be in pascal case"
        }
        check(baseName.renamedFrom == null) {
            "definitions should only pass through the renamer once"
        }
        return when (baseName.name) {
            // Self cannot be used as a raw identifier, so we can't use the normal escaping strategy
            // https://internals.rust-lang.org/t/raw-identifiers-dont-work-for-all-identifiers/9094/4
            "Self" -> MaybeRenamed("SelfValue", "Self")
            // Real models won't end in `_` so it's safe to stop here
            "SelfValue" -> MaybeRenamed("SelfValue_", "SelfValue")
            // Unknown is used as the name of the variant containing unexpected values
            "Unknown" -> MaybeRenamed("UnknownValue", "Unknown")
            // Real models won't end in `_` so it's safe to stop here
            "UnknownValue" -> MaybeRenamed("UnknownValue_", "UnknownValue")
            else -> baseName
        }
    }
}

object RustReservedWords : ReservedWords {
@@ -81,7 +108,12 @@ object RustReservedWords : ReservedWords {
        "try"
    )

    override fun escape(word: String): String = "r##$word"
    private val cantBeRaw = setOf("self", "crate", "super")

    override fun escape(word: String): String = when {
        cantBeRaw.contains(word) -> "${word}_"
        else -> "r##$word"
    }

    override fun isReserved(word: String): Boolean = RustKeywords.contains(word)
}
+5 −0
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.Attribute.Companion.NonExhaustive
@@ -27,6 +28,10 @@ open class WrappingSymbolProvider(private val base: RustSymbolProvider) : RustSy
        return base.config()
    }

    override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? {
        return base.toEnumVariantName(definition)
    }

    override fun toSymbol(shape: Shape): Symbol {
        return base.toSymbol(shape)
    }
+16 −1
Original line number Diff line number Diff line
@@ -34,6 +34,7 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpLabelTrait
@@ -44,6 +45,7 @@ import software.amazon.smithy.rust.codegen.smithy.traits.OutputBodyTrait
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.orNull
import software.amazon.smithy.rust.codegen.util.toPascalCase
import software.amazon.smithy.rust.codegen.util.toSnakeCase
import software.amazon.smithy.utils.StringUtils
@@ -71,7 +73,12 @@ data class SymbolVisitorConfig(

// TODO: consider if this is better handled as a wrapper
val DefaultConfig =
    SymbolVisitorConfig(runtimeConfig = RuntimeConfig(), handleOptionality = true, handleRustBoxing = true, codegenConfig = CodegenConfig())
    SymbolVisitorConfig(
        runtimeConfig = RuntimeConfig(),
        handleOptionality = true,
        handleRustBoxing = true,
        codegenConfig = CodegenConfig()
    )

data class SymbolLocation(val namespace: String) {
    val filename = "$namespace.rs"
@@ -117,8 +124,11 @@ fun Symbol.Builder.locatedIn(symbolLocation: SymbolLocation): Symbol.Builder {
        .rustType(newRustType)
}

data class MaybeRenamed(val name: String, val renamedFrom: String?)

interface RustSymbolProvider : SymbolProvider {
    fun config(): SymbolVisitorConfig
    fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed?
}

class SymbolVisitor(
@@ -142,6 +152,11 @@ class SymbolVisitor(
        }
    }

    override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? {
        val baseName = definition.name.orNull()?.toPascalCase() ?: return null
        return MaybeRenamed(baseName, null)
    }

    override fun toMemberName(shape: MemberShape): String = shape.memberName.toSnakeCase()

    override fun blobShape(shape: BlobShape?): Symbol {
+23 −34
Original line number Diff line number Diff line
@@ -5,7 +5,6 @@

package software.amazon.smithy.rust.codegen.smithy.generators

import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.traits.DocumentationTrait
@@ -18,45 +17,38 @@ import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.smithy.MaybeRenamed
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.util.doubleQuote
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.orNull
import software.amazon.smithy.rust.codegen.util.toPascalCase

/** Model that wraps [EnumDefinition] to calculate and cache values required to generate the Rust enum source. */
internal class EnumMemberModel(private val definition: EnumDefinition) {
internal class EnumMemberModel(private val definition: EnumDefinition, private val symbolProvider: RustSymbolProvider) {
    // Because enum variants always start with an upper case letter, they will never
    // conflict with reserved words (which are always lower case), therefore, we never need
    // to fall back to raw identifiers
    private val unescapedName: String? = definition.name.orNull()?.toPascalCase()

    val collidesWithUnknown: Boolean = unescapedName == EnumGenerator.UnknownVariant

    /** Enum name with correct case format and collision resolution */
    fun derivedName(): String = when (collidesWithUnknown) {
        // If there is a variant named "Unknown", then rename it to "UnknownValue" so that it
        // doesn't conflict with the code generator's "Unknown" variant that exists for backwards compatibility.
        true -> "UnknownValue"
        else -> checkNotNull(unescapedName) { "Enum variants must be named to derive a name. This is a bug." }
    }

    val value: String get() = definition.value

    fun name(): MaybeRenamed? = symbolProvider.toEnumVariantName(definition)

    private fun renderDocumentation(writer: RustWriter) {
        val name =
            checkNotNull(name()) { "cannot generate docs for unnamed enum variants" }
        writer.docWithNote(
            definition.documentation.orNull(),
            when (collidesWithUnknown) {
                true ->
                    "`::${EnumGenerator.UnknownVariant}` has been renamed to `::${EnumGenerator.EscapedUnknownVariant}`. " +
                        "`::${EnumGenerator.UnknownVariant}` refers to additional values that may have been added since " +
                        "this enum was generated."
                else -> null
            name.renamedFrom?.let { renamedFrom ->
                "`::$renamedFrom` has been renamed to `::${name.name}`."
            }

        )
    }

    fun derivedName() = checkNotNull(symbolProvider.toEnumVariantName(definition)).name

    fun render(writer: RustWriter) {
        renderDocumentation(writer)
        writer.write("${derivedName()},")
@@ -74,7 +66,7 @@ private fun RustWriter.docWithNote(doc: String?, note: String?) {

class EnumGenerator(
    private val model: Model,
    symbolProvider: SymbolProvider,
    private val symbolProvider: RustSymbolProvider,
    private val writer: RustWriter,
    private val shape: StringShape,
    private val enumTrait: EnumTrait
@@ -82,16 +74,13 @@ class EnumGenerator(
    private val symbol = symbolProvider.toSymbol(shape)
    private val enumName = symbol.name
    private val meta = symbol.expectRustMetadata()
    private val sortedMembers: List<EnumMemberModel> = enumTrait.values.sortedBy { it.value }.map(::EnumMemberModel)
    private val sortedMembers: List<EnumMemberModel> =
        enumTrait.values.sortedBy { it.value }.map { EnumMemberModel(it, symbolProvider) }

    companion object {
        /**
         * For enums with named members, variants with names that collide with the generated unknown enum
         * member get renamed to this [EscapedUnknownVariant] value.
         */
        const val EscapedUnknownVariant = "UnknownValue"
        /** Name of the generated unknown enum member name for enums with named members. */
        const val UnknownVariant = "Unknown"

        /** Name of the function on the enum impl to get a vec of value names */
        const val Values = "values"
    }
@@ -142,20 +131,20 @@ class EnumGenerator(
    }

    private fun renderEnum() {
        val renamedWarning =
            sortedMembers.mapNotNull { it.name() }.filter { it.renamedFrom != null }.joinToString("\n") {
                val previousName = it.renamedFrom!!
                "`$enumName::$previousName` has been renamed to `::${it.name}`."
            }
        writer.docWithNote(
            shape.getTrait<DocumentationTrait>()?.value,
            when (sortedMembers.any { it.collidesWithUnknown }) {
                true ->
                    "`$enumName::$UnknownVariant` has been renamed to `::$EscapedUnknownVariant`. " +
                        "`$enumName::$UnknownVariant` refers to additional values that may have been added since " +
                        "this enum was generated."
                else -> null
            }
            renamedWarning.ifBlank { null }
        )

        meta.render(writer)
        writer.rustBlock("enum $enumName") {
            sortedMembers.forEach { member -> member.render(writer) }
            docs("$UnknownVariant contains new variants that have been added since this code was generated.")
            write("$UnknownVariant(String)")
        }
    }
Loading