Unverified Commit afb1f16c authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Move `RustSymbolProvider` and related types out of `SymbolVisitor` (#2380)

* Move base `RustSymbolProvider` types out of `SymbolVisitor`
* Rename `SymbolVisitorConfig` to `RustSymbolProviderConfig`
parent 3d007674
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -31,7 +31,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.implBlock
import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
@@ -67,7 +67,7 @@ class ClientCodegenVisitor(
    private val protocolGenerator: ClientProtocolGenerator

    init {
        val symbolVisitorConfig = SymbolVisitorConfig(
        val rustSymbolProviderConfig = RustSymbolProviderConfig(
            runtimeConfig = settings.runtimeConfig,
            renameExceptions = settings.codegenConfig.renameExceptions,
            nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1,
@@ -82,7 +82,7 @@ class ClientCodegenVisitor(
        model = codegenDecorator.transformModel(untransformedService, baseModel)
        // the model transformer _might_ change the service shape
        val service = settings.getService(model)
        symbolProvider = RustClientCodegenPlugin.baseSymbolProvider(model, service, symbolVisitorConfig)
        symbolProvider = RustClientCodegenPlugin.baseSymbolProvider(model, service, rustSymbolProviderConfig)

        codegenContext = ClientCodegenContext(model, symbolProvider, service, protocol, settings, codegenDecorator)

+4 −4
Original line number Diff line number Diff line
@@ -23,10 +23,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolP
import software.amazon.smithy.rust.codegen.core.smithy.BaseSymbolMetadataProvider
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeMetadataProvider
import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import java.util.logging.Level
import java.util.logging.Logger

@@ -74,10 +74,10 @@ class RustClientCodegenPlugin : ClientDecoratableBuildPlugin() {
         * The Symbol provider is composed of a base [SymbolVisitor] which handles the core functionality, then is layered
         * with other symbol providers, documented inline, to handle the full scope of Smithy types.
         */
        fun baseSymbolProvider(model: Model, serviceShape: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig) =
            SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig)
        fun baseSymbolProvider(model: Model, serviceShape: ServiceShape, rustSymbolProviderConfig: RustSymbolProviderConfig) =
            SymbolVisitor(model, serviceShape = serviceShape, config = rustSymbolProviderConfig)
                // Generate different types for EventStream shapes (e.g. transcribe streaming)
                .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.CLIENT) }
                .let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, model, CodegenTarget.CLIENT) }
                // Generate `ByteStream` instead of `Blob` for streaming binary shapes (e.g. S3 GetObject)
                .let { StreamingShapeSymbolProvider(it, model) }
                // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes
+3 −3
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.core.testutil.testRustSettings

@@ -49,7 +49,7 @@ fun clientTestRustSettings(
    customizationConfig,
)

val ClientTestSymbolVisitorConfig = SymbolVisitorConfig(
val ClientTestRustSymbolProviderConfig = RustSymbolProviderConfig(
    runtimeConfig = TestRuntimeConfig,
    renameExceptions = true,
    nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1,
@@ -60,7 +60,7 @@ fun testSymbolProvider(model: Model, serviceShape: ServiceShape? = null): RustSy
    RustClientCodegenPlugin.baseSymbolProvider(
        model,
        serviceShape ?: ServiceShape.builder().version("test").id("test#Service").build(),
        ClientTestSymbolVisitorConfig,
        ClientTestRustSymbolProviderConfig,
    )

fun testCodegenContext(
+3 −3
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.client.testutil.ClientTestSymbolVisitorConfig
import software.amazon.smithy.rust.codegen.client.testutil.ClientTestRustSymbolProviderConfig
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider
@@ -46,7 +46,7 @@ class EventStreamSymbolProviderTest {
        )

        val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
        val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, ClientTestSymbolVisitorConfig), model, CodegenTarget.CLIENT)
        val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, ClientTestRustSymbolProviderConfig), model, CodegenTarget.CLIENT)

        // Look up the synthetic input/output rather than the original input/output
        val inputStream = model.expectShape(ShapeId.from("test.synthetic#TestOperationInput\$inputStream")) as MemberShape
@@ -82,7 +82,7 @@ class EventStreamSymbolProviderTest {
        )

        val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
        val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, ClientTestSymbolVisitorConfig), model, CodegenTarget.CLIENT)
        val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, ClientTestRustSymbolProviderConfig), model, CodegenTarget.CLIENT)

        // Look up the synthetic input/output rather than the original input/output
        val inputStream = model.expectShape(ShapeId.from("test.synthetic#TestOperationInput\$inputStream")) as MemberShape
+73 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package software.amazon.smithy.rust.codegen.core.smithy

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.knowledge.NullableIndex
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule

/**
 * SymbolProvider interface that carries both the inner configuration and a function to produce an enum variant name.
 */
interface RustSymbolProvider : SymbolProvider, ModuleProvider {
    fun config(): RustSymbolProviderConfig
    fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed?

    override fun moduleForShape(shape: Shape): RustModule.LeafModule = config().moduleProvider.moduleForShape(shape)
    override fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule =
        config().moduleProvider.moduleForOperationError(operation)
    override fun moduleForEventStreamError(eventStream: UnionShape): RustModule.LeafModule =
        config().moduleProvider.moduleForEventStreamError(eventStream)

    /** Returns the symbol for an operation error */
    fun symbolForOperationError(operation: OperationShape): Symbol

    /** Returns the symbol for an event stream error */
    fun symbolForEventStreamError(eventStream: UnionShape): Symbol
}

/**
 * Provider for RustModules so that the symbol provider knows where to organize things.
 */
interface ModuleProvider {
    /** Returns the module for a shape */
    fun moduleForShape(shape: Shape): RustModule.LeafModule

    /** Returns the module for an operation error */
    fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule

    /** Returns the module for an event stream error */
    fun moduleForEventStreamError(eventStream: UnionShape): RustModule.LeafModule
}

/**
 * Configuration for symbol providers.
 */
data class RustSymbolProviderConfig(
    val runtimeConfig: RuntimeConfig,
    val renameExceptions: Boolean,
    val nullabilityCheckMode: NullableIndex.CheckMode,
    val moduleProvider: ModuleProvider,
)

/**
 * Default delegator to enable easily decorating another symbol provider.
 */
open class WrappingSymbolProvider(private val base: RustSymbolProvider) : RustSymbolProvider {
    override fun config(): RustSymbolProviderConfig = base.config()
    override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? = base.toEnumVariantName(definition)
    override fun toSymbol(shape: Shape): Symbol = base.toSymbol(shape)
    override fun toMemberName(shape: MemberShape): String = base.toMemberName(shape)
    override fun symbolForOperationError(operation: OperationShape): Symbol = base.symbolForOperationError(operation)
    override fun symbolForEventStreamError(eventStream: UnionShape): Symbol =
        base.symbolForEventStreamError(eventStream)
}
Loading