Unverified Commit 39a51b72 authored by Ignatius's avatar Ignatius Committed by GitHub
Browse files

Add EnumSection to allow decorators to modify enum member attributes (#4039)



## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
Allows users to use decorators to add additional attributes to members
of an enum.

<!--- If it fixes an open issue, please link to the issue here -->
N/A

## Description
<!--- Describe your changes in detail -->
Adds `EnumSection` with named `AdditionalMemberAttributes` to allow
decorators to modify enum codegen.

## Testing
<!--- Please describe in detail how you tested your changes -->
Could not find existing unit tests for customizations; please advise if
you would like them to be added here. Ran `./gradlew` to ensure current
tests pass.
<!--- Include details of your testing environment, and the tests you ran
to -->
<!--- see how your change affects other areas of the code, etc. -->

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [X] For changes to the smithy-rs codegen or runtime crates, I have
created a changelog entry Markdown file in the `.changelog` directory,
specifying "client," "server," or both in the `applies_to` key.


----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._

---------

Co-authored-by: default avatarLandon James <lnj@amazon.com>
Co-authored-by: default avatarysaito1001 <awsaito@amazon.com>
parent de4be56e
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
---
applies_to: ["client", "server"]
authors: [Dorenavant]
references: []
breaking: false
new_feature: true
bug_fix: false
---

Add EnumSection to allow decorators to modify enum member attributes
+22 −4
Original line number Diff line number Diff line
@@ -93,7 +93,14 @@ class ClientCodegenVisitor(
        model = codegenDecorator.transformModel(untransformedService, baseModel, settings)
        // the model transformer _might_ change the service shape
        val service = settings.getService(model)
        symbolProvider = RustClientCodegenPlugin.baseSymbolProvider(settings, model, service, rustSymbolProviderConfig, codegenDecorator)
        symbolProvider =
            RustClientCodegenPlugin.baseSymbolProvider(
                settings,
                model,
                service,
                rustSymbolProviderConfig,
                codegenDecorator,
            )

        codegenContext =
            ClientCodegenContext(
@@ -177,7 +184,10 @@ class ClientCodegenVisitor(
        )
        try {
            // use an increased max_width to make rustfmt fail less frequently
            "cargo fmt -- --config max_width=150".runCommand(fileManifest.baseDir, timeout = settings.codegenConfig.formatTimeoutSeconds.toLong())
            "cargo fmt -- --config max_width=150".runCommand(
                fileManifest.baseDir,
                timeout = settings.codegenConfig.formatTimeoutSeconds.toLong(),
            )
        } catch (err: CommandError) {
            logger.warning("Failed to run cargo fmt: [${service.id}]\n${err.output}")
        }
@@ -236,7 +246,10 @@ class ClientCodegenVisitor(

                        implBlock(symbolProvider.toSymbol(shape)) {
                            BuilderGenerator.renderConvenienceMethod(this, symbolProvider, shape)
                            if (codegenContext.protocolImpl?.httpBindingResolver?.handlesEventStreamInitialResponse(shape) == true) {
                            if (codegenContext.protocolImpl?.httpBindingResolver?.handlesEventStreamInitialResponse(
                                    shape,
                                ) == true
                            ) {
                                BuilderGenerator.renderIntoBuilderMethod(this, symbolProvider, shape)
                            }
                        }
@@ -251,6 +264,7 @@ class ClientCodegenVisitor(
                    }
                    struct to builder
                }

                else -> {
                    val errorGenerator =
                        ErrorGenerator(
@@ -283,7 +297,11 @@ class ClientCodegenVisitor(
        if (shape.hasTrait<EnumTrait>()) {
            val privateModule = privateModule(shape)
            rustCrate.inPrivateModuleWithReexport(privateModule, symbolProvider.toSymbol(shape)) {
                ClientEnumGenerator(codegenContext, shape).render(this)
                ClientEnumGenerator(
                    codegenContext,
                    shape,
                    codegenDecorator.enumCustomizations(codegenContext, emptyList()),
                ).render(this)
            }
        }
    }
+17 −11
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGeneratorContext
import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumMemberModel
@@ -283,7 +284,11 @@ data class InfallibleEnumType(
    }
}

class ClientEnumGenerator(codegenContext: ClientCodegenContext, shape: StringShape) :
class ClientEnumGenerator(
    codegenContext: ClientCodegenContext,
    shape: StringShape,
    customizations: List<EnumCustomization>,
) :
    EnumGenerator(
            codegenContext.model,
            codegenContext.symbolProvider,
@@ -295,6 +300,7 @@ class ClientEnumGenerator(codegenContext: ClientCodegenContext, shape: StringSha
                    parent = ClientRustModule.primitives,
                ),
            ),
            customizations,
        )

private fun unknownVariantError(): RuntimeType =
+6 −6
Original line number Diff line number Diff line
@@ -28,7 +28,7 @@ class ClientEnumGeneratorTest {
            val context = testClientCodegenContext(model)
            val project = TestWorkspace.testProject(context.symbolProvider)
            project.moduleFor(shape) {
                ClientEnumGenerator(context, shape).render(this)
                ClientEnumGenerator(context, shape, emptyList()).render(this)
                unitTest(
                    "matching_on_enum_should_be_forward_compatible",
                    """
@@ -88,7 +88,7 @@ class ClientEnumGeneratorTest {
        val context = testClientCodegenContext(model)
        val project = TestWorkspace.testProject(context.symbolProvider)
        project.moduleFor(shape) {
            ClientEnumGenerator(context, shape).render(this)
            ClientEnumGenerator(context, shape, emptyList()).render(this)
            unitTest(
                "impl_debug_for_non_sensitive_enum_should_implement_the_derived_debug_trait",
                """
@@ -134,10 +134,10 @@ class ClientEnumGeneratorTest {
        val context = testClientCodegenContext(model)
        val project = TestWorkspace.testProject(context.symbolProvider)
        project.moduleFor(shapeA) {
            ClientEnumGenerator(context, shapeA).render(this)
            ClientEnumGenerator(context, shapeA, emptyList()).render(this)
        }
        project.moduleFor(shapeB) {
            ClientEnumGenerator(context, shapeB).render(this)
            ClientEnumGenerator(context, shapeB, emptyList()).render(this)
            unitTest(
                "impl_debug_for_non_sensitive_enum_should_implement_the_derived_debug_trait",
                """
@@ -172,7 +172,7 @@ class ClientEnumGeneratorTest {
        val context = testClientCodegenContext(model)
        val project = TestWorkspace.testProject(context.symbolProvider)
        project.moduleFor(shape) {
            ClientEnumGenerator(context, shape).render(this)
            ClientEnumGenerator(context, shape, emptyList()).render(this)
            unitTest(
                "it_escapes_the_unknown_variant_if_the_enum_has_an_unknown_value_in_the_model",
                """
@@ -205,7 +205,7 @@ class ClientEnumGeneratorTest {
        val project = TestWorkspace.testProject(context.symbolProvider)
        project.moduleFor(shape) {
            rust("##![allow(deprecated)]")
            ClientEnumGenerator(context, shape).render(this)
            ClientEnumGenerator(context, shape, emptyList()).render(this)
            unitTest(
                "generated_named_enums_roundtrip",
                """
+2 −2
Original line number Diff line number Diff line
@@ -53,7 +53,7 @@ internal class ClientInstantiatorTest {

        val project = TestWorkspace.testProject(symbolProvider)
        project.moduleFor(shape) {
            ClientEnumGenerator(codegenContext, shape).render(this)
            ClientEnumGenerator(codegenContext, shape, emptyList()).render(this)
            unitTest("generate_named_enums") {
                withBlock("let result = ", ";") {
                    sut.render(this, shape, data)
@@ -74,7 +74,7 @@ internal class ClientInstantiatorTest {

        val project = TestWorkspace.testProject(symbolProvider)
        project.moduleFor(shape) {
            ClientEnumGenerator(codegenContext, shape).render(this)
            ClientEnumGenerator(codegenContext, shape, emptyList()).render(this)
            unitTest("generate_unnamed_enums") {
                withBlock("let result = ", ";") {
                    sut.render(this, shape, data)
Loading