Loading codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +21 −1 Original line number Diff line number Diff line Loading @@ -66,6 +66,10 @@ sealed class CborSerializerSection(name: String) : Section(name) { /** Manipulate the serializer context for a map prior to it being serialized. **/ data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context<Shape>) : CborSerializerSection("BeforeIteratingOverMapOrCollection") /** Manipulate the serializer context for a non-null member prior to it being serialized. **/ data class BeforeSerializingNonNullMember(val shape: Shape, val context: CborSerializerGenerator.MemberContext) : CborSerializerSection("BeforeSerializingNonNullMember") } /** Loading Loading @@ -311,6 +315,7 @@ class CborSerializerGenerator( safeName().also { local -> rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") { context.valueExpression = ValueExpression.Reference(local) resolveValueExpressionForConstrainedType(targetShape, context) serializeMemberValue(context, targetShape) } if (context.writeNulls) { Loading @@ -320,6 +325,7 @@ class CborSerializerGenerator( } } } else { resolveValueExpressionForConstrainedType(targetShape, context) with(serializerUtil) { ignoreDefaultsForNumbersAndBools(context.shape, context.valueExpression) { serializeMemberValue(context, targetShape) Loading @@ -328,6 +334,20 @@ class CborSerializerGenerator( } } private fun RustWriter.resolveValueExpressionForConstrainedType( targetShape: Shape, context: MemberContext, ) { for (customization in customizations) { customization.section( CborSerializerSection.BeforeSerializingNonNullMember( targetShape, context, ), )(this) } } private fun RustWriter.serializeMemberValue( context: MemberContext, target: Shape, Loading Loading @@ -362,7 +382,7 @@ class CborSerializerGenerator( rust("$encoder;") // Encode the member key. } when (target) { is StructureShape -> serializeStructure(StructContext(value.name, target)) is StructureShape -> serializeStructure(StructContext(value.asRef(), target)) is CollectionShape -> serializeCollection(Context(value, target)) is MapShape -> serializeMap(Context(value, target)) is UnionShape -> serializeUnion(Context(value, target)) Loading codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +3 −1 Original line number Diff line number Diff line Loading @@ -90,6 +90,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.transformers.AttachVali import software.amazon.smithy.rust.codegen.server.smithy.transformers.ConstrainedMemberTransform import software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxer import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsModelValidationException import software.amazon.smithy.rust.codegen.server.smithy.transformers.ServerProtocolBasedTransformationFactory import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger import java.util.logging.Logger Loading Loading @@ -133,7 +134,8 @@ open class ServerCodegenVisitor( .protocolFor(context.model, service) this.protocolGeneratorFactory = protocolGeneratorFactory model = codegenDecorator.transformModel(service, baseModel, settings) val protocolTransformedModel = ServerProtocolBasedTransformationFactory.createTransformer(protocolShape).transform(baseModel, service) model = codegenDecorator.transformModel(service, protocolTransformedModel, settings) val serverSymbolProviders = ServerSymbolProviders.from( Loading codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt 0 → 100644 +36 −0 Original line number Diff line number Diff line package software.amazon.smithy.rust.codegen.server.smithy.customizations import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.ByteShape import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.LongShape import software.amazon.smithy.model.shapes.ShortShape import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType class BeforeSerializingMemberCborCustomization(private val codegenContext: ServerCodegenContext) : CborSerializerCustomization() { override fun section(section: CborSerializerSection): Writable = when (section) { is CborSerializerSection.BeforeSerializingNonNullMember -> writable { if (workingWithPublicConstrainedWrapperTupleType( section.shape, codegenContext.model, codegenContext.settings.codegenConfig.publicConstrainedTypes, ) ) { if (section.shape is IntegerShape || section.shape is ShortShape || section.shape is LongShape || section.shape is ByteShape || section.shape is BlobShape) { section.context.valueExpression = ValueExpression.Reference("&${section.context.valueExpression.name}.0") } } } else -> emptySection } } codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +2 −0 Original line number Diff line number Diff line Loading @@ -47,6 +47,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddTypeFieldToServerErrorsCborCustomization import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeEncodingMapOrCollectionCborCustomization import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeSerializingMemberCborCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.http.RestRequestSpecGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator Loading Loading @@ -349,6 +350,7 @@ class ServerRpcV2CborProtocol( listOf( BeforeEncodingMapOrCollectionCborCustomization(serverCodegenContext), AddTypeFieldToServerErrorsCborCustomization(), BeforeSerializingMemberCborCustomization(serverCodegenContext), ), ) } Loading codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt 0 → 100644 +82 −0 Original line number Diff line number Diff line package software.amazon.smithy.rust.codegen.server.smithy.transformers import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.AbstractShapeBuilder import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.traits.HttpLabelTrait import software.amazon.smithy.model.traits.HttpPayloadTrait import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.utils.SmithyBuilder import software.amazon.smithy.utils.ToSmithyBuilder /** * Each protocol may not support all of the features that Smithy allows. For instance, most * RPC protocols do not support HTTP bindings. `ServerProtocolBasedTransformationFactory` is a factory * object that transforms the model and removes specific traits based on the protocol being instantiated. */ object ServerProtocolBasedTransformationFactory { fun createTransformer(protocolShapeId: ShapeId): Transformer = when (protocolShapeId) { Rpcv2CborTrait.ID -> Rpcv2Transformer() else -> IdentityTransformer() } interface Transformer { fun transform( model: Model, service: ServiceShape, ): Model } fun <T : Shape, B> T.removeTraitIfPresent( traitId: ShapeId, ): T where T : ToSmithyBuilder<T>, B : AbstractShapeBuilder<B, T>, B : SmithyBuilder<T> { return if (this.hasTrait(traitId)) { @Suppress("UNCHECKED_CAST") (this.toBuilder() as B).removeTrait(traitId).build() } else { this } } class Rpcv2Transformer() : Transformer { override fun transform( model: Model, service: ServiceShape, ): Model { val transformedModel = ModelTransformer.create().mapShapes(model) { shape -> when (shape) { is OperationShape -> shape.removeTraitIfPresent(HttpTrait.ID) is MemberShape -> { shape .removeTraitIfPresent(HttpLabelTrait.ID) .removeTraitIfPresent(HttpPayloadTrait.ID) } else -> shape } } return transformedModel } } class IdentityTransformer() : Transformer { override fun transform( model: Model, service: ServiceShape, ): Model { return model } } } Loading
codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +21 −1 Original line number Diff line number Diff line Loading @@ -66,6 +66,10 @@ sealed class CborSerializerSection(name: String) : Section(name) { /** Manipulate the serializer context for a map prior to it being serialized. **/ data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context<Shape>) : CborSerializerSection("BeforeIteratingOverMapOrCollection") /** Manipulate the serializer context for a non-null member prior to it being serialized. **/ data class BeforeSerializingNonNullMember(val shape: Shape, val context: CborSerializerGenerator.MemberContext) : CborSerializerSection("BeforeSerializingNonNullMember") } /** Loading Loading @@ -311,6 +315,7 @@ class CborSerializerGenerator( safeName().also { local -> rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") { context.valueExpression = ValueExpression.Reference(local) resolveValueExpressionForConstrainedType(targetShape, context) serializeMemberValue(context, targetShape) } if (context.writeNulls) { Loading @@ -320,6 +325,7 @@ class CborSerializerGenerator( } } } else { resolveValueExpressionForConstrainedType(targetShape, context) with(serializerUtil) { ignoreDefaultsForNumbersAndBools(context.shape, context.valueExpression) { serializeMemberValue(context, targetShape) Loading @@ -328,6 +334,20 @@ class CborSerializerGenerator( } } private fun RustWriter.resolveValueExpressionForConstrainedType( targetShape: Shape, context: MemberContext, ) { for (customization in customizations) { customization.section( CborSerializerSection.BeforeSerializingNonNullMember( targetShape, context, ), )(this) } } private fun RustWriter.serializeMemberValue( context: MemberContext, target: Shape, Loading Loading @@ -362,7 +382,7 @@ class CborSerializerGenerator( rust("$encoder;") // Encode the member key. } when (target) { is StructureShape -> serializeStructure(StructContext(value.name, target)) is StructureShape -> serializeStructure(StructContext(value.asRef(), target)) is CollectionShape -> serializeCollection(Context(value, target)) is MapShape -> serializeMap(Context(value, target)) is UnionShape -> serializeUnion(Context(value, target)) Loading
codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +3 −1 Original line number Diff line number Diff line Loading @@ -90,6 +90,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.transformers.AttachVali import software.amazon.smithy.rust.codegen.server.smithy.transformers.ConstrainedMemberTransform import software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxer import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsModelValidationException import software.amazon.smithy.rust.codegen.server.smithy.transformers.ServerProtocolBasedTransformationFactory import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger import java.util.logging.Logger Loading Loading @@ -133,7 +134,8 @@ open class ServerCodegenVisitor( .protocolFor(context.model, service) this.protocolGeneratorFactory = protocolGeneratorFactory model = codegenDecorator.transformModel(service, baseModel, settings) val protocolTransformedModel = ServerProtocolBasedTransformationFactory.createTransformer(protocolShape).transform(baseModel, service) model = codegenDecorator.transformModel(service, protocolTransformedModel, settings) val serverSymbolProviders = ServerSymbolProviders.from( Loading
codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/BeforeSerializingMemberCborCustomization.kt 0 → 100644 +36 −0 Original line number Diff line number Diff line package software.amazon.smithy.rust.codegen.server.smithy.customizations import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.ByteShape import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.LongShape import software.amazon.smithy.model.shapes.ShortShape import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType class BeforeSerializingMemberCborCustomization(private val codegenContext: ServerCodegenContext) : CborSerializerCustomization() { override fun section(section: CborSerializerSection): Writable = when (section) { is CborSerializerSection.BeforeSerializingNonNullMember -> writable { if (workingWithPublicConstrainedWrapperTupleType( section.shape, codegenContext.model, codegenContext.settings.codegenConfig.publicConstrainedTypes, ) ) { if (section.shape is IntegerShape || section.shape is ShortShape || section.shape is LongShape || section.shape is ByteShape || section.shape is BlobShape) { section.context.valueExpression = ValueExpression.Reference("&${section.context.valueExpression.name}.0") } } } else -> emptySection } }
codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +2 −0 Original line number Diff line number Diff line Loading @@ -47,6 +47,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddTypeFieldToServerErrorsCborCustomization import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeEncodingMapOrCollectionCborCustomization import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeSerializingMemberCborCustomization import software.amazon.smithy.rust.codegen.server.smithy.generators.http.RestRequestSpecGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator Loading Loading @@ -349,6 +350,7 @@ class ServerRpcV2CborProtocol( listOf( BeforeEncodingMapOrCollectionCborCustomization(serverCodegenContext), AddTypeFieldToServerErrorsCborCustomization(), BeforeSerializingMemberCborCustomization(serverCodegenContext), ), ) } Loading
codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ServerProtocolBasedTransformationFactory.kt 0 → 100644 +82 −0 Original line number Diff line number Diff line package software.amazon.smithy.rust.codegen.server.smithy.transformers import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.AbstractShapeBuilder import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.traits.HttpLabelTrait import software.amazon.smithy.model.traits.HttpPayloadTrait import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.protocol.traits.Rpcv2CborTrait import software.amazon.smithy.utils.SmithyBuilder import software.amazon.smithy.utils.ToSmithyBuilder /** * Each protocol may not support all of the features that Smithy allows. For instance, most * RPC protocols do not support HTTP bindings. `ServerProtocolBasedTransformationFactory` is a factory * object that transforms the model and removes specific traits based on the protocol being instantiated. */ object ServerProtocolBasedTransformationFactory { fun createTransformer(protocolShapeId: ShapeId): Transformer = when (protocolShapeId) { Rpcv2CborTrait.ID -> Rpcv2Transformer() else -> IdentityTransformer() } interface Transformer { fun transform( model: Model, service: ServiceShape, ): Model } fun <T : Shape, B> T.removeTraitIfPresent( traitId: ShapeId, ): T where T : ToSmithyBuilder<T>, B : AbstractShapeBuilder<B, T>, B : SmithyBuilder<T> { return if (this.hasTrait(traitId)) { @Suppress("UNCHECKED_CAST") (this.toBuilder() as B).removeTrait(traitId).build() } else { this } } class Rpcv2Transformer() : Transformer { override fun transform( model: Model, service: ServiceShape, ): Model { val transformedModel = ModelTransformer.create().mapShapes(model) { shape -> when (shape) { is OperationShape -> shape.removeTraitIfPresent(HttpTrait.ID) is MemberShape -> { shape .removeTraitIfPresent(HttpLabelTrait.ID) .removeTraitIfPresent(HttpPayloadTrait.ID) } else -> shape } } return transformedModel } } class IdentityTransformer() : Transformer { override fun transform( model: Model, service: ServiceShape, ): Model { return model } } }