diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt index 7b9ad233667e61e428a84c4f1becca939a472768..b2af3df6f335fd053437cee08547244b3d1070b4 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/Writable.kt @@ -45,15 +45,30 @@ operator fun Writable.plus(other: Writable): Writable = /** * Helper allowing a `Iterable` to be joined together using a `String` separator. + * @param separator The string to use as a separator between elements + * @param prefix An optional string to prepend to the entire joined sequence (defaults to null) + * @return A Writable containing the optionally prefixed, joined elements */ -fun Iterable.join(separator: String) = join(writable(separator)) +fun Iterable.join( + separator: String, + prefix: String? = null, +) = join(writable(separator), prefix?.let { writable(it) }) /** * Helper allowing a `Iterable` to be joined together using a `Writable` separator. + * @param separator The Writable to use as a separator between elements + * @param prefix An optional Writable to prepend to the entire joined sequence (defaults to null) + * @return A Writable containing the optionally prefixed, joined elements */ -fun Iterable.join(separator: Writable): Writable { +fun Iterable.join( + separator: Writable, + prefix: Writable? = null, +): Writable { val iter = this.iterator() return writable { + if (iter.hasNext() && prefix != null) { + prefix() + } iter.forEach { value -> value() if (iter.hasNext()) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt index c036e5f9aee2b300e503d641977df47686ce42af..94ff37909c7c7becf6759a27c7d97e5689f3ca6b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RpcV2Cbor.kt @@ -19,8 +19,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.CborParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer @@ -92,7 +94,11 @@ class RpcV2CborHttpBindingResolver( ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, "application/cbor") } -open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol { +open class RpcV2Cbor( + val codegenContext: CodegenContext, + private val serializeCustomization: List = listOf(), + private val parserCustomization: List = listOf(), +) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig override val httpBindingResolver: HttpBindingResolver = @@ -134,10 +140,11 @@ open class RpcV2Cbor(val codegenContext: CodegenContext) : Protocol { ) } }, + customizations = parserCustomization, ) override fun structuredDataSerializer(): StructuredDataSerializerGenerator = - CborSerializerGenerator(codegenContext, httpBindingResolver) + CborSerializerGenerator(codegenContext, httpBindingResolver, customizations = serializeCustomization) override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = RuntimeType.cborErrors(runtimeConfig).resolve("parse_error_metadata") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt index 6830d7a1f7468190001b0af7fa62666834701cd5..11fb43191b5eae6bef96ad8d3b778171be737487 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt @@ -29,6 +29,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock +import software.amazon.smithy.rust.codegen.core.rustlang.join import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate @@ -57,10 +58,29 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape /** Class describing a CBOR parser section that can be used in a customization. */ sealed class CborParserSection(name: String) : Section(name) { data class BeforeBoxingDeserializedMember(val shape: MemberShape) : CborParserSection("BeforeBoxingDeserializedMember") + + /** + * Represents a customization point in union deserialization that occurs before decoding the map structure. + * This allows for custom handling of union variants before the standard map decoding logic is applied. + * @property shape The union shape being deserialized. + */ + data class UnionParserBeforeDecodingMap(val shape: UnionShape) : CborParserSection("UnionParserBeforeDecodingMap") } -/** Customization for the CBOR parser. */ -typealias CborParserCustomization = NamedCustomization +/** + * Customization class for CBOR parser generation that allows modification of union type deserialization behavior. + * Previously, union variant discrimination was hardcoded to use `decoder.str()`. This has been made more flexible + * to support different decoder implementations and discrimination methods. + */ +abstract class CborParserCustomization : NamedCustomization() { + /** + * Allows customization of how union variants are discriminated during deserialization. + * @param defaultContext The default discrimination context containing decoder symbol and discriminator method. + * @return UnionVariantDiscriminatorContext that defines how to discriminate union variants. + */ + open fun getUnionVariantDiscriminator(defaultContext: CborParserGenerator.UnionVariantDiscriminatorContext) = + defaultContext +} class CborParserGenerator( private val codegenContext: CodegenContext, @@ -75,6 +95,16 @@ class CborParserGenerator( private val shouldWrapBuilderMemberSetterInputWithOption: (MemberShape) -> Boolean = { _ -> true }, private val customizations: List = emptyList(), ) : StructuredDataParserGenerator { + /** + * Context class that encapsulates the information needed to discriminate union variants during deserialization. + * @property decoderSymbol The symbol representing the decoder type. + * @property variantDiscriminatorExpression The method call expression to determine the union variant. + */ + data class UnionVariantDiscriminatorContext( + val decoderSymbol: Symbol, + val variantDiscriminatorExpression: Writable, + ) + private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider private val runtimeConfig = codegenContext.runtimeConfig @@ -298,16 +328,26 @@ class CborParserGenerator( private fun unionPairParserFnWritable(shape: UnionShape) = writable { val returnSymbolToParse = returnSymbolToParse(shape) + // Get actual decoder type to use and the discriminating function to call to extract + // the variant of the union that has been encoded in the data. + val discriminatorContext = getUnionDiscriminatorContext("Decoder", "decoder.str()?.as_ref()") + rustBlockTemplate( """ fn pair( - decoder: &mut #{Decoder} + decoder: &mut #{DecoderSymbol} ) -> #{Result}<#{UnionSymbol}, #{Error}> """, *codegenScope, + "DecoderSymbol" to discriminatorContext.decoderSymbol, "UnionSymbol" to returnSymbolToParse.symbol, ) { - withBlock("Ok(match decoder.str()?.as_ref() {", "})") { + rustTemplate( + """ + Ok(match #{VariableDiscriminatingExpression} { + """, + "VariableDiscriminatingExpression" to discriminatorContext.variantDiscriminatorExpression, + ).run { for (member in shape.members()) { val variantName = symbolProvider.toMemberName(member) @@ -349,9 +389,24 @@ class CborParserGenerator( ) } } + rust("})") } } + private fun getUnionDiscriminatorContext( + decoderType: String, + callMethod: String, + ): UnionVariantDiscriminatorContext { + val defaultUnionPairContext = + UnionVariantDiscriminatorContext( + smithyCbor.resolve(decoderType).toSymbol(), + writable { rustTemplate(callMethod) }, + ) + return customizations.fold(defaultUnionPairContext) { context, customization -> + customization.getUnionVariantDiscriminator(context) + } + } + enum class CollectionKind { Map, List, @@ -677,12 +732,22 @@ class CborParserGenerator( private fun RustWriter.deserializeUnion(shape: UnionShape) { val returnSymbolToParse = returnSymbolToParse(shape) + val beforeDecoderMapCustomization = + customizations.map { customization -> + customization.section( + CborParserSection.UnionParserBeforeDecodingMap( + shape, + ), + ) + }.join("") + val parser = protocolFunctions.deserializeFn(shape) { fnName -> rustTemplate( """ pub(crate) fn $fnName(decoder: &mut #{Decoder}) -> #{Result}<#{UnionSymbol}, #{Error}> { #{UnionPairParserFnWritable} + #{BeforeDecoderMapCustomization:W} match decoder.map()? { None => { @@ -707,6 +772,7 @@ class CborParserGenerator( """, "UnionSymbol" to returnSymbolToParse.symbol, "UnionPairParserFnWritable" to unionPairParserFnWritable(shape), + "BeforeDecoderMapCustomization" to beforeDecoderMapCustomization, *codegenScope, ) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index cf0676ebd0dfc8ca9a4d500c5792d4cb4780355b..8b8245287bb8050c6d0e808c347e23ff6cd95c22 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -70,6 +70,13 @@ sealed class JsonParserSection(name: String) : Section(name) { data class AfterDocumentDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterDocumentDeserializedMember") + + /** + * Represents a customization point at the beginning of union deserialization, before any token + * processing occurs. + */ + data class BeforeUnionDeserialize(val shape: UnionShape) : + JsonParserSection("BeforeUnionDeserialize") } /** @@ -548,6 +555,12 @@ class JsonParserGenerator( *codegenScope, "Shape" to returnSymbolToParse.symbol, ) { + // Apply any custom union deserialization logic before processing tokens. + // This allows for customization of how union variants are handled, + // particularly their discrimination mechanism. + for (customization in customizations) { + customization.section(JsonParserSection.BeforeUnionDeserialize(shape))(this) + } rust("let mut variant = None;") val checkValueSet = !shape.members().all { it.isTargetUnit() } && !codegenTarget.renderUnknownVariant() rustBlock("match tokens.next().transpose()?") { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt index 1830fb962f5e5c08dd0a9ecf63a5be04a3a199a9..759bd2fa9e71956bb3aad1e00cefc9bf47f3aa28 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt @@ -25,10 +25,14 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.isNotEmpty +import software.amazon.smithy.rust.codegen.core.rustlang.join import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope @@ -42,6 +46,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerGenerator.Context import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.dq @@ -59,8 +64,9 @@ sealed class CborSerializerSection(name: String) : Section(name) { * to record the error type in the case of an error structure. */ data class BeforeSerializingStructureMembers( - val structureShape: StructureShape, + val structContext: CborSerializerGenerator.StructContext, val encoderBindingName: String, + val codegenContext: CodegenContext, ) : CborSerializerSection("BeforeSerializingStructureMembers") /** Manipulate the serializer context for a map prior to it being serialized. **/ @@ -70,6 +76,38 @@ sealed class CborSerializerSection(name: String) : Section(name) { /** Manipulate the serializer context for a non-null member prior to it being serialized. **/ data class BeforeSerializingNonNullMember(val shape: Shape, val context: CborSerializerGenerator.MemberContext) : CborSerializerSection("BeforeSerializingNonNullMember") + + /** + * Allows specification of additional parameters in the function signature of the serializer. + * This customization point enables extending the serializer's interface with supplementary parameters + * needed for specialized serialization behaviors. + */ + data class AdditionalSerializingParameters(val structContext: CborSerializerGenerator.StructContext, val codegenContext: CodegenContext) : + CborSerializerSection("AdditionalSerializingParameters") + + /** + * Provides a way to specify additional arguments that should be passed when invoking the serializer. + * This customization point allows for passing through context-specific information needed during + * the serialization process. + */ + data class AdditionalSerializingArguments(val structContext: CborSerializerGenerator.StructContext, val codegenContext: CodegenContext) : + CborSerializerSection("AdditionalSerializingArguments") + + /** + * Customizes how a union variant's shape ID is encoded in the CBOR format. + * This section allows for specialized handling of union variant identification + * during serialization. + */ + data class CustomizeUnionMemberKeyEncode(val context: CborSerializerGenerator.MemberContext, val encoderBindingName: String, val codegenContext: CodegenContext) : + CborSerializerSection("CustomizeUnionMemberKeyEncode") + + /** + * Allows customization of the CBOR map length calculation for union types. + * This section provides control over how the size of the encoded union + * representation is determined, which may vary based on the serialization requirements. + */ + data class CustomizeUnionEncoderMapLength(val context: Context, val codegenContext: CodegenContext) : + CborSerializerSection("CustomizeUnionEncoderMapLength") } /** @@ -78,7 +116,7 @@ sealed class CborSerializerSection(name: String) : Section(name) { typealias CborSerializerCustomization = NamedCustomization class CborSerializerGenerator( - codegenContext: CodegenContext, + private val codegenContext: CodegenContext, private val httpBindingResolver: HttpBindingResolver, private val customizations: List = listOf(), ) : StructuredDataSerializerGenerator { @@ -136,9 +174,10 @@ class CborSerializerGenerator( fun unionMember( variantReference: String, member: MemberShape, + encodeKeyExpression: String = encodeKeyExpression(member.memberName), ): MemberContext = MemberContext( - encodeKeyExpression(member.memberName), + encodeKeyExpression, ValueExpression.Reference(variantReference), member, ) @@ -152,6 +191,7 @@ class CborSerializerGenerator( /** Name of the variable that holds the struct */ val localName: String, val shape: StructureShape, + val memberContext: MemberContext? = null, ) private val model = codegenContext.model @@ -281,25 +321,55 @@ class CborSerializerGenerator( return serverSerializer(errorShape, includedMembers, error = true) } + /** + * Retrieves customized parameters and arguments for struct serialization based on the provided context and section type. + * + * @param structContext The context containing information about the structure being serialized + * @param sectionType A function that takes a StructContext and CodegenContext and returns a CborSerializerSection. + * This determines which section of the serializer to customize (e.g., AdditionalSerializingArguments) + * @return A Writable containing the joined customized parameters/arguments as a comma-separated string with a leading comma, + * or an empty Writable if no customizations are present + */ + private fun getCustomizedParamsAndArgsForStructSerializer( + structContext: StructContext, + sectionType: (StructContext, CodegenContext) -> CborSerializerSection, + ) = customizations + .map { it.section(sectionType(structContext, codegenContext)) } + .filter { it.isNotEmpty() } // Remove any empty customizations. + .takeIf { it.isNotEmpty() } // Proceed only if there are remaining customizations. + ?.join(", ", prefix = ", ") // Join with commas and add leading comma. + ?: writable {} // Return empty writable if no customizations exist. + private fun RustWriter.serializeStructure( context: StructContext, includedMembers: List? = null, ) { if (context.shape.isUnit()) { - rust( - """ - encoder.begin_map(); - encoder.end(); - """, - ) + rust("encoder.begin_map();") + for (customization in customizations) { + customization.section( + CborSerializerSection.BeforeSerializingStructureMembers( + context, + "encoder", + codegenContext, + ), + )(this) + } + rust("encoder.end();") return } val structureSerializer = protocolFunctions.serializeFn(context.shape) { fnName -> + val paramsWritable = + getCustomizedParamsAndArgsForStructSerializer( + context, + CborSerializerSection::AdditionalSerializingParameters, + ) rustBlockTemplate( - "pub fn $fnName(encoder: &mut #{Encoder}, ##[allow(unused)] input: &#{StructureSymbol}) -> #{Result}<(), #{Error}>", + "pub fn $fnName(encoder: &mut #{Encoder}, ##[allow(unused)] input: &#{StructureSymbol} #{Params}) -> #{Result}<(), #{Error}>", "StructureSymbol" to symbolProvider.toSymbol(context.shape), + "Params" to paramsWritable, *codegenScope, ) { // TODO(https://github.com/smithy-lang/smithy-rs/issues/3745) If all members are non-`Option`-al, @@ -308,8 +378,9 @@ class CborSerializerGenerator( for (customization in customizations) { customization.section( CborSerializerSection.BeforeSerializingStructureMembers( - context.shape, + context, "encoder", + codegenContext, ), )(this) } @@ -323,7 +394,17 @@ class CborSerializerGenerator( rust("Ok(())") } } - rust("#T(encoder, ${context.localName})?;", structureSerializer) + + val argsWritable = + getCustomizedParamsAndArgsForStructSerializer( + context, + CborSerializerSection::AdditionalSerializingArguments, + ) + rustTemplate( + "#{SerializingFunction}(encoder, ${context.localName} #{Args})?;", + "SerializingFunction" to structureSerializer, + "Args" to argsWritable, + ) } private fun RustWriter.serializeMember(context: MemberContext) { @@ -396,10 +477,17 @@ class CborSerializerGenerator( else -> { // This condition is equivalent to `containerShape !is CollectionShape`. if (containerShape is StructureShape || containerShape is UnionShape || containerShape is MapShape) { - rust("$encoder;") // Encode the member key. + val customizedMemberKeyWritable = + validateAndGetUniqueUnionVariantKeyCustomizedEncodingOrEmpty(context, encoder) + + if (customizedMemberKeyWritable.isNotEmpty()) { + customizedMemberKeyWritable(this) + } else { + rust("$encoder;") // Encode the member key. + } } when (target) { - is StructureShape -> serializeStructure(StructContext(value.asRef(), target)) + is StructureShape -> serializeStructure(StructContext(value.asRef(), target, context)) is CollectionShape -> serializeCollection(Context(value, target)) is MapShape -> serializeMap(Context(value, target)) is UnionShape -> serializeUnion(Context(value, target)) @@ -442,9 +530,21 @@ class CborSerializerGenerator( "UnionSymbol" to unionSymbol, *codegenScope, ) { - // A union is serialized identically as a `structure` shape, but only a single member can be set to a - // non-null value. - rust("encoder.map(1);") + // Processes customizations for encoding union variants. This determines how the variant + // type information is serialized in the generated code. + val customUnionEncoderLength = + validateAndGetUniqueUnionVariantEncoderLengthCustomizedEncodingOrEmpty(context) + + // Apply any custom variant encoding logic if customizations exist + // Otherwise fall back to default union variant serialization + if (customUnionEncoderLength.isNotEmpty()) { + customUnionEncoderLength(this) + } else { + // A union is serialized identically as a `structure` shape, but only a single member can be set to a + // non-null value. + rust("encoder.map(1);") + } + rustBlock("match input") { for (member in context.shape.members()) { val variantName = @@ -470,4 +570,75 @@ class CborSerializerGenerator( } rust("#T(encoder, ${context.valueExpression.asRef()})?;", unionSerializer) } + + /** + * Process and validate customizations for a CborSerializerSection. + * + * @param section The CborSerializerSection to validate + * @param customizationName A descriptive name for the customization type (used in error messages) + * @return A [Writable] containing the single valid customization, or an empty writable if none exist + * @throws IllegalArgumentException if multiple customizations are found + */ + private fun validateAndGetUniqueCustomizationOrEmpty( + section: CborSerializerSection, + customizationName: String, + ): Writable = + customizations.map { customization -> + customization.section(section) + } + .filter { it.isNotEmpty() } + .also { filteredCustomizations -> + if (filteredCustomizations.size > 1) { + throw IllegalArgumentException( + "Found ${filteredCustomizations.size} $customizationName customizations, but only one is allowed.", + ) + } + } + .firstOrNull() ?: writable {} + + /** + * Process customizations for union variant encoding, ensuring only one customization exists. + * + * This function processes all customizations to find those that modify how a union variant's + * map length is encoded. It enforces that at most one such customization exists to prevent + * conflicting encoding strategies. + * + * @throws IllegalArgumentException if multiple customizations are found. + * @return A [Writable] containing the custom encoding logic, or an empty writable if no + * customizations exist, in which case the default encoding should be used. + */ + private fun validateAndGetUniqueUnionVariantEncoderLengthCustomizedEncodingOrEmpty( + context: Context, + ): Writable = + validateAndGetUniqueCustomizationOrEmpty( + CborSerializerSection.CustomizeUnionEncoderMapLength( + context, + codegenContext, + ), + "union encoder map length", + ) + + /** + * Validates and retrieves a single customization for encoding a union variant's key. + * + * This function processes all customizations to find those that modify how a union variant's + * key is encoded. It enforces that at most one such customization exists to prevent + * conflicting encoding strategies. + * + * @throws IllegalArgumentException if multiple customizations are found. + * @return A [Writable] containing the custom encoding logic, or an empty writable if no + * customizations exist, in which case the default encoding should be used. + */ + private fun validateAndGetUniqueUnionVariantKeyCustomizedEncodingOrEmpty( + context: MemberContext, + encoder: String, + ): Writable = + validateAndGetUniqueCustomizationOrEmpty( + CborSerializerSection.CustomizeUnionMemberKeyEncode( + context, + encoder, + codegenContext, + ), + "union variant key", + ) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt index 69ec11fdd2830d9bdeb286b3928174759ed09808..f019be99d010a484f1f3752ba0eaf153d7057fa4 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGenerator.kt @@ -76,6 +76,10 @@ sealed class JsonSerializerSection(name: String) : Section(name) { /** Mutate the output object prior to finalization. */ data class OutputStruct(val structureShape: StructureShape, val jsonObject: String) : JsonSerializerSection("OutputStruct") + + /** Allow customizers to perform pre-serialization operations before handling union variants. */ + data class BeforeSerializeUnion(val shape: UnionShape, val jsonObject: String) : + JsonSerializerSection("BeforeSerializeUnion") } /** @@ -545,6 +549,10 @@ class JsonSerializerGenerator( "Input" to unionSymbol, *codegenScope, ) { + // Allow customizers to perform pre-serialization operations before handling union variants. + customizations.forEach { + it.section(JsonSerializerSection.BeforeSerializeUnion(context.shape, context.writerExpression))(this) + } rustBlock("match input") { for (member in context.shape.members()) { val variantName = diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt index 464a52dc463ae61ebd23bef0b70adfa17696752a..9fcca1062a6396670b33ebb2b4911c38a73b69b8 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AddTypeFieldToServerErrorsCborCustomization.kt @@ -42,13 +42,13 @@ class AddTypeFieldToServerErrorsCborCustomization : CborSerializerCustomization( override fun section(section: CborSerializerSection): Writable = when (section) { is CborSerializerSection.BeforeSerializingStructureMembers -> - if (section.structureShape.hasTrait()) { + if (section.structContext.shape.hasTrait()) { writable { rust( """ ${section.encoderBindingName} .str("__type") - .str("${escape(section.structureShape.id.toString())}"); + .str("${escape(section.structContext.shape.id.toString())}"); """, ) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index ac506aebd41981d20a1363b6142fa58df95a4fb2..43982b9b3e039d3dc07ed0306b64c70086b1584d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -448,5 +448,6 @@ class ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedCborPa rust("let v = v.into();") } } + else -> emptySection } } diff --git a/rust-runtime/aws-smithy-cbor/Cargo.toml b/rust-runtime/aws-smithy-cbor/Cargo.toml index 1225e10ff0143b8cebe47da184baf7560a51a57f..3c06f8a38207f90564340551b4a03fb4db5cebe2 100644 --- a/rust-runtime/aws-smithy-cbor/Cargo.toml +++ b/rust-runtime/aws-smithy-cbor/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "aws-smithy-cbor" -version = "0.60.8" +version = "0.61.0" authors = [ "AWS Rust SDK Team ", "David Pérez ", diff --git a/rust-runtime/aws-smithy-cbor/src/decode.rs b/rust-runtime/aws-smithy-cbor/src/decode.rs index 29ba86bc5d5285156660bd94314d7ce5efafff8a..3c6df7b91a8f389f7b9a28bf87c8f3d669074799 100644 --- a/rust-runtime/aws-smithy-cbor/src/decode.rs +++ b/rust-runtime/aws-smithy-cbor/src/decode.rs @@ -166,6 +166,11 @@ impl<'b> Decoder<'b> { self.decoder.position() } + /// Set the current decode position. + pub fn set_position(&mut self, pos: usize) { + self.decoder.set_position(pos) + } + /// Returns a `Cow::Borrowed(&str)` if the element at the current position in the buffer is a definite /// length string. Otherwise, it returns a `Cow::Owned(String)` if the element at the current position is an /// indefinite-length string. An error is returned if the element is neither a definite length nor an