Unverified Commit 9d31b6c4 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Split generated code into multiple files (#2425)

* Split fluent client functions across multiple files

* Add (commented) test for large generated files to CI

* Split serialization/deserialization generated code across multiple files

* Remove extraneous newline from fluent client doc comments

* Add doc comments to `ProtocolFunctions`

* Simplify some doc comment generation

* Improve some function names
parent bec93c8a
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -20,13 +20,13 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustom
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientRestXmlFactory
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.core.smithy.traits.AllowInvalidXmlRoot
@@ -101,9 +101,9 @@ class S3ProtocolOverride(codegenContext: CodegenContext) : RestXml(codegenContex
    )

    override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType {
        return RuntimeType.forInlineFun("parse_http_error_metadata", RustModule.private("xml_deser")) {
        return ProtocolFunctions.crossOperationFn("parse_http_error_metadata") { fnName ->
            rustBlockTemplate(
                "pub fn parse_http_error_metadata(response: &#{Response}<#{Bytes}>) -> Result<#{ErrorBuilder}, #{XmlDecodeError}>",
                "pub fn $fnName(response: &#{Response}<#{Bytes}>) -> Result<#{ErrorBuilder}, #{XmlDecodeError}>",
                *errorScope,
            ) {
                rustTemplate(
+6 −5
Original line number Diff line number Diff line
@@ -13,15 +13,16 @@ import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
import software.amazon.smithy.rust.codegen.core.smithy.mapRustType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.lensName
import software.amazon.smithy.rust.codegen.core.smithy.protocols.nestedAccessorName

/** Generator for accessing nested fields through optional values **/
class NestedAccessorGenerator(private val symbolProvider: RustSymbolProvider) {
class NestedAccessorGenerator(private val codegenContext: CodegenContext) {
    private val symbolProvider = codegenContext.symbolProvider
    private val module = RustModule.private("lens", "Generated accessors for nested fields")

    /**
@@ -30,7 +31,7 @@ class NestedAccessorGenerator(private val symbolProvider: RustSymbolProvider) {
    fun generateOwnedAccessor(root: StructureShape, path: List<MemberShape>): RuntimeType {
        check(path.isNotEmpty()) { "must not be called on an empty path" }
        val baseType = symbolProvider.toSymbol(path.last())
        val fnName = symbolProvider.lensName("", root, path)
        val fnName = symbolProvider.nestedAccessorName(codegenContext.serviceShape, "", root, path)
        return RuntimeType.forInlineFun(fnName, module) {
            rustTemplate(
                """
@@ -49,7 +50,7 @@ class NestedAccessorGenerator(private val symbolProvider: RustSymbolProvider) {
    fun generateBorrowingAccessor(root: StructureShape, path: List<MemberShape>): RuntimeType {
        check(path.isNotEmpty()) { "must not be called on an empty path" }
        val baseType = symbolProvider.toSymbol(path.last()).makeOptional()
        val fnName = symbolProvider.lensName("ref", root, path)
        val fnName = symbolProvider.nestedAccessorName(codegenContext.serviceShape, "ref", root, path)
        val referencedType = baseType.mapRustType { (it as RustType.Option).referenced(lifetime = null) }
        return RuntimeType.forInlineFun(fnName, module) {
            rustTemplate(
+3 −3
Original line number Diff line number Diff line
@@ -36,7 +36,7 @@ fun OperationShape.isPaginated(model: Model) =
        .findMemberWithTrait<IdempotencyTokenTrait>(model) == null

class PaginatorGenerator private constructor(
    codegenContext: ClientCodegenContext,
    private val codegenContext: ClientCodegenContext,
    operation: OperationShape,
    private val generics: FluentClientGenerics,
    retryClassifier: RuntimeType,
@@ -111,7 +111,7 @@ class PaginatorGenerator private constructor(

    /** Generate the paginator struct & impl **/
    private fun generate() = writable {
        val outputTokenLens = NestedAccessorGenerator(symbolProvider).generateBorrowingAccessor(
        val outputTokenLens = NestedAccessorGenerator(codegenContext).generateBorrowingAccessor(
            outputShape,
            paginationInfo.outputTokenMemberPath,
        )
@@ -266,7 +266,7 @@ class PaginatorGenerator private constructor(
                }

                """,
                "extract_items" to NestedAccessorGenerator(symbolProvider).generateOwnedAccessor(
                "extract_items" to NestedAccessorGenerator(codegenContext).generateOwnedAccessor(
                    outputShape,
                    paginationInfo.itemsMemberPath,
                ),
+129 −124
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.PaginatorGen
import software.amazon.smithy.rust.codegen.client.smithy.generators.isPaginated
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive
import software.amazon.smithy.rust.codegen.core.rustlang.EscapeFor
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
@@ -27,12 +28,12 @@ import software.amazon.smithy.rust.codegen.core.rustlang.asArgumentType
import software.amazon.smithy.rust.codegen.core.rustlang.asOptional
import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape
import software.amazon.smithy.rust.codegen.core.rustlang.docLink
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.documentShape
import software.amazon.smithy.rust.codegen.core.rustlang.escape
import software.amazon.smithy.rust.codegen.core.rustlang.normalizeHtml
import software.amazon.smithy.rust.codegen.core.rustlang.qualifiedName
import software.amazon.smithy.rust.codegen.core.rustlang.render
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTypeParameters
@@ -65,6 +66,11 @@ class FluentClientGenerator(
    companion object {
        fun clientOperationFnName(operationShape: OperationShape, symbolProvider: RustSymbolProvider): String =
            RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(operationShape).name.toSnakeCase())
        fun clientOperationModuleName(operationShape: OperationShape, symbolProvider: RustSymbolProvider): String =
            RustReservedWords.escapeIfNeeded(
                symbolProvider.toSymbol(operationShape).name.toSnakeCase(),
                EscapeFor.ModuleName,
            )
    }

    private val serviceShape = codegenContext.serviceShape
@@ -76,9 +82,7 @@ class FluentClientGenerator(
    private val core = FluentClientCore(model)

    fun render(crate: RustCrate) {
        crate.withModule(ClientRustModule.client) {
            renderFluentClient(this)
        }
        renderFluentClient(crate)

        operations.forEach { operation ->
            crate.withModule(operation.fluentBuilderModule(codegenContext, symbolProvider)) {
@@ -89,9 +93,10 @@ class FluentClientGenerator(
        CustomizableOperationGenerator(codegenContext, generics).render(crate)
    }

    private fun renderFluentClient(writer: RustWriter) {
    private fun renderFluentClient(crate: RustCrate) {
        crate.withModule(ClientRustModule.client) {
            if (!codegenContext.settings.codegenConfig.enableNewCrateOrganizationScheme || reexportSmithyClientBuilder) {
            writer.rustTemplate(
                rustTemplate(
                    """
                    ##[doc(inline)]
                    pub use #{client}::Builder;
@@ -99,7 +104,7 @@ class FluentClientGenerator(
                    "client" to RuntimeType.smithyClient(runtimeConfig),
                )
            }
        writer.rustTemplate(
            rustTemplate(
                """
                ##[derive(Debug)]
                pub(crate) struct Handle#{generics_decl:W} {
@@ -156,13 +161,20 @@ class FluentClientGenerator(
                        }
                    },
            )
        writer.rustBlockTemplate(
            "impl${generics.inst} Client${generics.inst} #{bounds:W}",
        }

        operations.forEach { operation ->
            val name = symbolProvider.toSymbol(operation).name
            val fnName = clientOperationFnName(operation, symbolProvider)
            val moduleName = clientOperationModuleName(operation, symbolProvider)

            val privateModule = RustModule.private(moduleName, parent = ClientRustModule.client)
            crate.withModule(privateModule) {
                rustBlockTemplate(
                    "impl${generics.inst} super::Client${generics.inst} #{bounds:W}",
                    "client" to RuntimeType.smithyClient(runtimeConfig),
                    "bounds" to generics.bounds,
                ) {
            operations.forEach { operation ->
                val name = symbolProvider.toSymbol(operation).name
                    val fullPath = operation.fullyQualifiedFluentBuilder(codegenContext, symbolProvider)
                    val maybePaginated = if (operation.isPaginated(model)) {
                        "\n/// This operation supports pagination; See [`into_paginator()`]($fullPath::into_paginator)."
@@ -175,7 +187,7 @@ class FluentClientGenerator(
                    val operationErr = symbolProvider.symbolForOperationError(operation)

                    val inputFieldsBody = generateOperationShapeDocs(
                    writer,
                        this,
                        codegenContext,
                        symbolProvider,
                        operation,
@@ -183,48 +195,46 @@ class FluentClientGenerator(
                    ).joinToString("\n") { "///   - $it" }

                    val inputFieldsHead = if (inputFieldsBody.isNotEmpty()) {
                    "The fluent builder is configurable:"
                        "The fluent builder is configurable:\n"
                    } else {
                        "The fluent builder takes no input, just [`send`]($fullPath::send) it."
                    }

                    val outputFieldsBody =
                    generateShapeMemberDocs(writer, symbolProvider, output, model).joinToString("\n") {
                        generateShapeMemberDocs(this, symbolProvider, output, model).joinToString("\n") {
                            "///   - $it"
                        }

                    var outputFieldsHead = "On success, responds with [`${operationOk.name}`]($operationOk)"
                    if (outputFieldsBody.isNotEmpty()) {
                    outputFieldsHead += " with field(s):"
                        outputFieldsHead += " with field(s):\n"
                    }

                writer.rustTemplate(
                    rustTemplate(
                        """
                        /// Constructs a fluent builder for the [`$name`]($fullPath) operation.$maybePaginated
                        ///
                    /// - $inputFieldsHead
                    $inputFieldsBody
                    /// - $outputFieldsHead
                    $outputFieldsBody
                        /// - $inputFieldsHead$inputFieldsBody
                        /// - $outputFieldsHead$outputFieldsBody
                        /// - On failure, responds with [`SdkError<${operationErr.name}>`]($operationErr)
                        """,
                    )

                    // Write a deprecation notice if this operation is deprecated.
                writer.deprecatedShape(operation)
                    deprecatedShape(operation)

                writer.rustTemplate(
                    rustTemplate(
                        """
                    pub fn #{fnName}(&self) -> #{FluentBuilder}${generics.inst} {
                        pub fn $fnName(&self) -> #{FluentBuilder}${generics.inst} {
                            #{FluentBuilder}::new(self.handle.clone())
                        }
                        """,
                    "fnName" to writable { rust(clientOperationFnName(operation, symbolProvider)) },
                        "FluentBuilder" to operation.fluentBuilderType(codegenContext, symbolProvider),
                    )
                }
            }
        }
    }

    private fun RustWriter.renderFluentBuilder(operation: OperationShape) {
        val operationSymbol = symbolProvider.toSymbol(operation)
@@ -232,12 +242,7 @@ class FluentClientGenerator(
        val baseDerives = symbolProvider.toSymbol(input).expectRustMetadata().derives
        // Filter out any derive that isn't Clone. Then add a Debug derive
        val derives = baseDerives.filter { it == RuntimeType.Clone } + RuntimeType.Debug
        rust(
            """
            /// Fluent builder constructing a request to `${operationSymbol.name}`.
            ///
            """,
        )
        docs("Fluent builder constructing a request to `${operationSymbol.name}`.\n")

        val builderName = operation.fluentBuilderType(codegenContext, symbolProvider).name
        documentShape(operation, model, autoSuppressMissingDocs = false)
+5 −9
Original line number Diff line number Diff line
@@ -14,7 +14,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.http.Respons
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.MakeOperationGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.assignment
@@ -36,6 +35,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDesc
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE
@@ -45,7 +45,6 @@ import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isStreaming
import software.amazon.smithy.rust.codegen.core.util.outputShape
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase

class HttpBoundProtocolGenerator(
    codegenContext: ClientCodegenContext,
@@ -71,7 +70,7 @@ class HttpBoundProtocolTraitImplGenerator(
    private val model = codegenContext.model
    private val runtimeConfig = codegenContext.runtimeConfig
    private val httpBindingResolver = protocol.httpBindingResolver
    private val operationDeserModule = RustModule.private("operation_deser")
    private val protocolFunctions = ProtocolFunctions(codegenContext)

    private val codegenScope = arrayOf(
        "ParseStrict" to RuntimeType.parseStrictResponse(runtimeConfig),
@@ -167,11 +166,10 @@ class HttpBoundProtocolTraitImplGenerator(
    }

    private fun parseError(operationShape: OperationShape, customizations: List<OperationCustomization>): RuntimeType {
        val fnName = "parse_${operationShape.id.name.toSnakeCase()}_error"
        val outputShape = operationShape.outputShape(model)
        val outputSymbol = symbolProvider.toSymbol(outputShape)
        val errorSymbol = symbolProvider.symbolForOperationError(operationShape)
        return RuntimeType.forInlineFun(fnName, operationDeserModule) {
        return protocolFunctions.deserializeFn(operationShape, fnNameSuffix = "http_error") { fnName ->
            Attribute.AllowClippyUnnecessaryWraps.render(this)
            rustBlockTemplate(
                "pub fn $fnName(response: &#{http}::Response<#{Bytes}>) -> std::result::Result<#{O}, #{E}>",
@@ -254,11 +252,10 @@ class HttpBoundProtocolTraitImplGenerator(
    }

    private fun parseStreamingResponse(operationShape: OperationShape, customizations: List<OperationCustomization>): RuntimeType {
        val fnName = "parse_${operationShape.id.name.toSnakeCase()}"
        val outputShape = operationShape.outputShape(model)
        val outputSymbol = symbolProvider.toSymbol(outputShape)
        val errorSymbol = symbolProvider.symbolForOperationError(operationShape)
        return RuntimeType.forInlineFun(fnName, operationDeserModule) {
        return protocolFunctions.deserializeFn(operationShape, fnNameSuffix = "http_response") { fnName ->
            Attribute.AllowClippyUnnecessaryWraps.render(this)
            rustBlockTemplate(
                "pub fn $fnName(op_response: &mut #{operation}::Response) -> std::result::Result<#{O}, #{E}>",
@@ -283,11 +280,10 @@ class HttpBoundProtocolTraitImplGenerator(
    }

    private fun parseResponse(operationShape: OperationShape, customizations: List<OperationCustomization>): RuntimeType {
        val fnName = "parse_${operationShape.id.name.toSnakeCase()}_response"
        val outputShape = operationShape.outputShape(model)
        val outputSymbol = symbolProvider.toSymbol(outputShape)
        val errorSymbol = symbolProvider.symbolForOperationError(operationShape)
        return RuntimeType.forInlineFun(fnName, operationDeserModule) {
        return protocolFunctions.deserializeFn(operationShape, fnNameSuffix = "http_response") { fnName ->
            Attribute.AllowClippyUnnecessaryWraps.render(this)
            rustBlockTemplate(
                "pub fn $fnName(response: &#{http}::Response<#{Bytes}>) -> std::result::Result<#{O}, #{E}>",
Loading