Unverified Commit 6ce79e88 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Refactor XML deserialization to use InlineFunctionNamer and add test (#497)

* Refactor XML deserialization to use InlineFunctionNamer and add test

* Fix XML error parser function naming issue
parent 4ad0d228
Loading
Loading
Loading
Loading
+15 −3
Original line number Diff line number Diff line
@@ -5,8 +5,11 @@

package software.amazon.smithy.rust.codegen.smithy.protocols

import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.SetShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
@@ -17,10 +20,13 @@ import software.amazon.smithy.rust.codegen.util.toSnakeCase
 * Creates a unique name for a serialization function.
 *
 * The prefixes will look like the following (for grep):
 * - serialize_list
 * - serialize_map
 * - serialize_member
 * - serialize_operation
 * - serialize_set
 * - serialize_structure
 * - serialize_union
 * - serialize_payload
 */
fun RustSymbolProvider.serializeFunctionName(shape: Shape): String = shapeFunctionName("serialize", shape)

@@ -28,20 +34,26 @@ fun RustSymbolProvider.serializeFunctionName(shape: Shape): String = shapeFuncti
 * Creates a unique name for a serialization function.
 *
 * The prefixes will look like the following (for grep):
 * - deser_list
 * - deser_map
 * - deser_member
 * - deser_operation
 * - deser_set
 * - deser_structure
 * - deser_union
 * - deser_payload
 */
fun RustSymbolProvider.deserializeFunctionName(shape: Shape): String = shapeFunctionName("deser", shape)

private fun RustSymbolProvider.shapeFunctionName(prefix: String, shape: Shape): String {
    val symbolNameSnakeCase = toSymbol(shape).name.toSnakeCase()
    return prefix + "_" + when (shape) {
        is ListShape -> "list_${shape.id.name.toSnakeCase()}"
        is MapShape -> "map_${shape.id.name.toSnakeCase()}"
        is MemberShape -> "member_${shape.container.name.toSnakeCase()}_${shape.memberName.toSnakeCase()}"
        is OperationShape -> "operation_$symbolNameSnakeCase"
        is SetShape -> "set_${shape.id.name.toSnakeCase()}"
        is StructureShape -> "structure_$symbolNameSnakeCase"
        is UnionShape -> "union_$symbolNameSnakeCase"
        is MemberShape -> "payload_${shape.target.name.toSnakeCase()}_${shape.container.name.toSnakeCase()}"
        else -> TODO("SerializerFunctionNamer.name: $shape")
    }
}
+7 −12
Original line number Diff line number Diff line
@@ -49,7 +49,6 @@ import software.amazon.smithy.rust.codegen.util.expectMember
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toPascalCase
import software.amazon.smithy.rust.codegen.util.toSnakeCase

// The string argument is the name of the XML ScopedDecoder to continue parsing from
typealias OperationInnerWriteable = RustWriter.(String) -> Unit
@@ -201,7 +200,7 @@ class XmlBindingTraitParserGenerator(
    }

    override fun errorParser(errorShape: StructureShape): RuntimeType {
        val fnName = errorShape.id.name.toString().toSnakeCase()
        val fnName = symbolProvider.deserializeFunctionName(errorShape) + "_xml_err"
        return RuntimeType.forInlineFun(fnName, "xml_deser") {
            Attribute.AllowUnusedMut.render(it)
            it.rustBlock(
@@ -338,7 +337,7 @@ class XmlBindingTraitParserGenerator(
    }

    private fun RustWriter.parseUnion(shape: UnionShape, ctx: Ctx) {
        val fnName = shape.id.name.toString().toSnakeCase() + "_inner"
        val fnName = symbolProvider.deserializeFunctionName(shape)
        val symbol = symbolProvider.toSymbol(shape)
        val nestedParser = RuntimeType.forInlineFun(fnName, "xml_deser") {
            it.rustBlockTemplate(
@@ -387,7 +386,7 @@ class XmlBindingTraitParserGenerator(
    }

    private fun RustWriter.parseStructure(shape: StructureShape, ctx: Ctx) {
        val fnName = shape.id.name.toString().toSnakeCase() + "_inner"
        val fnName = symbolProvider.deserializeFunctionName(shape)
        val symbol = symbolProvider.toSymbol(shape)
        val nestedParser = RuntimeType.forInlineFun(fnName, "xml_deser") {
            it.rustBlockTemplate(
@@ -418,7 +417,7 @@ class XmlBindingTraitParserGenerator(
    }

    private fun RustWriter.parseList(target: CollectionShape, ctx: Ctx) {
        val fnName = "deserialize_${target.member.id.name.toSnakeCase()}"
        val fnName = symbolProvider.deserializeFunctionName(target)
        val member = target.member
        val listParser = RuntimeType.forInlineFun(fnName, "xml_deser") {
            it.rustBlockTemplate(
@@ -453,7 +452,7 @@ class XmlBindingTraitParserGenerator(
    }

    private fun RustWriter.parseMap(target: MapShape, ctx: Ctx) {
        val fnName = "deserialize_${target.value.id.name.toSnakeCase()}"
        val fnName = symbolProvider.deserializeFunctionName(target)
        val mapParser = RuntimeType.forInlineFun(fnName, "xml_deser") {
            it.rustBlockTemplate(
                "pub fn $fnName(decoder: &mut #{ScopedDecoder}) -> Result<#{Map}, #{XmlError}>",
@@ -489,12 +488,8 @@ class XmlBindingTraitParserGenerator(
        }
    }

    private fun mapEntryParser(
        target: MapShape,
        ctx: Ctx
    ): RuntimeType {

        val fnName = target.value.id.name.toSnakeCase() + "_entry"
    private fun mapEntryParser(target: MapShape, ctx: Ctx): RuntimeType {
        val fnName = symbolProvider.deserializeFunctionName(target) + "_entry"
        return RuntimeType.forInlineFun(fnName, "xml_deser") {
            it.rustBlockTemplate(
                "pub fn $fnName(decoder: &mut #{ScopedDecoder}, out: &mut #{Map}) -> Result<(), #{XmlError}>",
+126 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

package software.amazon.smithy.rust.codegen.smithy.protocols

import io.kotest.assertions.withClue
import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.util.lookup

class InlineFunctionNamerTest {
    private val testModel = """
        namespace test

        structure SomeStruct1 {
            some_string: String,
            some_int: Integer,
        }

        union SomeUnion1 {
            int: Integer,
            long: Long,
        }

        map SomeMap1 {
            key: String,
            value: SomeStruct1,
        }

        list SomeList1 {
            member: Integer,
        }

        set SomeSet1 {
            member: Integer,
        }

        structure Op1Input {
            some_struct: SomeStruct1,
            some_list: SomeList1,
            some_set: SomeSet1,
            some_union: SomeUnion1,
            some_map: SomeMap1,
        }

        operation Op1 {
            input: Op1Input,
        }

        structure SomeStruct2 {
            some_string: String,
            some_int: Integer,
        }

        union SomeUnion2 {
            int: Integer,
            long: Long,
        }

        map SomeMap2 {
            key: String,
            value: SomeStruct2,
        }

        list SomeList2 {
            member: Integer,
        }

        structure Op2Input {
            some_struct: SomeStruct2,
            some_list: SomeList2,
            some_union: SomeUnion2,
            some_map: SomeMap2,
        }

        operation Op2 {
            input: Op1Input,
        }
    """.asSmithyModel()

    class UniqueChecker {
        private val names = HashSet<String>()

        fun checkName(value: String) {
            withClue("Name '$value' should be unique") {
                names.contains(value) shouldBe false
            }
            names.add(value)
        }
    }

    @Test
    fun `generates function names for shapes`() {
        val symbolProvider = testSymbolProvider(testModel)

        fun test(shapeId: String, suffix: String) {
            symbolProvider.serializeFunctionName(testModel.lookup(shapeId)) shouldBe "serialize_$suffix"
            symbolProvider.deserializeFunctionName(testModel.lookup(shapeId)) shouldBe "deser_$suffix"
        }

        test("test#Op1", "operation_op1")
        test("test#SomeList1", "list_some_list1")
        test("test#SomeMap1", "map_some_map1")
        test("test#SomeSet1", "set_some_set1")
        test("test#SomeStruct1", "structure_some_struct1")
        test("test#SomeUnion1", "union_some_union1")
        test("test#SomeStruct1\$some_string", "member_some_struct1_some_string")
    }

    @Test
    fun `generates unique function names for member shapes`() {
        val symbolProvider = testSymbolProvider(testModel)
        UniqueChecker().also { checker ->
            for (shape in testModel.shapes().filter { it.id.namespace == "test" }) {
                for (member in shape.members()) {
                    checker.checkName(symbolProvider.serializeFunctionName(member))
                    checker.checkName(symbolProvider.deserializeFunctionName(member))
                }
            }
        }
    }
}