Unverified Commit 86bddca6 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Move inputs, outputs, and op errors into operation modules (#2394)

parent afb1f16c
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -97,7 +97,8 @@ fun generateSmithyBuild(services: AwsServices): String {
                        "codegen": {
                            "includeFluentClient": false,
                            "renameErrors": false,
                            "eventStreamAllowList": [$eventStreamAllowListMembers]
                            "eventStreamAllowList": [$eventStreamAllowListMembers],
                            "enableNewCrateOrganizationScheme": false
                        },
                        "service": "${service.service}",
                        "module": "$moduleName",
+5 −2
Original line number Diff line number Diff line
@@ -71,7 +71,10 @@ class ClientCodegenVisitor(
            runtimeConfig = settings.runtimeConfig,
            renameExceptions = settings.codegenConfig.renameExceptions,
            nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1,
            moduleProvider = ClientModuleProvider,
            moduleProvider = when (settings.codegenConfig.enableNewCrateOrganizationScheme) {
                true -> ClientModuleProvider
                else -> OldModuleSchemeClientModuleProvider
            },
        )
        val baseModel = baselineTransform(context.model)
        val untransformedService = settings.getService(baseModel)
@@ -263,7 +266,7 @@ class ClientCodegenVisitor(
     * Generate errors for operation shapes
     */
    override fun operationShape(shape: OperationShape) {
        rustCrate.withModule(ClientRustModule.Error) {
        rustCrate.withModule(symbolProvider.moduleForOperationError(shape)) {
            OperationErrorGenerator(
                model,
                symbolProvider,
+66 −5
Original line number Diff line number Diff line
@@ -5,16 +5,23 @@

package software.amazon.smithy.rust.codegen.client.smithy

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.core.smithy.ModuleProvider
import software.amazon.smithy.rust.codegen.core.smithy.ModuleProviderContext
import software.amazon.smithy.rust.codegen.core.smithy.contextName
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase

/**
 * Modules for code generated client crates.
@@ -22,6 +29,7 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait
object ClientRustModule {
    /** crate::client */
    val client = Client.self

    object Client {
        /** crate::client */
        val self = RustModule.public("client", "Client and fluent builders for calling the service.")
@@ -40,7 +48,55 @@ object ClientRustModule {
}

object ClientModuleProvider : ModuleProvider {
    override fun moduleForShape(shape: Shape): RustModule.LeafModule = when (shape) {
    override fun moduleForShape(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule = when (shape) {
        is OperationShape -> perOperationModule(context, shape)
        is StructureShape -> when {
            shape.hasTrait<ErrorTrait>() -> ClientRustModule.Error
            shape.hasTrait<SyntheticInputTrait>() -> perOperationModule(context, shape)
            shape.hasTrait<SyntheticOutputTrait>() -> perOperationModule(context, shape)
            else -> ClientRustModule.Model
        }

        else -> ClientRustModule.Model
    }

    override fun moduleForOperationError(
        context: ModuleProviderContext,
        operation: OperationShape,
    ): RustModule.LeafModule = perOperationModule(context, operation)

    override fun moduleForEventStreamError(
        context: ModuleProviderContext,
        eventStream: UnionShape,
    ): RustModule.LeafModule = ClientRustModule.Error

    private fun Shape.findOperation(model: Model): OperationShape {
        val inputTrait = getTrait<SyntheticInputTrait>()
        val outputTrait = getTrait<SyntheticOutputTrait>()
        return when {
            this is OperationShape -> this
            inputTrait != null -> model.expectShape(inputTrait.operation, OperationShape::class.java)
            outputTrait != null -> model.expectShape(outputTrait.operation, OperationShape::class.java)
            else -> UNREACHABLE("this is only called with compatible shapes")
        }
    }

    private fun perOperationModule(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule {
        val operationShape = shape.findOperation(context.model)
        val contextName = operationShape.contextName(context.serviceShape)
        val operationModuleName =
            RustReservedWords.escapeIfNeeded(contextName.toSnakeCase())
        return RustModule.public(
            operationModuleName,
            parent = ClientRustModule.Operation,
            documentation = "Types for the `$contextName` operation.",
        )
    }
}

// TODO(CrateReorganization): Remove this provider
object OldModuleSchemeClientModuleProvider : ModuleProvider {
    override fun moduleForShape(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule = when (shape) {
        is OperationShape -> ClientRustModule.Operation
        is StructureShape -> when {
            shape.hasTrait<ErrorTrait>() -> ClientRustModule.Error
@@ -48,12 +104,17 @@ object ClientModuleProvider : ModuleProvider {
            shape.hasTrait<SyntheticOutputTrait>() -> ClientRustModule.Output
            else -> ClientRustModule.Model
        }

        else -> ClientRustModule.Model
    }

    override fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule =
        ClientRustModule.Error
    override fun moduleForOperationError(
        context: ModuleProviderContext,
        operation: OperationShape,
    ): RustModule.LeafModule = ClientRustModule.Error

    override fun moduleForEventStreamError(eventStream: UnionShape): RustModule.LeafModule =
        ClientRustModule.Error
    override fun moduleForEventStreamError(
        context: ModuleProviderContext,
        eventStream: UnionShape,
    ): RustModule.LeafModule = ClientRustModule.Error
}
+5 −5
Original line number Diff line number Diff line
@@ -77,15 +77,15 @@ class RustClientCodegenPlugin : ClientDecoratableBuildPlugin() {
        fun baseSymbolProvider(model: Model, serviceShape: ServiceShape, rustSymbolProviderConfig: RustSymbolProviderConfig) =
            SymbolVisitor(model, serviceShape = serviceShape, config = rustSymbolProviderConfig)
                // Generate different types for EventStream shapes (e.g. transcribe streaming)
                .let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, model, CodegenTarget.CLIENT) }
                .let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, CodegenTarget.CLIENT) }
                // Generate `ByteStream` instead of `Blob` for streaming binary shapes (e.g. S3 GetObject)
                .let { StreamingShapeSymbolProvider(it, model) }
                .let { StreamingShapeSymbolProvider(it) }
                // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes
                .let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf(NonExhaustive)) }
                .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf(NonExhaustive)) }
                // Streaming shapes need different derives (e.g. they cannot derive `PartialEq`)
                .let { StreamingShapeMetadataProvider(it, model) }
                .let { StreamingShapeMetadataProvider(it) }
                // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot
                // be the name of an operation input
                .let { RustReservedWordSymbolProvider(it, model) }
                .let { RustReservedWordSymbolProvider(it) }
    }
}
+1 −1
Original line number Diff line number Diff line
@@ -69,7 +69,7 @@ class PaginatorGenerator private constructor(
    }

    private val paginatorName = "${operation.id.name.toPascalCase()}Paginator"
    private val runtimeConfig = symbolProvider.config().runtimeConfig
    private val runtimeConfig = symbolProvider.config.runtimeConfig
    private val idx = PaginatedIndex.of(model)
    private val paginationInfo =
        idx.getPaginationInfo(service, operation).orNull() ?: PANIC("failed to load pagination info")
Loading