Unverified Commit 4ad631a6 authored by Fahad Zubair's avatar Fahad Zubair Committed by GitHub
Browse files

Add support for customizing union variants in JSON/CBOR serialization/de-serialization (#3970)



This PR introduces customization points in JSON and CBOR serialization /
deserialization logic to support customization of the wire format of
Union variant keys.

---------

Co-authored-by: default avatarFahad Zubair <fahadzub@amazon.com>
parent 1f9c6088
Loading
Loading
Loading
Loading
+17 −2
Original line number Diff line number Diff line
@@ -45,15 +45,30 @@ operator fun Writable.plus(other: Writable): Writable =

/**
 * Helper allowing a `Iterable<Writable>` to be joined together using a `String` separator.
 * @param separator The string to use as a separator between elements
 * @param prefix An optional string to prepend to the entire joined sequence (defaults to null)
 * @return A Writable containing the optionally prefixed, joined elements
 */
fun Iterable<Writable>.join(separator: String) = join(writable(separator))
fun Iterable<Writable>.join(
    separator: String,
    prefix: String? = null,
) = join(writable(separator), prefix?.let { writable(it) })

/**
 * Helper allowing a `Iterable<Writable>` to be joined together using a `Writable` separator.
 * @param separator The Writable to use as a separator between elements
 * @param prefix An optional Writable to prepend to the entire joined sequence (defaults to null)
 * @return A Writable containing the optionally prefixed, joined elements
 */
fun Iterable<Writable>.join(separator: Writable): Writable {
fun Iterable<Writable>.join(
    separator: Writable,
    prefix: Writable? = null,
): Writable {
    val iter = this.iterator()
    return writable {
        if (iter.hasNext() && prefix != null) {
            prefix()
        }
        iter.forEach { value ->
            value()
            if (iter.hasNext()) {
+9 −2
Original line number Diff line number Diff line
@@ -19,8 +19,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserCustomization
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
@@ -92,7 +94,11 @@ class RpcV2CborHttpBindingResolver(
        ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, "application/cbor")
}

open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol {
open class RpcV2Cbor(
    val codegenContext: CodegenContext,
    private val serializeCustomization: List<CborSerializerCustomization> = listOf(),
    private val parserCustomization: List<CborParserCustomization> = listOf(),
) : Protocol {
    private val runtimeConfig = codegenContext.runtimeConfig

    override val httpBindingResolver: HttpBindingResolver =
@@ -134,10 +140,11 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol {
                    )
                }
            },
            customizations = parserCustomization,
        )

    override fun structuredDataSerializer(): StructuredDataSerializerGenerator =
        CborSerializerGenerator(codegenContext, httpBindingResolver)
        CborSerializerGenerator(codegenContext, httpBindingResolver, customizations = serializeCustomization)

    override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
        RuntimeType.cborErrors(runtimeConfig).resolve("parse_error_metadata")
+70 −4
Original line number Diff line number Diff line
@@ -29,6 +29,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
@@ -57,10 +58,29 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape
/** Class describing a CBOR parser section that can be used in a customization. */
sealed class CborParserSection(name: String) : Section(name) {
    data class BeforeBoxingDeserializedMember(val shape: MemberShape) : CborParserSection("BeforeBoxingDeserializedMember")

    /**
     * Represents a customization point in union deserialization that occurs before decoding the map structure.
     * This allows for custom handling of union variants before the standard map decoding logic is applied.
     * @property shape The union shape being deserialized.
     */
    data class UnionParserBeforeDecodingMap(val shape: UnionShape) : CborParserSection("UnionParserBeforeDecodingMap")
}

/** Customization for the CBOR parser. */
typealias CborParserCustomization = NamedCustomization<CborParserSection>
/**
 * Customization class for CBOR parser generation that allows modification of union type deserialization behavior.
 * Previously, union variant discrimination was hardcoded to use `decoder.str()`. This has been made more flexible
 * to support different decoder implementations and discrimination methods.
 */
abstract class CborParserCustomization : NamedCustomization<CborParserSection>() {
    /**
     * Allows customization of how union variants are discriminated during deserialization.
     * @param defaultContext The default discrimination context containing decoder symbol and discriminator method.
     * @return UnionVariantDiscriminatorContext that defines how to discriminate union variants.
     */
    open fun getUnionVariantDiscriminator(defaultContext: CborParserGenerator.UnionVariantDiscriminatorContext) =
        defaultContext
}

class CborParserGenerator(
    private val codegenContext: CodegenContext,
@@ -75,6 +95,16 @@ class CborParserGenerator(
    private val shouldWrapBuilderMemberSetterInputWithOption: (MemberShape) -> Boolean = { _ -> true },
    private val customizations: List<CborParserCustomization> = emptyList(),
) : StructuredDataParserGenerator {
    /**
     * Context class that encapsulates the information needed to discriminate union variants during deserialization.
     * @property decoderSymbol The symbol representing the decoder type.
     * @property variantDiscriminatorExpression The method call expression to determine the union variant.
     */
    data class UnionVariantDiscriminatorContext(
        val decoderSymbol: Symbol,
        val variantDiscriminatorExpression: Writable,
    )

    private val model = codegenContext.model
    private val symbolProvider = codegenContext.symbolProvider
    private val runtimeConfig = codegenContext.runtimeConfig
@@ -298,16 +328,26 @@ class CborParserGenerator(
    private fun unionPairParserFnWritable(shape: UnionShape) =
        writable {
            val returnSymbolToParse = returnSymbolToParse(shape)
            // Get actual decoder type to use and the discriminating function to call to extract
            // the variant of the union that has been encoded in the data.
            val discriminatorContext = getUnionDiscriminatorContext("Decoder", "decoder.str()?.as_ref()")

            rustBlockTemplate(
                """
                fn pair(
                    decoder: &mut #{Decoder}
                    decoder: &mut #{DecoderSymbol}
                ) -> #{Result}<#{UnionSymbol}, #{Error}>
                """,
                *codegenScope,
                "DecoderSymbol" to discriminatorContext.decoderSymbol,
                "UnionSymbol" to returnSymbolToParse.symbol,
            ) {
                withBlock("Ok(match decoder.str()?.as_ref() {", "})") {
                rustTemplate(
                    """
                    Ok(match #{VariableDiscriminatingExpression} {
                    """,
                    "VariableDiscriminatingExpression" to discriminatorContext.variantDiscriminatorExpression,
                ).run {
                    for (member in shape.members()) {
                        val variantName = symbolProvider.toMemberName(member)

@@ -349,6 +389,21 @@ class CborParserGenerator(
                            )
                    }
                }
                rust("})")
            }
        }

    private fun getUnionDiscriminatorContext(
        decoderType: String,
        callMethod: String,
    ): UnionVariantDiscriminatorContext {
        val defaultUnionPairContext =
            UnionVariantDiscriminatorContext(
                smithyCbor.resolve(decoderType).toSymbol(),
                writable { rustTemplate(callMethod) },
            )
        return customizations.fold(defaultUnionPairContext) { context, customization ->
            customization.getUnionVariantDiscriminator(context)
        }
    }

@@ -677,12 +732,22 @@ class CborParserGenerator(

    private fun RustWriter.deserializeUnion(shape: UnionShape) {
        val returnSymbolToParse = returnSymbolToParse(shape)
        val beforeDecoderMapCustomization =
            customizations.map { customization ->
                customization.section(
                    CborParserSection.UnionParserBeforeDecodingMap(
                        shape,
                    ),
                )
            }.join("")

        val parser =
            protocolFunctions.deserializeFn(shape) { fnName ->
                rustTemplate(
                    """
                    pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{UnionSymbol}, #{Error}> {
                        #{UnionPairParserFnWritable}
                        #{BeforeDecoderMapCustomization:W}

                        match decoder.map()? {
                            None => {
@@ -707,6 +772,7 @@ class CborParserGenerator(
                    """,
                    "UnionSymbol" to returnSymbolToParse.symbol,
                    "UnionPairParserFnWritable" to unionPairParserFnWritable(shape),
                    "BeforeDecoderMapCustomization" to beforeDecoderMapCustomization,
                    *codegenScope,
                )
            }
+13 −0
Original line number Diff line number Diff line
@@ -70,6 +70,13 @@ sealed class JsonParserSection(name: String) : Section(name) {

    data class AfterDocumentDeserializedMember(val shape: MemberShape) :
        JsonParserSection("AfterDocumentDeserializedMember")

    /**
     * Represents a customization point at the beginning of union deserialization, before any token
     * processing occurs.
     */
    data class BeforeUnionDeserialize(val shape: UnionShape) :
        JsonParserSection("BeforeUnionDeserialize")
}

/**
@@ -548,6 +555,12 @@ class JsonParserGenerator(
                    *codegenScope,
                    "Shape" to returnSymbolToParse.symbol,
                ) {
                    // Apply any custom union deserialization logic before processing tokens.
                    // This allows for customization of how union variants are handled,
                    // particularly their discrimination mechanism.
                    for (customization in customizations) {
                        customization.section(JsonParserSection.BeforeUnionDeserialize(shape))(this)
                    }
                    rust("let mut variant = None;")
                    val checkValueSet = !shape.members().all { it.isTargetUnit() } && !codegenTarget.renderUnknownVariant()
                    rustBlock("match tokens.next().transpose()?") {
+188 −17

File changed.

Preview size limit exceeded, changes collapsed.

Loading