Loading codegen-core/build.gradle.kts +1 −0 Original line number Diff line number Diff line Loading @@ -25,6 +25,7 @@ dependencies { implementation("org.jsoup:jsoup:1.16.2") api("software.amazon.smithy:smithy-codegen-core:$smithyVersion") api("com.moandjiezana.toml:toml4j:0.7.2") implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:2.13.0") implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") Loading codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +18 −2 Original line number Diff line number Diff line Loading @@ -48,7 +48,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.inputShape Loading Loading @@ -447,7 +446,24 @@ class CborParserGenerator( } override fun payloadParser(member: MemberShape): RuntimeType { UNREACHABLE("No protocol using CBOR serialization supports payload binding") val shape = model.expectShape(member.target) val returnSymbol = returnSymbolToParse(shape) check(shape is UnionShape || shape is StructureShape) { "Payload parser should only be used on structure and union shapes." } return protocolFunctions.deserializeFn(shape, fnNameSuffix = "payload") { fnName -> rustTemplate( """ pub(crate) fn $fnName(value: &[u8]) -> #{Result}<#{ReturnType}, #{Error}> { let decoder = &mut #{Decoder}::new(value); #{DeserializeMember} } """, "ReturnType" to returnSymbol.symbol, "DeserializeMember" to deserializeMember(member), *codegenScope, ) } } override fun operationParser(operationShape: OperationShape): RuntimeType? { Loading codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +20 −2 Original line number Diff line number Diff line Loading @@ -204,9 +204,26 @@ class CborSerializerGenerator( } } // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) override fun payloadSerializer(member: MemberShape): RuntimeType { TODO("We only call this when serializing in event streams, which are not supported yet: https://github.com/smithy-lang/smithy-rs/issues/3573") val target = model.expectShape(member.target) return protocolFunctions.serializeFn(member, fnNameSuffix = "payload") { fnName -> rustBlockTemplate( "pub fn $fnName(input: &#{target}) -> std::result::Result<#{Vec}<u8>, #{Error}>", *codegenScope, "target" to symbolProvider.toSymbol(target), ) { rustTemplate("let mut encoder = #{Encoder}::new(#{Vec}::new());", *codegenScope) rustBlock("") { rust("let encoder = &mut encoder;") when (target) { is StructureShape -> serializeStructure(StructContext("input", target)) is UnionShape -> serializeUnion(Context(ValueExpression.Reference("input"), target)) else -> throw IllegalStateException("CBOR payloadSerializer only supports structs and unions") } } rustTemplate("#{Ok}(encoder.into_writer())", *codegenScope) } } } override fun unsetStructure(structure: StructureShape): RuntimeType = Loading @@ -223,6 +240,7 @@ class CborSerializerGenerator( } val httpDocumentMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) val inputShape = operationShape.inputShape(model) return protocolFunctions.serializeFn(operationShape, fnNameSuffix = "input") { fnName -> rustBlockTemplate( Loading codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt +61 −27 Original line number Diff line number Diff line Loading @@ -5,6 +5,8 @@ package software.amazon.smithy.rust.codegen.core.testutil import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.dataformat.cbor.CBORFactory import software.amazon.smithy.model.Model import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson Loading @@ -12,16 +14,18 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor import java.util.Base64 private fun fillInBaseModel( protocolName: String, namespacedProtocolName: String, extraServiceAnnotations: String = "", ): String = """ namespace test use smithy.framework#ValidationException use aws.protocols#$protocolName use $namespacedProtocolName union TestUnion { Foo: String, Loading Loading @@ -86,22 +90,24 @@ private fun fillInBaseModel( } $extraServiceAnnotations @$protocolName @${namespacedProtocolName.substringAfter("#")} service TestService { version: "123", operations: [TestStreamOp] } """ object EventStreamTestModels { private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel() private fun restJson1(): Model = fillInBaseModel("aws.protocols#restJson1").asSmithyModel() private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel() private fun restXml(): Model = fillInBaseModel("aws.protocols#restXml").asSmithyModel() private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel() private fun awsJson11(): Model = fillInBaseModel("aws.protocols#awsJson1_1").asSmithyModel() private fun rpcv2Cbor(): Model = fillInBaseModel("smithy.protocols#rpcv2Cbor").asSmithyModel() private fun awsQuery(): Model = fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() fillInBaseModel("aws.protocols#awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() private fun ec2Query(): Model = fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() fillInBaseModel("aws.protocols#ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() data class TestCase( val protocolShapeId: String, Loading @@ -120,11 +126,21 @@ object EventStreamTestModels { override fun toString(): String = protocolShapeId } val TEST_CASES = listOf( // // restJson1 // private fun base64Encode(input: ByteArray): String { val encodedBytes = Base64.getEncoder().encode(input) return String(encodedBytes) } private fun createCBORFromJSON(jsonString: String): ByteArray { val jsonMapper = ObjectMapper() val cborMapper = ObjectMapper(CBORFactory()) // Parse JSON string to a generic type. val jsonData = jsonMapper.readValue(jsonString, Any::class.java) // Convert the parsed data to CBOR. return cborMapper.writeValueAsBytes(jsonData) } private val restJsonTestCase = TestCase( protocolShapeId = "aws.protocols#restJson1", model = restJson1(), Loading @@ -137,22 +153,40 @@ object EventStreamTestModels { validTestUnion = """{"Foo":"hello"}""", validSomeError = """{"Message":"some error"}""", validUnmodeledError = """{"Message":"unmodeled error"}""", ) { RestJson(it) }, ) { RestJson(it) } val TEST_CASES = listOf( // // restJson1 // restJsonTestCase, // // rpcV2Cbor // restJsonTestCase.copy( protocolShapeId = "smithy.protocols#rpcv2Cbor", model = rpcv2Cbor(), mediaType = "application/cbor", responseContentType = "application/cbor", eventStreamMessageContentType = "application/cbor", validTestStruct = base64Encode(createCBORFromJSON(restJsonTestCase.validTestStruct)), validMessageWithNoHeaderPayloadTraits = base64Encode(createCBORFromJSON(restJsonTestCase.validMessageWithNoHeaderPayloadTraits)), validTestUnion = base64Encode(createCBORFromJSON(restJsonTestCase.validTestUnion)), validSomeError = base64Encode(createCBORFromJSON(restJsonTestCase.validSomeError)), validUnmodeledError = base64Encode(createCBORFromJSON(restJsonTestCase.validUnmodeledError)), protocolBuilder = { RpcV2Cbor(it) }, ), // // awsJson1_1 // TestCase( restJsonTestCase.copy( protocolShapeId = "aws.protocols#awsJson1_1", model = awsJson11(), mediaType = "application/x-amz-json-1.1", requestContentType = "application/x-amz-json-1.1", responseContentType = "application/x-amz-json-1.1", eventStreamMessageContentType = "application/json", validTestStruct = """{"someString":"hello","someInt":5}""", validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", validTestUnion = """{"Foo":"hello"}""", validSomeError = """{"Message":"some error"}""", validUnmodeledError = """{"Message":"unmodeled error"}""", ) { AwsJson(it, AwsJsonVersion.Json11) }, // // restXml Loading codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +3 −2 Original line number Diff line number Diff line Loading @@ -134,8 +134,7 @@ open class ServerCodegenVisitor( .protocolFor(context.model, service) this.protocolGeneratorFactory = protocolGeneratorFactory val protocolTransformedModel = ServerProtocolBasedTransformationFactory.createTransformer(protocolShape).transform(baseModel, service) model = codegenDecorator.transformModel(service, protocolTransformedModel, settings) model = codegenDecorator.transformModel(service, baseModel, settings) val serverSymbolProviders = ServerSymbolProviders.from( Loading Loading @@ -210,6 +209,8 @@ open class ServerCodegenVisitor( .let { AttachValidationExceptionToConstrainedOperationInputs.transform(it, settings) } // Tag aggregate shapes reachable from operation input .let(ShapesReachableFromOperationInputTagger::transform) // Remove traits that are not supported by the chosen protocol .let { ServerProtocolBasedTransformationFactory.transform(it, settings) } // Normalize event stream operations .let(EventStreamNormalizer::transform) Loading Loading
codegen-core/build.gradle.kts +1 −0 Original line number Diff line number Diff line Loading @@ -25,6 +25,7 @@ dependencies { implementation("org.jsoup:jsoup:1.16.2") api("software.amazon.smithy:smithy-codegen-core:$smithyVersion") api("com.moandjiezana.toml:toml4j:0.7.2") implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:2.13.0") implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") Loading
codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/CborParserGenerator.kt +18 −2 Original line number Diff line number Diff line Loading @@ -48,7 +48,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.inputShape Loading Loading @@ -447,7 +446,24 @@ class CborParserGenerator( } override fun payloadParser(member: MemberShape): RuntimeType { UNREACHABLE("No protocol using CBOR serialization supports payload binding") val shape = model.expectShape(member.target) val returnSymbol = returnSymbolToParse(shape) check(shape is UnionShape || shape is StructureShape) { "Payload parser should only be used on structure and union shapes." } return protocolFunctions.deserializeFn(shape, fnNameSuffix = "payload") { fnName -> rustTemplate( """ pub(crate) fn $fnName(value: &[u8]) -> #{Result}<#{ReturnType}, #{Error}> { let decoder = &mut #{Decoder}::new(value); #{DeserializeMember} } """, "ReturnType" to returnSymbol.symbol, "DeserializeMember" to deserializeMember(member), *codegenScope, ) } } override fun operationParser(operationShape: OperationShape): RuntimeType? { Loading
codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/CborSerializerGenerator.kt +20 −2 Original line number Diff line number Diff line Loading @@ -204,9 +204,26 @@ class CborSerializerGenerator( } } // TODO(https://github.com/smithy-lang/smithy-rs/issues/3573) override fun payloadSerializer(member: MemberShape): RuntimeType { TODO("We only call this when serializing in event streams, which are not supported yet: https://github.com/smithy-lang/smithy-rs/issues/3573") val target = model.expectShape(member.target) return protocolFunctions.serializeFn(member, fnNameSuffix = "payload") { fnName -> rustBlockTemplate( "pub fn $fnName(input: &#{target}) -> std::result::Result<#{Vec}<u8>, #{Error}>", *codegenScope, "target" to symbolProvider.toSymbol(target), ) { rustTemplate("let mut encoder = #{Encoder}::new(#{Vec}::new());", *codegenScope) rustBlock("") { rust("let encoder = &mut encoder;") when (target) { is StructureShape -> serializeStructure(StructContext("input", target)) is UnionShape -> serializeUnion(Context(ValueExpression.Reference("input"), target)) else -> throw IllegalStateException("CBOR payloadSerializer only supports structs and unions") } } rustTemplate("#{Ok}(encoder.into_writer())", *codegenScope) } } } override fun unsetStructure(structure: StructureShape): RuntimeType = Loading @@ -223,6 +240,7 @@ class CborSerializerGenerator( } val httpDocumentMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) val inputShape = operationShape.inputShape(model) return protocolFunctions.serializeFn(operationShape, fnNameSuffix = "input") { fnName -> rustBlockTemplate( Loading
codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt +61 −27 Original line number Diff line number Diff line Loading @@ -5,6 +5,8 @@ package software.amazon.smithy.rust.codegen.core.testutil import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.dataformat.cbor.CBORFactory import software.amazon.smithy.model.Model import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson Loading @@ -12,16 +14,18 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor import java.util.Base64 private fun fillInBaseModel( protocolName: String, namespacedProtocolName: String, extraServiceAnnotations: String = "", ): String = """ namespace test use smithy.framework#ValidationException use aws.protocols#$protocolName use $namespacedProtocolName union TestUnion { Foo: String, Loading Loading @@ -86,22 +90,24 @@ private fun fillInBaseModel( } $extraServiceAnnotations @$protocolName @${namespacedProtocolName.substringAfter("#")} service TestService { version: "123", operations: [TestStreamOp] } """ object EventStreamTestModels { private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel() private fun restJson1(): Model = fillInBaseModel("aws.protocols#restJson1").asSmithyModel() private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel() private fun restXml(): Model = fillInBaseModel("aws.protocols#restXml").asSmithyModel() private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel() private fun awsJson11(): Model = fillInBaseModel("aws.protocols#awsJson1_1").asSmithyModel() private fun rpcv2Cbor(): Model = fillInBaseModel("smithy.protocols#rpcv2Cbor").asSmithyModel() private fun awsQuery(): Model = fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() fillInBaseModel("aws.protocols#awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() private fun ec2Query(): Model = fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() fillInBaseModel("aws.protocols#ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel() data class TestCase( val protocolShapeId: String, Loading @@ -120,11 +126,21 @@ object EventStreamTestModels { override fun toString(): String = protocolShapeId } val TEST_CASES = listOf( // // restJson1 // private fun base64Encode(input: ByteArray): String { val encodedBytes = Base64.getEncoder().encode(input) return String(encodedBytes) } private fun createCBORFromJSON(jsonString: String): ByteArray { val jsonMapper = ObjectMapper() val cborMapper = ObjectMapper(CBORFactory()) // Parse JSON string to a generic type. val jsonData = jsonMapper.readValue(jsonString, Any::class.java) // Convert the parsed data to CBOR. return cborMapper.writeValueAsBytes(jsonData) } private val restJsonTestCase = TestCase( protocolShapeId = "aws.protocols#restJson1", model = restJson1(), Loading @@ -137,22 +153,40 @@ object EventStreamTestModels { validTestUnion = """{"Foo":"hello"}""", validSomeError = """{"Message":"some error"}""", validUnmodeledError = """{"Message":"unmodeled error"}""", ) { RestJson(it) }, ) { RestJson(it) } val TEST_CASES = listOf( // // restJson1 // restJsonTestCase, // // rpcV2Cbor // restJsonTestCase.copy( protocolShapeId = "smithy.protocols#rpcv2Cbor", model = rpcv2Cbor(), mediaType = "application/cbor", responseContentType = "application/cbor", eventStreamMessageContentType = "application/cbor", validTestStruct = base64Encode(createCBORFromJSON(restJsonTestCase.validTestStruct)), validMessageWithNoHeaderPayloadTraits = base64Encode(createCBORFromJSON(restJsonTestCase.validMessageWithNoHeaderPayloadTraits)), validTestUnion = base64Encode(createCBORFromJSON(restJsonTestCase.validTestUnion)), validSomeError = base64Encode(createCBORFromJSON(restJsonTestCase.validSomeError)), validUnmodeledError = base64Encode(createCBORFromJSON(restJsonTestCase.validUnmodeledError)), protocolBuilder = { RpcV2Cbor(it) }, ), // // awsJson1_1 // TestCase( restJsonTestCase.copy( protocolShapeId = "aws.protocols#awsJson1_1", model = awsJson11(), mediaType = "application/x-amz-json-1.1", requestContentType = "application/x-amz-json-1.1", responseContentType = "application/x-amz-json-1.1", eventStreamMessageContentType = "application/json", validTestStruct = """{"someString":"hello","someInt":5}""", validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", validTestUnion = """{"Foo":"hello"}""", validSomeError = """{"Message":"some error"}""", validUnmodeledError = """{"Message":"unmodeled error"}""", ) { AwsJson(it, AwsJsonVersion.Json11) }, // // restXml Loading
codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +3 −2 Original line number Diff line number Diff line Loading @@ -134,8 +134,7 @@ open class ServerCodegenVisitor( .protocolFor(context.model, service) this.protocolGeneratorFactory = protocolGeneratorFactory val protocolTransformedModel = ServerProtocolBasedTransformationFactory.createTransformer(protocolShape).transform(baseModel, service) model = codegenDecorator.transformModel(service, protocolTransformedModel, settings) model = codegenDecorator.transformModel(service, baseModel, settings) val serverSymbolProviders = ServerSymbolProviders.from( Loading Loading @@ -210,6 +209,8 @@ open class ServerCodegenVisitor( .let { AttachValidationExceptionToConstrainedOperationInputs.transform(it, settings) } // Tag aggregate shapes reachable from operation input .let(ShapesReachableFromOperationInputTagger::transform) // Remove traits that are not supported by the chosen protocol .let { ServerProtocolBasedTransformationFactory.transform(it, settings) } // Normalize event stream operations .let(EventStreamNormalizer::transform) Loading