Unverified Commit 7cffe145 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Remove `toEnumVariantName` from `RustSymbolProvider` (#2377)

The `toEnumVariantName` function existed on symbol provider to work
around enum definitions not being shapes. In the future when we refactor
to use `EnumShape` instead of `EnumTrait`, there will be `MemberShape`s
for each enum member. This change incrementally moves us to that future
by creating fake `MemberShape`s in the enum generator from the enum
definition.
parent da264056
Loading
Loading
Loading
Loading
+14 −2
Original line number Diff line number Diff line
@@ -9,8 +9,10 @@ import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.smithy.Default
import software.amazon.smithy.rust.codegen.core.smithy.defaultValue
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.util.lookup
@@ -38,8 +40,18 @@ internal class StreamingShapeSymbolProviderTest {
        // "doing the right thing"
        val modelWithOperationTraits = OperationNormalizer.transform(model)
        val symbolProvider = testSymbolProvider(modelWithOperationTraits)
        symbolProvider.toSymbol(modelWithOperationTraits.lookup<MemberShape>("test.synthetic#GenerateSpeechOutput\$data")).name shouldBe ("ByteStream")
        symbolProvider.toSymbol(modelWithOperationTraits.lookup<MemberShape>("test.synthetic#GenerateSpeechInput\$data")).name shouldBe ("ByteStream")
        modelWithOperationTraits.lookup<MemberShape>("test.synthetic#GenerateSpeechOutput\$data").also { shape ->
            symbolProvider.toSymbol(shape).also { symbol ->
                symbol.name shouldBe "data"
                symbol.rustType() shouldBe RustType.Opaque("ByteStream", "aws_smithy_http::byte_stream")
            }
        }
        modelWithOperationTraits.lookup<MemberShape>("test.synthetic#GenerateSpeechInput\$data").also { shape ->
            symbolProvider.toSymbol(shape).also { symbol ->
                symbol.name shouldBe "data"
                symbol.rustType() shouldBe RustType.Opaque("ByteStream", "aws_smithy_http::byte_stream")
            }
        }
    }

    @Test
+36 −37
Original line number Diff line number Diff line
@@ -8,28 +8,29 @@ package software.amazon.smithy.rust.codegen.core.rustlang
import software.amazon.smithy.codegen.core.ReservedWordSymbolProvider
import software.amazon.smithy.codegen.core.ReservedWords
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.rust.codegen.core.smithy.MaybeRenamed
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.renamedFrom
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.toPascalCase

class RustReservedWordSymbolProvider(private val base: RustSymbolProvider) : WrappingSymbolProvider(base) {
    private val internal =
        ReservedWordSymbolProvider.builder().symbolProvider(base).memberReservedWords(RustReservedWords).build()

    override fun toMemberName(shape: MemberShape): String {
        val baseName = internal.toMemberName(shape)
        return when (val container = model.expectShape(shape.container)) {
            is StructureShape -> when (baseName) {
        val baseName = super.toMemberName(shape)
        val reservedWordReplacedName = internal.toMemberName(shape)
        val container = model.expectShape(shape.container)
        return when {
            container is StructureShape -> when (baseName) {
                "build" -> "build_value"
                "builder" -> "builder_value"
                "default" -> "default_value"
@@ -40,10 +41,10 @@ class RustReservedWordSymbolProvider(private val base: RustSymbolProvider) : Wra
                "customize" -> "customize_value"
                // To avoid conflicts with the error metadata `meta` field
                "meta" -> "meta_value"
                else -> baseName
                else -> reservedWordReplacedName
            }

            is UnionShape -> when (baseName) {
            container is UnionShape -> when (baseName) {
                // Unions contain an `Unknown` variant. This exists to support parsing data returned from the server
                // that represent union variants that have been added since this SDK was generated.
                UnionGenerator.UnknownVariantName -> "${UnionGenerator.UnknownVariantName}Value"
@@ -53,7 +54,20 @@ class RustReservedWordSymbolProvider(private val base: RustSymbolProvider) : Wra
                "Self" -> "SelfValue"
                // Real models won't end in `_` so it's safe to stop here
                "SelfValue" -> "SelfValue_"
                else -> baseName
                else -> reservedWordReplacedName
            }

            container is EnumShape || container.hasTrait<EnumTrait>() -> when (baseName) {
                // Self cannot be used as a raw identifier, so we can't use the normal escaping strategy
                // https://internals.rust-lang.org/t/raw-identifiers-dont-work-for-all-identifiers/9094/4
                "Self" -> "SelfValue"
                // Real models won't end in `_` so it's safe to stop here
                "SelfValue" -> "SelfValue_"
                // Unknown is used as the name of the variant containing unexpected values
                "Unknown" -> "UnknownValue"
                // Real models won't end in `_` so it's safe to stop here
                "UnknownValue" -> "UnknownValue_"
                else -> reservedWordReplacedName
            }

            else -> error("unexpected container: $container")
@@ -67,46 +81,31 @@ class RustReservedWordSymbolProvider(private val base: RustSymbolProvider) : Wra
     * code generators to generate special docs.
     */
    override fun toSymbol(shape: Shape): Symbol {
        // Sanity check that the symbol provider stack is set up correctly
        check(super.toSymbol(shape).renamedFrom() == null) {
            "RustReservedWordSymbolProvider should only run once"
        }

        var renamedSymbol = internal.toSymbol(shape)
        return when (shape) {
            is MemberShape -> {
                val container = model.expectShape(shape.container)
                if (!(container is StructureShape || container is UnionShape)) {
                val containerIsEnum = container is EnumShape || container.hasTrait<EnumTrait>()
                if (container !is StructureShape && container !is UnionShape && !containerIsEnum) {
                    return base.toSymbol(shape)
                }
                val previousName = base.toMemberName(shape)
                val escapedName = this.toMemberName(shape)
                val baseSymbol = base.toSymbol(shape)
                // if the names don't match and it isn't a simple escaping with `r#`, record a rename
                baseSymbol.letIf(escapedName != previousName && !escapedName.contains("r#")) {
                    it.toBuilder().renamedFrom(previousName).build()
                }
                renamedSymbol.toBuilder().name(escapedName)
                    .letIf(escapedName != previousName && !escapedName.contains("r#")) {
                        it.renamedFrom(previousName)
                    }.build()
            }

            else -> base.toSymbol(shape)
        }
    }

    override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? {
        val baseName = base.toEnumVariantName(definition) ?: return null
        check(definition.name.orNull()?.toPascalCase() == baseName.name) {
            "Enum variants must already be in pascal case ${baseName.name} differed from ${baseName.name.toPascalCase()}. Definition: ${definition.name}"
        }
        check(baseName.renamedFrom == null) {
            "definitions should only pass through the renamer once"
        }
        return when (baseName.name) {
            // Self cannot be used as a raw identifier, so we can't use the normal escaping strategy
            // https://internals.rust-lang.org/t/raw-identifiers-dont-work-for-all-identifiers/9094/4
            "Self" -> MaybeRenamed("SelfValue", "Self")
            // Real models won't end in `_` so it's safe to stop here
            "SelfValue" -> MaybeRenamed("SelfValue_", "SelfValue")
            // Unknown is used as the name of the variant containing unexpected values
            "Unknown" -> MaybeRenamed("UnknownValue", "Unknown")
            // Real models won't end in `_` so it's safe to stop here
            "UnknownValue" -> MaybeRenamed("UnknownValue_", "UnknownValue")
            else -> baseName
        }
    }
}

object RustReservedWords : ReservedWords {
+1 −5
Original line number Diff line number Diff line
@@ -14,19 +14,16 @@ import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule

/**
 * SymbolProvider interface that carries both the inner configuration and a function to produce an enum variant name.
 * SymbolProvider interface that carries additional configuration and module/symbol resolution.
 */
interface RustSymbolProvider : SymbolProvider {
    val model: Model
    val moduleProviderContext: ModuleProviderContext
    val config: RustSymbolProviderConfig

    fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed?

    fun moduleForShape(shape: Shape): RustModule.LeafModule =
        config.moduleProvider.moduleForShape(moduleProviderContext, shape)
    fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule =
@@ -84,7 +81,6 @@ open class WrappingSymbolProvider(private val base: RustSymbolProvider) : RustSy
    override val moduleProviderContext: ModuleProviderContext get() = base.moduleProviderContext
    override val config: RustSymbolProviderConfig get() = base.config

    override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? = base.toEnumVariantName(definition)
    override fun toSymbol(shape: Shape): Symbol = base.toSymbol(shape)
    override fun toMemberName(shape: MemberShape): String = base.toMemberName(shape)
    override fun symbolForOperationError(operation: OperationShape): Symbol = base.symbolForOperationError(operation)
+5 −0
Original line number Diff line number Diff line
@@ -114,6 +114,11 @@ class BaseSymbolMetadataProvider(
            }

            is UnionShape, is CollectionShape, is MapShape -> RustMetadata(visibility = Visibility.PUBLIC)

            // This covers strings with the enum trait for now, and can be removed once we're fully on EnumShape
            // TODO(https://github.com/awslabs/smithy-rs/issues/1700): Remove this `is StringShape` match arm
            is StringShape -> RustMetadata(visibility = Visibility.PUBLIC)

            else -> TODO("Unrecognized container type: $container")
        }

+8 −17
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.ByteShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.DoubleShape
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.IntegerShape
import software.amazon.smithy.model.shapes.ListShape
@@ -35,7 +36,6 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
@@ -47,7 +47,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import kotlin.reflect.KClass
@@ -132,22 +131,14 @@ open class SymbolVisitor(
            module.toType().resolve("${symbol.name}Error").toSymbol().toBuilder().locatedIn(module).build()
        }

    /**
     * Return the name of a given `enum` variant. Note that this refers to `enum` in the Smithy context
     * where enum is a trait that can be applied to [StringShape] and not in the Rust context of an algebraic data type.
     *
     * Because enum variants are not member shape, a separate handler is required.
     */
    override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? {
        val baseName = definition.name.orNull()?.toPascalCase() ?: return null
        return MaybeRenamed(baseName, null)
    }

    override fun toMemberName(shape: MemberShape): String = when (val container = model.expectShape(shape.container)) {
        is StructureShape -> shape.memberName.toSnakeCase()
        is UnionShape -> shape.memberName.toPascalCase()
    override fun toMemberName(shape: MemberShape): String {
        val container = model.expectShape(shape.container)
        return when {
            container is StructureShape -> shape.memberName.toSnakeCase()
            container is UnionShape || container is EnumShape || container.hasTrait<EnumTrait>() -> shape.memberName.toPascalCase()
            else -> error("unexpected container shape: $container")
        }
    }

    override fun blobShape(shape: BlobShape?): Symbol {
        return RuntimeType.blob(config.runtimeConfig).toSymbol()
Loading