Unverified Commit 2beb6009 authored by John DiSanti's avatar John DiSanti Committed by GitHub
Browse files

Implement the AWS Query serialization codegen (#436)

* Implement the AWS Query serialization codegen

* Fix bad import

* Simplify `serializeStructure`

* CR feedback
parent 814192d7
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -191,6 +191,7 @@ data class CargoDependency(
        )

        fun smithyJson(runtimeConfig: RuntimeConfig): CargoDependency = runtimeConfig.runtimeCrate("json")
        fun smithyQuery(runtimeConfig: RuntimeConfig): CargoDependency = runtimeConfig.runtimeCrate("query")
        fun smithyXml(runtimeConfig: RuntimeConfig): CargoDependency = runtimeConfig.runtimeCrate("xml")

        val SerdeJson: CargoDependency =
+2 −2
Original line number Diff line number Diff line
@@ -154,6 +154,8 @@ class BasicAwsJsonGenerator(
) : HttpProtocolGenerator(protocolConfig) {
    private val model = protocolConfig.model
    private val runtimeConfig = protocolConfig.runtimeConfig
    private val symbolProvider = protocolConfig.symbolProvider
    private val operationIndex = OperationIndex.of(model)

    override fun traitImplementations(operationWriter: RustWriter, operationShape: OperationShape) {
        val outputSymbol = symbolProvider.toSymbol(operationShape.outputShape(model))
@@ -175,8 +177,6 @@ class BasicAwsJsonGenerator(
        )
    }

    private val symbolProvider = protocolConfig.symbolProvider
    private val operationIndex = OperationIndex.of(model)
    override fun toHttpRequestImpl(
        implBlockWriter: RustWriter,
        operationShape: OperationShape,
+12 −4
Original line number Diff line number Diff line
@@ -46,8 +46,8 @@ class AwsQueryFactory : ProtocolGeneratorFactory<AwsQueryProtocolGenerator> {

    override fun support(): ProtocolSupport {
        return ProtocolSupport(
            requestSerialization = false,
            requestBodySerialization = false,
            requestSerialization = true,
            requestBodySerialization = true,
            responseDeserialization = true,
            errorDeserialization = true,
        )
@@ -110,8 +110,16 @@ class AwsQueryProtocolGenerator(private val protocolConfig: ProtocolConfig) : Ht
        inputShape: StructureShape
    ) {
        httpBuilderFun(implBlockWriter) {
            // TODO: Implement request building
            rust("unimplemented!()")
            rust(
                """
                Ok(
                    #T::new()
                        .method("POST")
                        .header("Content-Type", "application/x-www-form-urlencoded")
                )
                """,
                RuntimeType.HttpRequestBuilder
            )
        }
    }

+5 −5
Original line number Diff line number Diff line
@@ -397,7 +397,7 @@ class HttpTraitProtocolGenerator(
                let builder = builder.header("Content-Type", ${contentType.dq()});
                self.update_http_builder(builder)
                """,
                RuntimeType.Http("request::Builder")
                RuntimeType.HttpRequestBuilder
            )
        }
    }
+262 −16
Original line number Diff line number Diff line
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

package software.amazon.smithy.rust.codegen.smithy.protocols.serialize

import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.model.traits.XmlFlattenedTrait
import software.amazon.smithy.model.traits.XmlNameTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.withBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.smithy.protocols.serializeFunctionName
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.orNull
import software.amazon.smithy.rust.codegen.util.toPascalCase

class AwsQuerySerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSerializerGenerator {
    private data class Context<T : Shape>(
        /** Expression that yields a QueryValueWriter */
        val writerExpression: String,
        /** Expression representing the value to write to the QueryValueWriter */
        val valueExpression: ValueExpression,
        val shape: T,
    )

    private data class MemberContext(
        /** Expression that yields a QueryValueWriter */
        val writerExpression: String,
        /** Expression representing the value to write to the QueryValueWriter */
        val valueExpression: ValueExpression,
        val shape: MemberShape,
    ) {
        companion object {
            fun structMember(context: Context<StructureShape>, member: MemberShape, symProvider: RustSymbolProvider): MemberContext =
                MemberContext(
                    context.writerExpression,
                    ValueExpression.Value("${context.valueExpression.name}.${symProvider.toMemberName(member)}"),
                    member
                )

            fun unionMember(context: Context<UnionShape>, variantReference: String, member: MemberShape): MemberContext =
                MemberContext(
                    context.writerExpression,
                    ValueExpression.Reference(variantReference),
                    member
                )
        }
    }

    private val model = protocolConfig.model
    private val symbolProvider = protocolConfig.symbolProvider
    private val runtimeConfig = protocolConfig.runtimeConfig
    private val serviceShape = protocolConfig.serviceShape
    private val serializerError = RuntimeType.SerdeJson("error::Error")
    private val smithyTypes = CargoDependency.SmithyTypes(runtimeConfig).asType()
    private val smithyQuery = CargoDependency.smithyQuery(runtimeConfig).asType()
    private val codegenScope = arrayOf(
        "String" to RuntimeType.String,
        "Error" to serializerError,
        "SdkBody" to RuntimeType.sdkBody(runtimeConfig),
        "QueryWriter" to smithyQuery.member("QueryWriter"),
        "QueryValueWriter" to smithyQuery.member("QueryValueWriter"),
    )

    override fun documentSerializer(): RuntimeType {
        TODO("AwsQuery doesn't support document types")
    }

    override fun payloadSerializer(member: MemberShape): RuntimeType {
        val fnName = symbolProvider.serializeFunctionName(member)
        val target = model.expectShape(member.target, StructureShape::class.java)
        TODO("The Aws Query protocol doesn't support http payload traits")
    }

    override fun operationSerializer(operationShape: OperationShape): RuntimeType? {
        val fnName = symbolProvider.serializeFunctionName(operationShape)
        val inputShape = operationShape.inputShape(model)
        return RuntimeType.forInlineFun(fnName, "operation_ser") { writer ->
            writer.rustBlockTemplate(
                "pub fn $fnName(input: &#{target}) -> Result<#{SdkBody}, #{Error}>",
                *codegenScope,
                "target" to symbolProvider.toSymbol(target)
                *codegenScope, "target" to symbolProvider.toSymbol(inputShape)
            ) {
                // TODO: Implement query payload serializer
                writer.rust("unimplemented!()")
                val action = operationShape.id.name
                val version = serviceShape.version

                if (inputShape.members().isEmpty()) {
                    rust("let _ = input;")
                }
                rust("let mut out = String::new();")
                Attribute.AllowUnusedMut.render(writer)
                rustTemplate(
                    "let mut writer = #{QueryWriter}::new(&mut out, ${action.dq()}, ${version.dq()});",
                    *codegenScope
                )
                serializeStructureInner(Context("writer", ValueExpression.Reference("input"), inputShape))
                rust("writer.finish();")
                rustTemplate("Ok(#{SdkBody}::from(out))", *codegenScope)
            }
        }
    }

    override fun operationSerializer(operationShape: OperationShape): RuntimeType? {
        val fnName = symbolProvider.serializeFunctionName(operationShape)
        val inputShape = operationShape.inputShape(model)
        return RuntimeType.forInlineFun(fnName, "operation_ser") { writer ->
    private fun RustWriter.serializeStructure(context: Context<StructureShape>) {
        val fnName = symbolProvider.serializeFunctionName(context.shape)
        val structureSymbol = symbolProvider.toSymbol(context.shape)
        val structureSerializer = RuntimeType.forInlineFun(fnName, "query_ser") { writer ->
            Attribute.AllowUnusedMut.render(writer)
            writer.rustBlockTemplate(
                "pub fn $fnName(_input: &#{target}) -> Result<#{SdkBody}, #{Error}>",
                *codegenScope, "target" to symbolProvider.toSymbol(inputShape)
                "pub fn $fnName(mut writer: #{QueryValueWriter}, input: &#{Input})",
                "Input" to structureSymbol,
                *codegenScope
            ) {
                // TODO: Implement query operation serializer
                writer.rust("unimplemented!()")
                if (context.shape.members().isEmpty()) {
                    rust("let (_, _) = (writer, input);") // Suppress unused argument warnings
                }
                serializeStructureInner(context)
            }
        }
        rust("#T(${context.writerExpression}, ${context.valueExpression.name});", structureSerializer)
    }

    override fun documentSerializer(): RuntimeType {
        TODO("AwsQuery doesn't support document types")
    private fun RustWriter.serializeStructureInner(context: Context<StructureShape>) {
        context.copy(writerExpression = "writer", valueExpression = ValueExpression.Reference("input"))
            .also { inner ->
                for (member in inner.shape.members()) {
                    val memberContext = MemberContext.structMember(inner, member, symbolProvider)
                    structWriter(memberContext) { writerExpression ->
                        serializeMember(memberContext.copy(writerExpression = writerExpression))
                    }
                }
            }
    }

    private fun RustWriter.serializeMember(context: MemberContext) {
        val targetShape = model.expectShape(context.shape.target)
        if (symbolProvider.toSymbol(context.shape).isOptional()) {
            safeName().also { local ->
                rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") {
                    val innerContext = context.copy(valueExpression = ValueExpression.Reference(local))
                    serializeMemberValue(innerContext, targetShape)
                }
            }
        } else {
            serializeMemberValue(context, targetShape)
        }
    }

    private fun RustWriter.serializeMemberValue(context: MemberContext, target: Shape) {
        val writer = context.writerExpression
        val value = context.valueExpression
        when (target) {
            is StringShape -> when (target.hasTrait<EnumTrait>()) {
                true -> rust("$writer.string(${value.name}.as_str());")
                false -> rust("$writer.string(${value.name});")
            }
            is BooleanShape -> rust("$writer.boolean(${value.asValue()});")
            is NumberShape -> {
                val numberType = when (symbolProvider.toSymbol(target).rustType()) {
                    is RustType.Float -> "Float"
                    // NegInt takes an i64 while PosInt takes u64. We need this to be signed here
                    is RustType.Integer -> "NegInt"
                    else -> throw IllegalStateException("unreachable")
                }
                rust(
                    "$writer.number(##[allow(clippy::useless_conversion)]#T::$numberType((${value.asValue()}).into()));",
                    smithyTypes.member("Number")
                )
            }
            is BlobShape -> rust(
                "$writer.string(&#T(${value.name}));",
                RuntimeType.Base64Encode(runtimeConfig)
            )
            is TimestampShape -> {
                val timestampFormat = determineTimestampFormat(context.shape)
                val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
                rust("$writer.instant(${value.name}, #T);", timestampFormatType)
            }
            is CollectionShape -> serializeCollection(context, Context(writer, context.valueExpression, target))
            is MapShape -> serializeMap(context, Context(writer, context.valueExpression, target))
            is StructureShape -> serializeStructure(Context(writer, context.valueExpression, target))
            is UnionShape -> structWriter(context) { writerExpression ->
                serializeUnion(Context(writerExpression, context.valueExpression, target))
            }
            else -> TODO(target.toString())
        }
    }

    private fun determineTimestampFormat(shape: MemberShape): TimestampFormatTrait.Format =
        shape.getMemberTrait(model, TimestampFormatTrait::class.java).orNull()?.format
            ?: TimestampFormatTrait.Format.DATE_TIME

    private fun RustWriter.structWriter(context: MemberContext, inner: RustWriter.(String) -> Unit) {
        val prefix = context.shape.getTrait<XmlNameTrait>()?.value ?: context.shape.memberName
        safeName("scope").also { scopeName ->
            Attribute.AllowUnusedMut.render(this)
            rust("let mut $scopeName = ${context.writerExpression}.prefix(${prefix.dq()});")
            inner(scopeName)
        }
    }

    private fun RustWriter.serializeCollection(memberContext: MemberContext, context: Context<CollectionShape>) {
        val flat = memberContext.shape.getTrait<XmlFlattenedTrait>() != null
        val memberOverride = when (val override = context.shape.member.getTrait<XmlNameTrait>()?.value) {
            null -> "None"
            else -> "Some(${override.dq()})"
        }
        val itemName = safeName("item")
        safeName("list").also { listName ->
            rust("let mut $listName = ${context.writerExpression}.start_list($flat, $memberOverride);")
            rustBlock("for $itemName in ${context.valueExpression.asRef()}") {
                val entryName = safeName("entry")
                Attribute.AllowUnusedMut.render(this)
                rust("let mut $entryName = $listName.entry();")
                val targetShape = model.expectShape(context.shape.member.target)
                serializeMemberValue(
                    MemberContext(entryName, ValueExpression.Reference(itemName), context.shape.member),
                    targetShape
                )
            }
            rust("$listName.finish();")
        }
    }

    private fun RustWriter.serializeMap(memberContext: MemberContext, context: Context<MapShape>) {
        val flat = memberContext.shape.getTrait<XmlFlattenedTrait>() != null
        val entryKeyName = (context.shape.key.getTrait<XmlNameTrait>()?.value ?: "key").dq()
        val entryValueName = (context.shape.value.getTrait<XmlNameTrait>()?.value ?: "value").dq()
        safeName("map").also { mapName ->
            val keyName = safeName("key")
            val valueName = safeName("value")
            rust("let mut $mapName = ${context.writerExpression}.start_map($flat, $entryKeyName, $entryValueName);")
            rustBlock("for ($keyName, $valueName) in ${context.valueExpression.asRef()}") {
                val keyTarget = model.expectShape(context.shape.key.target)
                val keyExpression = when (keyTarget.hasTrait<EnumTrait>()) {
                    true -> "$keyName.as_str()"
                    else -> keyName
                }
                val entryName = safeName("entry")
                Attribute.AllowUnusedMut.render(this)
                rust("let mut $entryName = $mapName.entry($keyExpression);")
                serializeMember(MemberContext(entryName, ValueExpression.Reference(valueName), context.shape.value))
            }
            rust("$mapName.finish();")
        }
    }

    private fun RustWriter.serializeUnion(context: Context<UnionShape>) {
        val fnName = symbolProvider.serializeFunctionName(context.shape)
        val unionSymbol = symbolProvider.toSymbol(context.shape)
        val unionSerializer = RuntimeType.forInlineFun(fnName, "query_ser") { writer ->
            Attribute.AllowUnusedMut.render(writer)
            writer.rustBlockTemplate(
                "pub fn $fnName(mut writer: #{QueryValueWriter}, input: &#{Input})",
                "Input" to unionSymbol,
                *codegenScope,
            ) {
                rustBlock("match input") {
                    for (member in context.shape.members()) {
                        val variantName = member.memberName.toPascalCase()
                        withBlock("#T::$variantName(inner) => {", "},", unionSymbol) {
                            serializeMember(
                                MemberContext.unionMember(
                                    context.copy(writerExpression = "writer"),
                                    "inner",
                                    member
                                )
                            )
                        }
                    }
                }
            }
        }
        rust("#T(${context.writerExpression}, ${context.valueExpression.asRef()});", unionSerializer)
    }
}
Loading