Commit d026ad1f authored by Fahad Zubair's avatar Fahad Zubair
Browse files

Use constraints.smithy with CBor

parent b38ccb96
Loading
Loading
Loading
Loading
+21 −1
Original line number Diff line number Diff line
@@ -66,6 +66,10 @@ sealed class CborSerializerSection(name: String) : Section(name) {
    /** Manipulate the serializer context for a map prior to it being serialized. **/
    data class BeforeIteratingOverMapOrCollection(val shape: Shape, val context: CborSerializerGenerator.Context<Shape>) :
        CborSerializerSection("BeforeIteratingOverMapOrCollection")

    /** Manipulate the serializer context for a non-null member prior to it being serialized. **/
    data class BeforeSerializingNonNullMember(val shape: Shape, val context: CborSerializerGenerator.MemberContext) :
        CborSerializerSection("BeforeSerializingNonNullMember")
}

/**
@@ -311,6 +315,7 @@ class CborSerializerGenerator(
            safeName().also { local ->
                rustBlock("if let Some($local) = ${context.valueExpression.asRef()}") {
                    context.valueExpression = ValueExpression.Reference(local)
                    resolveValueExpressionForConstrainedType(targetShape, context)
                    serializeMemberValue(context, targetShape)
                }
                if (context.writeNulls) {
@@ -320,6 +325,7 @@ class CborSerializerGenerator(
                }
            }
        } else {
            resolveValueExpressionForConstrainedType(targetShape, context)
            with(serializerUtil) {
                ignoreDefaultsForNumbersAndBools(context.shape, context.valueExpression) {
                    serializeMemberValue(context, targetShape)
@@ -328,6 +334,20 @@ class CborSerializerGenerator(
        }
    }

    private fun RustWriter.resolveValueExpressionForConstrainedType(
        targetShape: Shape,
        context: MemberContext,
    ) {
        for (customization in customizations) {
            customization.section(
                CborSerializerSection.BeforeSerializingNonNullMember(
                    targetShape,
                    context,
                ),
            )(this)
        }
    }

    private fun RustWriter.serializeMemberValue(
        context: MemberContext,
        target: Shape,
@@ -362,7 +382,7 @@ class CborSerializerGenerator(
                    rust("$encoder;") // Encode the member key.
                }
                when (target) {
                    is StructureShape -> serializeStructure(StructContext(value.name, target))
                    is StructureShape -> serializeStructure(StructContext(value.asRef(), target))
                    is CollectionShape -> serializeCollection(Context(value, target))
                    is MapShape -> serializeMap(Context(value, target))
                    is UnionShape -> serializeUnion(Context(value, target))
+3 −1
Original line number Diff line number Diff line
@@ -90,6 +90,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.transformers.AttachVali
import software.amazon.smithy.rust.codegen.server.smithy.transformers.ConstrainedMemberTransform
import software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxer
import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsModelValidationException
import software.amazon.smithy.rust.codegen.server.smithy.transformers.ServerProtocolBasedTransformationFactory
import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger
import java.util.logging.Logger

@@ -133,7 +134,8 @@ open class ServerCodegenVisitor(
                .protocolFor(context.model, service)
        this.protocolGeneratorFactory = protocolGeneratorFactory

        model = codegenDecorator.transformModel(service, baseModel, settings)
        val protocolTransformedModel = ServerProtocolBasedTransformationFactory.createTransformer(protocolShape).transform(baseModel, service)
        model = codegenDecorator.transformModel(service, protocolTransformedModel, settings)

        val serverSymbolProviders =
            ServerSymbolProviders.from(
+36 −0
Original line number Diff line number Diff line
package software.amazon.smithy.rust.codegen.server.smithy.customizations

import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.ByteShape
import software.amazon.smithy.model.shapes.IntegerShape
import software.amazon.smithy.model.shapes.LongShape
import software.amazon.smithy.model.shapes.ShortShape
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerCustomization
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.CborSerializerSection
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.workingWithPublicConstrainedWrapperTupleType

class BeforeSerializingMemberCborCustomization(private val codegenContext: ServerCodegenContext) : CborSerializerCustomization() {
    override fun section(section: CborSerializerSection): Writable =
        when (section) {
            is CborSerializerSection.BeforeSerializingNonNullMember ->
                writable {
                    if (workingWithPublicConstrainedWrapperTupleType(
                            section.shape,
                            codegenContext.model,
                            codegenContext.settings.codegenConfig.publicConstrainedTypes,
                        )
                    ) {
                        if (section.shape is IntegerShape || section.shape is ShortShape || section.shape is LongShape || section.shape is ByteShape || section.shape is BlobShape) {
                            section.context.valueExpression =
                                ValueExpression.Reference("&${section.context.valueExpression.name}.0")
                        }
                    }
                }

            else -> emptySection
        }
}
+2 −0
Original line number Diff line number Diff line
@@ -47,6 +47,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddTypeFieldToServerErrorsCborCustomization
import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeEncodingMapOrCollectionCborCustomization
import software.amazon.smithy.rust.codegen.server.smithy.customizations.BeforeSerializingMemberCborCustomization
import software.amazon.smithy.rust.codegen.server.smithy.generators.http.RestRequestSpecGenerator
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator
@@ -349,6 +350,7 @@ class ServerRpcV2CborProtocol(
            listOf(
                BeforeEncodingMapOrCollectionCborCustomization(serverCodegenContext),
                AddTypeFieldToServerErrorsCborCustomization(),
                BeforeSerializingMemberCborCustomization(serverCodegenContext),
            ),
        )
    }
+82 −0
Original line number Diff line number Diff line
package software.amazon.smithy.rust.codegen.server.smithy.transformers

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.AbstractShapeBuilder
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.HttpLabelTrait
import software.amazon.smithy.model.traits.HttpPayloadTrait
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
import software.amazon.smithy.utils.SmithyBuilder
import software.amazon.smithy.utils.ToSmithyBuilder

/**
 * Each protocol may not support all of the features that Smithy allows. For instance, most
 * RPC protocols do not support HTTP bindings. `ServerProtocolBasedTransformationFactory` is a factory
 * object that transforms the model and removes specific traits based on the protocol being instantiated.
 */
object ServerProtocolBasedTransformationFactory {
    fun createTransformer(protocolShapeId: ShapeId): Transformer =
        when (protocolShapeId) {
            Rpcv2CborTrait.ID -> Rpcv2Transformer()
            else -> IdentityTransformer()
        }

    interface Transformer {
        fun transform(
            model: Model,
            service: ServiceShape,
        ): Model
    }

    fun <T : Shape, B> T.removeTraitIfPresent(
        traitId: ShapeId,
    ): T
        where T : ToSmithyBuilder<T>,
              B : AbstractShapeBuilder<B, T>,
              B : SmithyBuilder<T> {
        return if (this.hasTrait(traitId)) {
            @Suppress("UNCHECKED_CAST")
            (this.toBuilder() as B).removeTrait(traitId).build()
        } else {
            this
        }
    }

    class Rpcv2Transformer() : Transformer {
        override fun transform(
            model: Model,
            service: ServiceShape,
        ): Model {
            val transformedModel =
                ModelTransformer.create().mapShapes(model) { shape ->
                    when (shape) {
                        is OperationShape -> shape.removeTraitIfPresent(HttpTrait.ID)
                        is MemberShape -> {
                            shape
                                .removeTraitIfPresent(HttpLabelTrait.ID)
                                .removeTraitIfPresent(HttpPayloadTrait.ID)
                        }

                        else -> shape
                    }
                }

            return transformedModel
        }
    }

    class IdentityTransformer() : Transformer {
        override fun transform(
            model: Model,
            service: ServiceShape,
        ): Model {
            return model
        }
    }
}
Loading