Unverified Commit 0a63d5b8 authored by Fahad Zubair's avatar Fahad Zubair Committed by GitHub
Browse files

Pass UnionShape for union type discrimination (#3984)



UnionShape needs to be passed so that the customization code can detect
which Union type is being generated.

Co-authored-by: default avatarFahad Zubair <fahadzub@amazon.com>
parent 3d801c4f
Loading
Loading
Loading
Loading
+7 −4
Original line number Diff line number Diff line
@@ -78,8 +78,10 @@ abstract class CborParserCustomization : NamedCustomization<CborParserSection>()
     * @param defaultContext The default discrimination context containing decoder symbol and discriminator method.
     * @return UnionVariantDiscriminatorContext that defines how to discriminate union variants.
     */
    open fun getUnionVariantDiscriminator(defaultContext: CborParserGenerator.UnionVariantDiscriminatorContext) =
        defaultContext
    open fun getUnionVariantDiscriminator(
        unionShape: UnionShape,
        defaultContext: CborParserGenerator.UnionVariantDiscriminatorContext,
    ) = defaultContext
}

class CborParserGenerator(
@@ -330,7 +332,7 @@ class CborParserGenerator(
            val returnSymbolToParse = returnSymbolToParse(shape)
            // Get actual decoder type to use and the discriminating function to call to extract
            // the variant of the union that has been encoded in the data.
            val discriminatorContext = getUnionDiscriminatorContext("Decoder", "decoder.str()?.as_ref()")
            val discriminatorContext = getUnionDiscriminatorContext(shape, "Decoder", "decoder.str()?.as_ref()")

            rustBlockTemplate(
                """
@@ -394,6 +396,7 @@ class CborParserGenerator(
        }

    private fun getUnionDiscriminatorContext(
        unionShape: UnionShape,
        decoderType: String,
        callMethod: String,
    ): UnionVariantDiscriminatorContext {
@@ -403,7 +406,7 @@ class CborParserGenerator(
                writable { rustTemplate(callMethod) },
            )
        return customizations.fold(defaultUnionPairContext) { context, customization ->
            customization.getUnionVariantDiscriminator(context)
            customization.getUnionVariantDiscriminator(unionShape, context)
        }
    }