Unverified Commit 374a1c5d authored by david-perez's avatar david-perez Committed by GitHub
Browse files

Bubble up unknown enum variant error when parsing XML in the server (#1485)

XML deserialization of enums is currently broken.

This commit should have been done as part of #1398, but we only tested
that patch using the restJson1 protocol.

The converter from the unknown enum variant error into
`aws_smithy_json::deserialize::Error` has been removed from
`ServerEnumGenerator`, since it's protocol-specific logic. We instead
use `map_err` in the protocol-specific parsers before bubbling up using
`?`.

Fixes #1477.
parent 610d963d
Loading
Loading
Loading
Loading
+0 −6
Original line number Diff line number Diff line
@@ -55,11 +55,6 @@ open class ServerEnumGenerator(
                    Self::EnumVariantNotFound(Box::new(e))
                }
            }
            impl #{From}<$errorStruct> for #{JsonDeserialize} {
                fn from(e: $errorStruct) -> Self {
                    Self::custom(format!("unknown variant {}", e))
                }
            }
            impl #{StdError} for $errorStruct { }
            impl #{Display} for $errorStruct {
                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -71,7 +66,6 @@ open class ServerEnumGenerator(
            "From" to RuntimeType.From,
            "StdError" to RuntimeType.StdError,
            "RequestRejection" to ServerRuntimeType.RequestRejection(runtimeConfig),
            "JsonDeserialize" to RuntimeType.jsonDeserialize(runtimeConfig),
        )
    }

+0 −6
Original line number Diff line number Diff line
@@ -287,12 +287,6 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n
            namespace = "aws_smithy_http::response",
        )

        fun jsonDeserialize(runtimeConfig: RuntimeConfig) = RuntimeType(
            name = "Error",
            dependency = CargoDependency.smithyJson(runtimeConfig),
            namespace = "aws_smithy_json::deserialize",
        )

        fun ec2QueryErrors(runtimeConfig: RuntimeConfig) =
            forInlineDependency(InlineDependency.ec2QueryErrors(runtimeConfig))

+10 −3
Original line number Diff line number Diff line
@@ -250,7 +250,14 @@ class JsonParserGenerator(
            when (target.hasTrait<EnumTrait>()) {
                true -> {
                    if (convertsToEnumInServer(target)) {
                        rust("#T::try_from(u.as_ref())", symbolProvider.toSymbol(target))
                        rustTemplate(
                            """
                            #{EnumSymbol}::try_from(u.as_ref())
                                .map_err(|e| #{Error}::custom(format!("unknown variant {}", e)))
                            """,
                            "EnumSymbol" to symbolProvider.toSymbol(target),
                            *codegenScope,
                        )
                    } else {
                        rust("#T::from(u.as_ref())", symbolProvider.toSymbol(target))
                    }
@@ -263,8 +270,8 @@ class JsonParserGenerator(
    private fun convertsToEnumInServer(shape: StringShape) = target == CodegenTarget.SERVER && shape.hasTrait<EnumTrait>()

    private fun RustWriter.deserializeString(target: StringShape) {
        // additional .transpose()? because Rust does not allow ? up from closures
        val additionalTranspose = if (convertsToEnumInServer(target)) { ".transpose()?".repeat(2) } else { ".transpose()?" }
        // Additional `.transpose()?` because we can't use `?` inside the closures that parsed the string.
        val additionalTranspose = ".transpose()?".repeat(if (convertsToEnumInServer(target)) 2 else 1)
        withBlockTemplate("#{expect_string_or_null}(tokens.next())?.map(|s|", ")$additionalTranspose", *codegenScope) {
            deserializeStringInner(target, "s")
        }
+17 −7
Original line number Diff line number Diff line
@@ -39,6 +39,7 @@ import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.builderSymbol
@@ -654,18 +655,27 @@ class XmlBindingTraitParserGenerator(

    private fun RustWriter.parseStringInner(shape: StringShape, provider: RustWriter.() -> Unit) {
        withBlock("Result::<#T, #T>::Ok(", ")", symbolProvider.toSymbol(shape), xmlError) {
            if (!shape.hasTrait<EnumTrait>()) {
            if (shape.hasTrait<EnumTrait>()) {
                val enumSymbol = symbolProvider.toSymbol(shape)
                if (convertsToEnumInServer(shape)) {
                    withBlock("#T::try_from(", ")", enumSymbol) {
                        provider()
                // if it's already `Cow::Owned` then `.into()` is free (vs. to_string())
                rust(".into()")
                    }
                    rustTemplate(""".map_err(|e| #{XmlError}::custom(format!("unknown variant {}", e)))?""", *codegenScope)
                } else {
                val enumSymbol = symbolProvider.toSymbol(shape)
                    withBlock("#T::from(", ")", enumSymbol) {
                        provider()
                    }
                }
            } else {
                provider()
                // If it's already `Cow::Owned` then `.into()` is free (as opposed to using `to_string()`).
                rust(".into()")
            }
        }
    }

    private fun convertsToEnumInServer(shape: StringShape) = target == CodegenTarget.SERVER && shape.hasTrait<EnumTrait>()

    private fun MemberShape.xmlName(): XmlName {
        return XmlName(xmlIndex.memberName(this))