Unverified Commit 8b539938 authored by Russell Cohen's avatar Russell Cohen Committed by GitHub
Browse files

restXML Protocol Serializers (#394)



* restXML Protocol tests passing

* Fix bug in namespace priority resolution

* Fix unit test

* CR feedback

* Fix handwritten serializer test

* Fix the test, take 2

* Update codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/XmlBindingTraitSerializerGenerator.kt

fix typo

Co-authored-by: default avatarJohn DiSanti <johndisanti@gmail.com>

Co-authored-by: default avatarJohn DiSanti <johndisanti@gmail.com>
parent 392463d0
Loading
Loading
Loading
Loading
+7 −3
Original line number Diff line number Diff line
@@ -411,9 +411,10 @@ class HttpProtocolTestGenerator(
        // These could be configured via runtime configuration, but since this won't be long-lasting,
        // it makes sense to do the simplest thing for now.
        // The test will _fail_ if these pass, so we will discover & remove if we fix them by accident
        val JsonRpc10 = "aws.protocoltests.json10#JsonRpc10"
        val AwsJson11 = "aws.protocoltests.json#JsonProtocol"
        val RestJson = "aws.protocoltests.restjson#RestJson"
        private val JsonRpc10 = "aws.protocoltests.json10#JsonRpc10"
        private val AwsJson11 = "aws.protocoltests.json#JsonProtocol"
        private val RestJson = "aws.protocoltests.restjson#RestJson"
        private val RestXml = "aws.protocoltests.restxml#RestXml"
        private val ExpectFail = setOf(
            // Endpoint trait https://github.com/awslabs/smithy-rs/issues/197
            // This will also require running operations through the endpoint middleware (or moving endpoint middleware
@@ -424,6 +425,9 @@ class HttpProtocolTestGenerator(
            FailingTest(AwsJson11, "AwsJson11EndpointTraitWithHostLabel", Action.Request),
            FailingTest(RestJson, "RestJsonEndpointTrait", Action.Request),
            FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", Action.Request),
            FailingTest(RestXml, "RestXmlEndpointTraitWithHostLabelAndHttpBinding", Action.Request),
            FailingTest(RestXml, "RestXmlEndpointTraitWithHostLabel", Action.Request),
            FailingTest(RestXml, "RestXmlEndpointTrait", Action.Request)
        )
        private val RunOnly: Set<String>? = null

+4 −39
Original line number Diff line number Diff line
@@ -7,16 +7,11 @@ package software.amazon.smithy.rust.codegen.smithy.protocols

import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.render
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory
@@ -24,11 +19,10 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport
import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.XmlBindingTraitParserGenerator
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.XmlBindingTraitSerializerGenerator
import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.toSnakeCase

class RestXmlFactory : ProtocolGeneratorFactory<HttpTraitProtocolGenerator> {
    override fun buildProtocolGenerator(protocolConfig: ProtocolConfig): HttpTraitProtocolGenerator {
@@ -44,8 +38,8 @@ class RestXmlFactory : ProtocolGeneratorFactory<HttpTraitProtocolGenerator> {

    override fun support(): ProtocolSupport {
        return ProtocolSupport(
            requestSerialization = false,
            requestBodySerialization = false,
            requestSerialization = true,
            requestBodySerialization = true,
            responseDeserialization = true,
            errorDeserialization = true
        )
@@ -65,7 +59,7 @@ class RestXml(private val protocolConfig: ProtocolConfig) : Protocol {
    }

    override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator {
        return RestXmlSerializer(protocolConfig)
        return XmlBindingTraitSerializerGenerator(protocolConfig)
    }

    override fun parseGenericError(operationShape: OperationShape): RuntimeType {
@@ -91,32 +85,3 @@ class RestXml(private val protocolConfig: ProtocolConfig) : Protocol {

    override fun defaultContentType(): String = "application/xml"
}

class RestXmlSerializer(protocolConfig: ProtocolConfig) : StructuredDataSerializerGenerator {
    private val symbolProvider = protocolConfig.symbolProvider
    private val runtimeConfig = protocolConfig.runtimeConfig
    private val model = protocolConfig.model
    override fun payloadSerializer(member: MemberShape): RuntimeType {
        val target = model.expectShape(member.target)
        val fnName = "serialize_payload_${target.id.name.toSnakeCase()}_${member.container.name.toSnakeCase()}"
        return RuntimeType.forInlineFun(fnName, "operation_ser") {
            val t = symbolProvider.toSymbol(member).rustType().stripOuter<RustType.Option>().render(true)
            it.rustBlock(
                "pub fn $fnName(_input: &$t) -> Result<#T, String>",

                RuntimeType.sdkBody(runtimeConfig),
            ) {
                rust("todo!()")
            }
        }
    }

    override fun operationSerializer(operationShape: OperationShape): RuntimeType? {
        return null
    }

    override fun documentSerializer(): RuntimeType {
        // RestXML does not support documents
        TODO("Not yet implemented")
    }
}
+67 −0
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

package software.amazon.smithy.rust.codegen.smithy.protocols

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.KnowledgeIndex
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.XmlAttributeTrait
import software.amazon.smithy.model.traits.XmlNameTrait
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape

class XmlNameIndex(private val model: Model) : KnowledgeIndex {
    companion object {
        fun of(model: Model): XmlNameIndex {
            return model.getKnowledge(XmlNameIndex::class.java, ::XmlNameIndex)
        }
    }

    fun payloadShapeName(member: MemberShape): String {
        val payloadShape = model.expectShape(member.target)
        val xmlRename: XmlNameTrait? = member.getTrait() ?: payloadShape.getTrait()
        return xmlRename?.value ?: payloadShape.id.name
    }

    /**
     * XmlName for an operation output
     *
     * When an operation has no output body, null is returned
     */
    fun operationOutputShapeName(operationShape: OperationShape): String? {
        val outputShape = operationShape.outputShape(model)
        val rename = outputShape.getTrait<XmlNameTrait>()?.value
        return rename ?: outputShape.expectTrait<SyntheticOutputTrait>().originalId?.name
    }

    fun operationInputShapeName(operationShape: OperationShape): String? {
        val outputShape = operationShape.inputShape(model)
        val rename = outputShape.getTrait<XmlNameTrait>()?.value
        return rename ?: outputShape.expectTrait<SyntheticInputTrait>().originalId?.name
    }

    fun memberName(member: MemberShape): String {
        val override = member.getTrait<XmlNameTrait>()?.value
        return override ?: member.memberName
    }
}

data class XmlMemberIndex(val dataMembers: List<MemberShape>, val attributeMembers: List<MemberShape>) {
    companion object {
        fun fromMembers(members: List<MemberShape>): XmlMemberIndex {
            val (attribute, data) = members.partition { it.hasTrait<XmlAttributeTrait>() }
            return XmlMemberIndex(data, attribute)
        }
    }

    fun isNotEmpty() = dataMembers.isNotEmpty() || attributeMembers.isNotEmpty()
}
+1 −1
Original line number Diff line number Diff line
@@ -61,7 +61,7 @@ class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSe
                StructureShape::class.java
            )
        } ?: return null
        val fnName = "serialize_synthetic_${inputBody.id.name.toSnakeCase()}"
        val fnName = "serialize_operation_${inputBody.id.name.toSnakeCase()}"
        return RuntimeType.forInlineFun(fnName, "operation_ser") {
            it.rustBlockTemplate(
                "pub fn $fnName(input: &#{target}) -> Result<#{SdkBody}, #{Error}>",
+13 −52
Original line number Diff line number Diff line
@@ -21,9 +21,7 @@ import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.model.traits.XmlAttributeTrait
import software.amazon.smithy.model.traits.XmlFlattenedTrait
import software.amazon.smithy.model.traits.XmlNameTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
@@ -42,11 +40,10 @@ import software.amazon.smithy.rust.codegen.smithy.generators.builderSymbol
import software.amazon.smithy.rust.codegen.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.smithy.isBoxed
import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.smithy.protocols.XmlMemberIndex
import software.amazon.smithy.rust.codegen.smithy.protocols.XmlNameIndex
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectMember
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toPascalCase
@@ -59,20 +56,9 @@ class XmlBindingTraitParserGenerator(protocolConfig: ProtocolConfig, private val
     * Abstraction to represent an XML element name:
     * `[prefix]:[local]`
     */
    data class XmlName(val local: String, val prefix: String? = null) {
    data class XmlName(val name: String) {
        override fun toString(): String {
            return prefix?.let { "$it:" }.orEmpty() + local
        }

        companion object {
            fun parse(v: String): XmlName {
                val split = v.indexOf(':')
                return if (split == -1) {
                    XmlName(local = v, prefix = null)
                } else {
                    XmlName(v.substring(split + 1), prefix = v.substring(0, split))
                }
            }
            return name
        }
    }

@@ -105,6 +91,7 @@ class XmlBindingTraitParserGenerator(protocolConfig: ProtocolConfig, private val
    )
    private val model = protocolConfig.model
    private val index = HttpBindingIndex.of(model)
    private val xmlIndex = XmlNameIndex.of(model)

    /**
     * Generate a parse function for a given targeted as a payload.
@@ -130,7 +117,7 @@ class XmlBindingTraitParserGenerator(protocolConfig: ProtocolConfig, private val
                // for payloads, first look at the member trait
                // next, look to see if this structure was renamed

                val shapeName = payloadName(member)
                val shapeName = XmlName(xmlIndex.payloadShapeName(member))
                rustTemplate(
                    """
                    use std::convert::TryFrom;
@@ -138,7 +125,7 @@ class XmlBindingTraitParserGenerator(protocolConfig: ProtocolConfig, private val
                    ##[allow(unused_mut)]
                    let mut decoder = doc.root_element()?;
                    let start_el = decoder.start_el();
                    if !(${shapeName.compareTo("start_el")}) {
                    if !(${shapeName.matchExpression("start_el")}) {
                        return Err(#{XmlError}::custom(format!("invalid root, expected $shapeName got {:?}", start_el)))
                    }
                    """,
@@ -155,13 +142,6 @@ class XmlBindingTraitParserGenerator(protocolConfig: ProtocolConfig, private val
        }
    }

    private fun payloadName(member: MemberShape): XmlName {
        val payloadShape = model.expectShape(member.target)
        val xmlRename = member.getTrait<XmlNameTrait>() ?: payloadShape.getTrait()

        return xmlRename?.let { XmlName.parse(it.value) } ?: XmlName(local = payloadShape.id.name)
    }

    /** Generate a parser for operation input
     * Because only a subset of fields of the operation may be impacted by the document, a builder is passed
     * through:
@@ -175,7 +155,7 @@ class XmlBindingTraitParserGenerator(protocolConfig: ProtocolConfig, private val
    override fun operationParser(operationShape: OperationShape): RuntimeType? {
        val outputShape = operationShape.outputShape(model)
        val fnName = operationShape.id.name.toString().toSnakeCase() + "_deser_operation"
        val shapeName = operationXmlName(outputShape)
        val shapeName = xmlIndex.operationOutputShapeName(operationShape)
        val members = operationShape.operationXmlMembers()
        if (shapeName == null || !members.isNotEmpty()) {
            return null
@@ -193,7 +173,7 @@ class XmlBindingTraitParserGenerator(protocolConfig: ProtocolConfig, private val
                    ##[allow(unused_mut)]
                    let mut decoder = doc.root_element()?;
                    let start_el = decoder.start_el();
                    if !(${shapeName.compareTo("start_el")}) {
                    if !(${XmlName(shapeName).matchExpression("start_el")}) {
                        return Err(#{XmlError}::custom(format!("invalid root, expected $shapeName got {:?}", start_el)))
                    }
                    """,
@@ -205,13 +185,6 @@ class XmlBindingTraitParserGenerator(protocolConfig: ProtocolConfig, private val
        }
    }

    private fun operationXmlName(outputShape: StructureShape): XmlName? {
        return outputShape.getTrait<XmlNameTrait>()?.let { XmlName.parse(it.value) }
            ?: outputShape.expectTrait<SyntheticOutputTrait>().originalId?.name?.let {
                XmlName(local = it, prefix = null)
            }
    }

    override fun errorParser(errorShape: StructureShape): RuntimeType {
        val fnName = errorShape.id.name.toString().toSnakeCase()
        return RuntimeType.forInlineFun(fnName, "xml_deser") {
@@ -384,7 +357,7 @@ class XmlBindingTraitParserGenerator(protocolConfig: ProtocolConfig, private val
    private fun RustWriter.case(member: MemberShape, inner: RustWriter.() -> Unit) {
        rustBlock(
            "s if ${
            member.xmlName().compareTo("s")
            member.xmlName().matchExpression("s")
            } /* ${member.memberName} ${escape(member.id.toString())} */ => "
        ) {
            inner()
@@ -463,7 +436,7 @@ class XmlBindingTraitParserGenerator(protocolConfig: ProtocolConfig, private val
            ) {
                rust("let mut out = #T::new();", RustType.HashMap.RuntimeType)
                parseLoop(Ctx(tag = "decoder", accum = null)) { ctx ->
                    rustBlock("s if ${XmlName(local = "entry").compareTo("s")} => ") {
                    rustBlock("s if ${XmlName("entry").matchExpression("s")} => ") {
                        rust("#T(&mut ${ctx.tag}, &mut out)?;", mapEntryParser(target, ctx))
                    }
                }
@@ -598,28 +571,16 @@ class XmlBindingTraitParserGenerator(protocolConfig: ProtocolConfig, private val
    }

    private fun MemberShape.xmlName(): XmlName {
        val override = this.getTrait<XmlNameTrait>()
        return override?.let { XmlName.parse(it.value) } ?: XmlName(local = this.memberName)
        return XmlName(xmlIndex.memberName(this))
    }

    private fun MemberShape.isFlattened(): Boolean {
        return getMemberTrait(model, XmlFlattenedTrait::class.java).isPresent
    }

    fun XmlName.compareTo(start_el: String) =
    fun XmlName.matchExpression(start_el: String) =
        "$start_el.matches(${this.toString().dq()})"

    data class XmlMemberIndex(val dataMembers: List<MemberShape>, val attributeMembers: List<MemberShape>) {
        companion object {
            fun fromMembers(members: List<MemberShape>): XmlMemberIndex {
                val (attribute, data) = members.partition { it.hasTrait<XmlAttributeTrait>() }
                return XmlMemberIndex(data, attribute)
            }
        }

        fun isNotEmpty() = dataMembers.isNotEmpty() || attributeMembers.isNotEmpty()
    }

    private fun OperationShape.operationXmlMembers(): XmlMemberIndex {
        val outputShape = this.outputShape(model)
        val documentMembers =
Loading