Commit fd8c25cc authored by Fahad Zubair's avatar Fahad Zubair
Browse files

Use CBOR encoded string for marhsalling tests

parent d026ad1f
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ dependencies {
    implementation("org.jsoup:jsoup:1.16.2")
    api("software.amazon.smithy:smithy-codegen-core:$smithyVersion")
    api("com.moandjiezana.toml:toml4j:0.7.2")
    implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:2.13.0")
    implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
    implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
    implementation("software.amazon.smithy:smithy-waiters:$smithyVersion")
+18 −2
Original line number Diff line number Diff line
@@ -48,7 +48,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.inputShape
@@ -447,7 +446,24 @@ class CborParserGenerator(
    }

    override fun payloadParser(member: MemberShape): RuntimeType {
        UNREACHABLE("No protocol using CBOR serialization supports payload binding")
        val shape = model.expectShape(member.target)
        val returnSymbol = returnSymbolToParse(shape)
        check(shape is UnionShape || shape is StructureShape) {
            "Payload parser should only be used on structure and union shapes."
        }
        return protocolFunctions.deserializeFn(shape, fnNameSuffix = "payload") { fnName ->
            rustTemplate(
                """
                    pub(crate) fn $fnName(value: &[u8]) -> #{Result}<#{ReturnType}, #{Error}> {
                        let decoder = &mut #{Decoder}::new(value);
                        #{DeserializeMember}
                    }
                """,
                "ReturnType" to returnSymbol.symbol,
                "DeserializeMember" to deserializeMember(member),
                *codegenScope,
            )
        }
    }

    override fun operationParser(operationShape: OperationShape): RuntimeType? {
+20 −2
Original line number Diff line number Diff line
@@ -204,9 +204,26 @@ class CborSerializerGenerator(
        }
    }

    // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
    override fun payloadSerializer(member: MemberShape): RuntimeType {
        TODO("We only call this when serializing in event streams, which are not supported yet: https://github.com/smithy-lang/smithy-rs/issues/3573")
        val target = model.expectShape(member.target)
        return protocolFunctions.serializeFn(member, fnNameSuffix = "payload") { fnName ->
            rustBlockTemplate(
                "pub fn $fnName(input: &#{target}) -> std::result::Result<#{Vec}<u8>, #{Error}>",
                *codegenScope,
                "target" to symbolProvider.toSymbol(target),
            ) {
                rustTemplate("let mut encoder = #{Encoder}::new(#{Vec}::new());", *codegenScope)
                rustBlock("") {
                    rust("let encoder = &mut encoder;")
                    when (target) {
                        is StructureShape -> serializeStructure(StructContext("input", target))
                        is UnionShape -> serializeUnion(Context(ValueExpression.Reference("input"), target))
                        else -> throw IllegalStateException("CBOR payloadSerializer only supports structs and unions")
                    }
                }
                rustTemplate("#{Ok}(encoder.into_writer())", *codegenScope)
            }
        }
    }

    override fun unsetStructure(structure: StructureShape): RuntimeType =
@@ -223,6 +240,7 @@ class CborSerializerGenerator(
        }

        val httpDocumentMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT)

        val inputShape = operationShape.inputShape(model)
        return protocolFunctions.serializeFn(operationShape, fnNameSuffix = "input") { fnName ->
            rustBlockTemplate(
+61 −27
Original line number Diff line number Diff line
@@ -5,6 +5,8 @@

package software.amazon.smithy.rust.codegen.core.testutil

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.dataformat.cbor.CBORFactory
import software.amazon.smithy.model.Model
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
@@ -12,16 +14,18 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor
import java.util.Base64

private fun fillInBaseModel(
    protocolName: String,
    namespacedProtocolName: String,
    extraServiceAnnotations: String = "",
): String =
    """
    namespace test

    use smithy.framework#ValidationException
    use aws.protocols#$protocolName
    use $namespacedProtocolName

    union TestUnion {
        Foo: String,
@@ -86,22 +90,24 @@ private fun fillInBaseModel(
    }

    $extraServiceAnnotations
    @$protocolName
    @${namespacedProtocolName.substringAfter("#")}
    service TestService { version: "123", operations: [TestStreamOp] }
    """

object EventStreamTestModels {
    private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel()
    private fun restJson1(): Model = fillInBaseModel("aws.protocols#restJson1").asSmithyModel()

    private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel()
    private fun restXml(): Model = fillInBaseModel("aws.protocols#restXml").asSmithyModel()

    private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel()
    private fun awsJson11(): Model = fillInBaseModel("aws.protocols#awsJson1_1").asSmithyModel()

    private fun rpcv2Cbor(): Model = fillInBaseModel("smithy.protocols#rpcv2Cbor").asSmithyModel()

    private fun awsQuery(): Model =
        fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
        fillInBaseModel("aws.protocols#awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()

    private fun ec2Query(): Model =
        fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
        fillInBaseModel("aws.protocols#ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()

    data class TestCase(
        val protocolShapeId: String,
@@ -120,11 +126,21 @@ object EventStreamTestModels {
        override fun toString(): String = protocolShapeId
    }

    val TEST_CASES =
        listOf(
            //
            // restJson1
            //
    private fun base64Encode(input: ByteArray): String {
        val encodedBytes = Base64.getEncoder().encode(input)
        return String(encodedBytes)
    }

    private fun createCBORFromJSON(jsonString: String): ByteArray {
        val jsonMapper = ObjectMapper()
        val cborMapper = ObjectMapper(CBORFactory())
        // Parse JSON string to a generic type.
        val jsonData = jsonMapper.readValue(jsonString, Any::class.java)
        // Convert the parsed data to CBOR.
        return cborMapper.writeValueAsBytes(jsonData)
    }

    private val restJsonTestCase =
        TestCase(
            protocolShapeId = "aws.protocols#restJson1",
            model = restJson1(),
@@ -137,22 +153,40 @@ object EventStreamTestModels {
            validTestUnion = """{"Foo":"hello"}""",
            validSomeError = """{"Message":"some error"}""",
            validUnmodeledError = """{"Message":"unmodeled error"}""",
            ) { RestJson(it) },
        ) { RestJson(it) }

    val TEST_CASES =
        listOf(
            //
            // restJson1
            //
            restJsonTestCase,
            //
            // rpcV2Cbor
            //
            restJsonTestCase.copy(
                protocolShapeId = "smithy.protocols#rpcv2Cbor",
                model = rpcv2Cbor(),
                mediaType = "application/cbor",
                responseContentType = "application/cbor",
                eventStreamMessageContentType = "application/cbor",
                validTestStruct = base64Encode(createCBORFromJSON(restJsonTestCase.validTestStruct)),
                validMessageWithNoHeaderPayloadTraits = base64Encode(createCBORFromJSON(restJsonTestCase.validMessageWithNoHeaderPayloadTraits)),
                validTestUnion = base64Encode(createCBORFromJSON(restJsonTestCase.validTestUnion)),
                validSomeError = base64Encode(createCBORFromJSON(restJsonTestCase.validSomeError)),
                validUnmodeledError = base64Encode(createCBORFromJSON(restJsonTestCase.validUnmodeledError)),
                protocolBuilder = { RpcV2Cbor(it) },
            ),
            //
            // awsJson1_1
            //
            TestCase(
            restJsonTestCase.copy(
                protocolShapeId = "aws.protocols#awsJson1_1",
                model = awsJson11(),
                mediaType = "application/x-amz-json-1.1",
                requestContentType = "application/x-amz-json-1.1",
                responseContentType = "application/x-amz-json-1.1",
                eventStreamMessageContentType = "application/json",
                validTestStruct = """{"someString":"hello","someInt":5}""",
                validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
                validTestUnion = """{"Foo":"hello"}""",
                validSomeError = """{"Message":"some error"}""",
                validUnmodeledError = """{"Message":"unmodeled error"}""",
            ) { AwsJson(it, AwsJsonVersion.Json11) },
            //
            // restXml
+3 −2
Original line number Diff line number Diff line
@@ -134,8 +134,7 @@ open class ServerCodegenVisitor(
                .protocolFor(context.model, service)
        this.protocolGeneratorFactory = protocolGeneratorFactory

        val protocolTransformedModel = ServerProtocolBasedTransformationFactory.createTransformer(protocolShape).transform(baseModel, service)
        model = codegenDecorator.transformModel(service, protocolTransformedModel, settings)
        model = codegenDecorator.transformModel(service, baseModel, settings)

        val serverSymbolProviders =
            ServerSymbolProviders.from(
@@ -210,6 +209,8 @@ open class ServerCodegenVisitor(
            .let { AttachValidationExceptionToConstrainedOperationInputs.transform(it, settings) }
            // Tag aggregate shapes reachable from operation input
            .let(ShapesReachableFromOperationInputTagger::transform)
            // Remove traits that are not supported by the chosen protocol
            .let { ServerProtocolBasedTransformationFactory.transform(it, settings) }
            // Normalize event stream operations
            .let(EventStreamNormalizer::transform)

Loading