Unverified Commit 35989d2f authored by 82marbag's avatar 82marbag Committed by GitHub
Browse files

Unknown enum variants removed from server (#1398)



The server must have the most up to date variants and the unknown enum
variant should not be used. Clients are generated with it because they
might not have the most recent model and the server might return
an unknown variant to them.

Closes #1187

Signed-off-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>

Co-authored-by: default avatarDaniele Ahmed <ahmeddan@amazon.de>
Co-authored-by: default avatardavid-perez <d@vidp.dev>
Co-authored-by: default avatarMatteo Bigoi <1781140+crisidev@users.noreply.github.com>
parent d6e2944c
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -35,7 +35,7 @@ class RustCodegenServerPlugin : SmithyBuildPlugin {
    override fun execute(context: PluginContext) {
        // Suppress extremely noisy logs about reserved words
        Logger.getLogger(ReservedWordSymbolProvider::class.java.name).level = Level.OFF
        // Discover [RustCodegenDecorators] on the classpath. [RustCodegenDectorator] return different types of
        // Discover [RustCodegenDecorators] on the classpath. [RustCodegenDecorator] return different types of
        // customization. A customization is a function of:
        // - location (e.g. the mutate section of an operation)
        // - context (e.g. the of the operation)
+2 −2
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerEnumGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerServiceGenerator
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
@@ -28,7 +29,6 @@ import software.amazon.smithy.rust.codegen.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
import software.amazon.smithy.rust.codegen.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.implBlock
@@ -184,7 +184,7 @@ class ServerCodegenVisitor(context: PluginContext, private val codegenDecorator:
        logger.info("[rust-server-codegen] Generating an enum $shape")
        shape.getTrait<EnumTrait>()?.also { enum ->
            rustCrate.useShapeWriter(shape) { writer ->
                EnumGenerator(model, symbolProvider, writer, shape, enum).render()
                ServerEnumGenerator(model, symbolProvider, writer, shape, enum, codegenContext.runtimeConfig).render()
            }
        }
    }
+1 −1
Original line number Diff line number Diff line
@@ -51,7 +51,7 @@ class ServerCombinedErrorGenerator(
            }
        }

        writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.stdfmt.member("Display")) {
        writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.Display) {
            rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") {
                delegateToVariants {
                    rust("_inner.fmt(f)")
+94 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */
package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.smithy.CodegenMode
import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator
import software.amazon.smithy.rust.codegen.util.dq

class ServerEnumGenerator(
    model: Model,
    symbolProvider: RustSymbolProvider,
    private val writer: RustWriter,
    shape: StringShape,
    enumTrait: EnumTrait,
    private val runtimeConfig: RuntimeConfig,
) : EnumGenerator(model, symbolProvider, writer, shape, enumTrait) {
    override var mode: CodegenMode = CodegenMode.Server
    private val errorStruct = "${enumName}UnknownVariantError"

    override fun renderFromForStr() {
        writer.rust(
            """
            ##[derive(Debug, PartialEq, Eq, Hash)]
            pub struct $errorStruct(String);
            """
        )
        writer.rustBlock("impl #T<&str> for $enumName", RuntimeType.TryFrom) {
            write("type Error = $errorStruct;")
            writer.rustBlock("fn try_from(s: &str) -> Result<Self, <$enumName as #T<&str>>::Error>", RuntimeType.TryFrom) {
                writer.rustBlock("match s") {
                    sortedMembers.forEach { member ->
                        write("${member.value.dq()} => Ok($enumName::${member.derivedName()}),")
                    }
                    write("_ => Err($errorStruct(s.to_owned()))")
                }
            }
        }
        writer.rustTemplate(
            """
            impl #{From}<$errorStruct> for #{RequestRejection} {
                fn from(e: $errorStruct) -> Self {
                    Self::EnumVariantNotFound(Box::new(e))
                }
            }

            impl #{From}<$errorStruct> for #{JsonDeserialize} {
                fn from(e: $errorStruct) -> Self {
                    Self::custom(format!("unknown variant {}", e))
                }
            }

            impl #{StdError} for $errorStruct { }

            impl #{Display} for $errorStruct {
                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                    self.0.fmt(f)
                }
            }
            """,
            "Display" to RuntimeType.Display,
            "From" to RuntimeType.From,
            "StdError" to RuntimeType.StdError,
            "RequestRejection" to ServerRuntimeType.RequestRejection(runtimeConfig),
            "JsonDeserialize" to RuntimeType.jsonDeserialize(runtimeConfig),
        )
    }

    override fun renderFromStr() {
        writer.rust(
            """
            impl std::str::FromStr for $enumName {
                type Err = $errorStruct;

                fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
                    $enumName::try_from(s)
                }
            }
            """
        )
    }
}
+3 −2
Original line number Diff line number Diff line
@@ -119,7 +119,7 @@ class ServerOperationRegistryGenerator(
            pub enum ${operationRegistryBuilderName}Error {
                UninitializedField(&'static str)
            }
            impl std::fmt::Display for ${operationRegistryBuilderName}Error {
            impl #{Display} for ${operationRegistryBuilderName}Error {
                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                    match self {
                        Self::UninitializedField(v) => write!(f, "{}", v),
@@ -128,7 +128,8 @@ class ServerOperationRegistryGenerator(
            }
            impl #{StdError} for ${operationRegistryBuilderName}Error {}
            """.trimIndent(),
            *codegenScope
            *codegenScope,
            "Display" to RuntimeType.Display,
        )
    }

Loading