Unverified Commit 9b40b38d authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

Fix Enum Serialization (#160)

Named enums were being incorrectly serialized via the derive macro. This was causing the value of the enum to be disregarded. This change implements `Serialize` and `Deserialize` explicitly for enums.

Additionally, a new protocol test was added that avoid using sets–this allow it to be run without being flaky until #37 is fixed.
parent cdb63348
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@ use dynamo::output::ListTablesOutput;
async fn main() -> Result<(), Box<dyn Error>> {
    let table_name = "new_table";
    let client = io_v0::Client::local("dynamodb");
    let config = dynamo::Config::from_env();
    let config = dynamo::Config::builder().build();
    let clear_table = dynamo::operation::DeleteTable::builder()
        .table_name(table_name)
        .build(&config);
@@ -64,6 +64,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
        )
        .build(&config);

    let body = String::from_utf8(create_table.build_http_request().body().clone()).unwrap();
    println!("{}", body);

    let response = io_v0::dispatch!(client, create_table);
    match response.parsed {
        Some(Ok(output)) => {
+48 −0
Original line number Diff line number Diff line
$version: "1.0"

namespace aws.protocoltests.json

use aws.protocols#awsJson1_1
use aws.protocoltests.shared#FooEnum
use aws.protocoltests.shared#FooEnumList
use aws.protocoltests.shared#FooEnumSet
use aws.protocoltests.shared#FooEnumMap
use smithy.test#httpRequestTests
use smithy.test#httpResponseTests

// The enum protocol test is disabled because of set ordering non-determinism. Add a simpler test
apply JsonEnums @httpRequestTests([
    {
        id: "AwsJson11EnumsBasic",
        documentation: "Serializes simple scalar properties",
        protocol: awsJson1_1,
        method: "POST",
        uri: "/",
        body: """
              {
                  "fooEnum1": "Foo",
                  "fooEnum2": "0",
                  "fooEnum3": "1",
                  "fooEnumList": [
                      "Foo",
                      "0"
                  ],
                  "fooEnumMap": {
                      "hi": "Foo",
                      "zero": "0"
                  }
              }""",
        headers: {"Content-Type": "application/x-amz-json-1.1"},
        bodyMediaType: "application/json",
        params: {
            fooEnum1: "Foo",
            fooEnum2: "0",
            fooEnum3: "1",
            fooEnumList: ["Foo", "0"],
            fooEnumMap: {
                "hi": "Foo",
                "zero": "0"
            }
        }
    }
])
+30 −3
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.expectRustMetadata
@@ -28,6 +29,7 @@ class EnumGenerator(
    private val symbol = symbolProvider.toSymbol(shape)
    private val enumName = symbol.name
    private val meta = symbol.expectRustMetadata()

    companion object {
        const val Values = "values"
    }
@@ -41,10 +43,11 @@ class EnumGenerator(
            renderFromStr()
            writer.insertTrailingNewline()
            // impl Blah { pub fn as_str(&self) -> &str
            renderAsStr()
            implBlock()
        } else {
            renderUnamedEnum()
        }
        renderSerde()
    }

    private fun renderUnamedEnum() {
@@ -74,7 +77,8 @@ class EnumGenerator(
        // Because enum variants always start with an upper case letter, they will never
        // conflict with reserved words (which are always lower case), therefore, we never need
        // to fall back to raw identifiers
        return name.orElse(null)?.toPascalCase() ?: throw IllegalStateException("Enum variants must be named to derive a name. This is a bug.")
        return name.orElse(null)?.toPascalCase()
            ?: throw IllegalStateException("Enum variants must be named to derive a name. This is a bug.")
    }

    private fun renderEnum() {
@@ -89,7 +93,7 @@ class EnumGenerator(
        }
    }

    private fun renderAsStr() {
    private fun implBlock() {
        // TODO: should enums also implement AsRef<str>?
        writer.rustBlock("impl $enumName") {
            writer.rustBlock("pub fn as_str(&self) -> &str") {
@@ -103,6 +107,29 @@ class EnumGenerator(
        }
    }

    private fun renderSerde() {
        writer.rustTemplate(
            """
                impl #{serialize} for $enumName {
                    fn serialize<S>(&self, serializer: S) -> Result<<S as #{serializer}>::Ok, <S as #{serializer}>::Error> where S: #{serializer}{
                        serializer.serialize_str(self.as_str())
                    }
                }

                impl<'de> #{deserialize}<'de> for $enumName {
                    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: #{deserializer}<'de> {
                        let data = <&str>::deserialize(deserializer)?;
                        Ok(Self::from(data))
                    }
                }
            """,
            "serializer" to RuntimeType.Serializer,
            "serialize" to RuntimeType.Serialize,
            "deserializer" to RuntimeType.Deserializer,
            "deserialize" to RuntimeType.Deserialize
        )
    }

    private fun renderFromStr() {
        writer.rustBlock("impl <T> #T<T> for $enumName where T: #T<str>", RuntimeType.From, RuntimeType.AsRef) {
            writer.rustBlock("fn from(s: T) -> Self") {
+1 −1
Original line number Diff line number Diff line
@@ -38,7 +38,7 @@ import software.amazon.smithy.rust.codegen.smithy.rustType
 */
class CustomSerializerGenerator(
    private val symbolProvider: RustSymbolProvider,
    model: Model,
    private val model: Model,
    private val defaultTimestampFormat: TimestampFormatTrait.Format
) {
    private val inp = "_inp"
+3 −2
Original line number Diff line number Diff line
@@ -66,7 +66,8 @@ class JsonSerializerSymbolProvider(
            serializerBuilder.deserializerFor(memberShape)?.also {
                attribs.add(Custom("serde(deserialize_with = ${it.fullyQualifiedName().dq()})", listOf(it)))
            }
            if (model.expectShape(memberShape.container) is StructureShape && base.toSymbol(memberShape).isOptional()
            if (model.expectShape(memberShape.container) is StructureShape && base.toSymbol(memberShape)
                .isOptional()
            ) {
                attribs.add(Custom("serde(default)"))
            }
@@ -76,7 +77,7 @@ class JsonSerializerSymbolProvider(

    override fun structureMeta(structureShape: StructureShape): RustMetadata = containerMeta(structureShape)
    override fun unionMeta(unionShape: UnionShape): RustMetadata = containerMeta(unionShape)
    override fun enumMeta(stringShape: StringShape): RustMetadata = containerMeta(stringShape)
    override fun enumMeta(stringShape: StringShape): RustMetadata = base.toSymbol(stringShape).expectRustMetadata()

    private fun containerMeta(container: Shape): RustMetadata {
        val currentMeta = base.toSymbol(container).expectRustMetadata()
Loading