Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustWriter.kt +2 −2 Original line number Diff line number Diff line Loading @@ -285,7 +285,7 @@ class RustWriter private constructor( * - If the field is an unboxed primitive, it will only be called if the field is non-zero * */ fun ifSet(shape: Shape, member: Symbol, outerField: String, block: CodeWriter.(field: String) -> Unit) { fun ifSet(shape: Shape, member: Symbol, outerField: String, block: RustWriter.(field: String) -> Unit) { // TODO: this API should be refactored so that we don't need to strip `&` to get reference comparisons to work. when { member.isOptional() -> { Loading @@ -307,7 +307,7 @@ class RustWriter private constructor( fun listForEach( target: Shape, outerField: String, block: CodeWriter.(field: String, target: ShapeId) -> Unit block: RustWriter.(field: String, target: ShapeId) -> Unit ) { if (target is CollectionShape) { val derefName = safeName("inner") Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt +27 −15 Original line number Diff line number Diff line Loading @@ -26,8 +26,10 @@ import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.smithy.generators.redactIfNecessary import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectMember import software.amazon.smithy.rust.codegen.util.hasTrait Loading Loading @@ -59,16 +61,30 @@ class RequestBindingGenerator( val model: Model, private val symbolProvider: RustSymbolProvider, private val runtimeConfig: RuntimeConfig, private val writer: RustWriter, private val defaultTimestampFormat: TimestampFormatTrait.Format, private val shape: OperationShape, private val inputShape: StructureShape, private val httpTrait: HttpTrait private val httpTrait: HttpTrait, ) { // TODO: make defaultTimestampFormat configurable private val defaultTimestampFormat = TimestampFormatTrait.Format.EPOCH_SECONDS private val index = HttpBindingIndex.of(model) private val buildError = runtimeConfig.operationBuildError() constructor( protocolConfig: ProtocolConfig, defaultTimestampFormat: TimestampFormatTrait.Format, httpBindingResolver: HttpBindingResolver, shape: OperationShape, inputShape: StructureShape, ) : this( protocolConfig.model, protocolConfig.symbolProvider, protocolConfig.runtimeConfig, defaultTimestampFormat, shape, inputShape, httpBindingResolver.httpTrait(shape), ) /** * Generates `update_http_builder` and all necessary dependency functions into the impl block provided by * [implBlockWriter]. The specific behavior is configured by [httpTrait]. Loading Loading @@ -145,7 +161,7 @@ class RequestBindingGenerator( #{build_error}::InvalidField { field: ${memberName.dq()}, details: format!("`{}` cannot be used as a header name: {}", k, err)} })?; use std::convert::TryFrom; let header_value = ${headerFmtFun(target, memberShape, "v")}; let header_value = ${headerFmtFun(this, target, memberShape, "v")}; let header_value = http::header::HeaderValue::try_from(header_value).map_err(|err| { #{build_error}::InvalidField { field: ${memberName.dq()}, Loading Loading @@ -174,7 +190,7 @@ class RequestBindingGenerator( ifSet(memberType, memberSymbol, "&self.$memberName") { field -> listForEach(memberType, field) { innerField, targetId -> val innerMemberType = model.expectShape(targetId) val formatted = headerFmtFun(innerMemberType, memberShape, innerField) val formatted = headerFmtFun(this, innerMemberType, memberShape, innerField) val safeName = safeName("formatted") write("let $safeName = $formatted;") rustBlock("if !$safeName.is_empty()") { Loading Loading @@ -203,7 +219,7 @@ class RequestBindingGenerator( /** * Format [member] in the when used as an HTTP header */ private fun headerFmtFun(target: Shape, member: MemberShape, targetName: String): String { private fun headerFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String): String { return when { target.isStringShape -> { if (target.hasTrait<MediaTypeTrait>()) { Loading Loading @@ -238,7 +254,7 @@ class RequestBindingGenerator( val formatString = httpTrait.uriFormatString() val args = httpTrait.uri.labels.map { label -> val member = inputShape.expectMember(label.content) "${label.content} = ${labelFmtFun(model.expectShape(member.target), member, label)}" "${label.content} = ${labelFmtFun(writer, model.expectShape(member.target), member, label)}" } val combinedArgs = listOf(formatString, *args.toTypedArray()) writer.addImport(RuntimeType.stdfmt.member("Write").toSymbol(), null) Loading Loading @@ -316,11 +332,7 @@ class RequestBindingGenerator( val target = model.expectShape(targetId) rust( "query.push_kv(${param.locationName.dq()}, &${ paramFmtFun( target, memberShape, innerField ) paramFmtFun(writer, target, memberShape, innerField) });" ) } Loading @@ -333,7 +345,7 @@ class RequestBindingGenerator( /** * Format [member] when used as a queryParam */ private fun paramFmtFun(target: Shape, member: MemberShape, targetName: String): String { private fun paramFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String): String { return when { target.isStringShape -> { val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_string")) Loading @@ -359,7 +371,7 @@ class RequestBindingGenerator( /** * Format [member] when used as an HTTP Label (`/bucket/{key}`) */ private fun labelFmtFun(target: Shape, member: MemberShape, label: SmithyPattern.Segment): String { private fun labelFmtFun(writer: RustWriter, target: Shape, member: MemberShape, label: SmithyPattern.Segment): String { val memberName = symbolProvider.toMemberName(member) return when { target.isStringShape -> { Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt +10 −7 Original line number Diff line number Diff line Loading @@ -34,6 +34,8 @@ import software.amazon.smithy.rust.codegen.smithy.Default import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.defaultValue import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.hasTrait Loading Loading @@ -62,8 +64,8 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera * } * ``` */ fun generateDeserializeHeaderFn(binding: HttpBinding): RuntimeType { check(binding.location == HttpBinding.Location.HEADER) fun generateDeserializeHeaderFn(binding: HttpBindingDescriptor): RuntimeType { check(binding.location == HttpLocation.HEADER) val outputT = symbolProvider.toSymbol(binding.member) val fnName = "deser_header_${fnName(operationShape, binding)}" return RuntimeType.forInlineFun(fnName, "http_serde") { writer -> Loading @@ -79,7 +81,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera } } fun generateDeserializePrefixHeaderFn(binding: HttpBinding): RuntimeType { fun generateDeserializePrefixHeaderFn(binding: HttpBindingDescriptor): RuntimeType { check(binding.location == HttpBinding.Location.PREFIX_HEADERS) val outputT = symbolProvider.toSymbol(binding.member) check(outputT.rustType().stripOuter<RustType.Option>() is RustType.HashMap) { outputT.rustType() } Loading Loading @@ -122,7 +124,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera * Generate a function to deserialize `[binding]` from the response payload */ fun generateDeserializePayloadFn( binding: HttpBinding, binding: HttpBindingDescriptor, errorT: RuntimeType, // Deserialize a single structure or union member marked as a payload structuredHandler: RustWriter.(String) -> Unit, Loading Loading @@ -156,7 +158,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera } private fun RustWriter.deserializeStreamingBody( binding: HttpBinding, binding: HttpBindingDescriptor, ) { val member = binding.member val targetShape = model.expectShape(member.target) Loading @@ -172,7 +174,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera } private fun RustWriter.deserializePayloadBody( binding: HttpBinding, binding: HttpBindingDescriptor, errorSymbol: RuntimeType, structuredHandler: RustWriter.(String) -> Unit, docShapeHandler: RustWriter.(String) -> Unit Loading Loading @@ -312,5 +314,6 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera * Generate a unique name for the deserializer function for a given operationShape -> member pair */ // rename here technically not required, operations and members cannot be renamed private fun fnName(operationShape: OperationShape, binding: HttpBinding) = "${operationShape.id.getName(service).toSnakeCase()}_${binding.memberName.toSnakeCase()}" private fun fnName(operationShape: OperationShape, binding: HttpBindingDescriptor) = "${operationShape.id.getName(service).toSnakeCase()}_${binding.memberName.toSnakeCase()}" } codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson.kt +37 −1 Original line number Diff line number Diff line Loading @@ -8,10 +8,13 @@ package software.amazon.smithy.rust.codegen.smithy.protocols import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.pattern.UriPattern import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.ToShapeId import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.RustType Loading Loading @@ -46,6 +49,7 @@ import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStream import software.amazon.smithy.rust.codegen.smithy.transformers.StructureModifier import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.orNull import software.amazon.smithy.rust.codegen.util.outputShape sealed class AwsJsonVersion { Loading Loading @@ -148,6 +152,38 @@ class SyntheticBodySymbolProvider(private val model: Model, private val base: Ru } } class AwsJsonHttpBindingResolver( private val model: Model, private val awsJsonVersion: AwsJsonVersion, ) : HttpBindingResolver { private val httpTrait = HttpTrait.builder() .code(200) .method("POST") .uri(UriPattern.parse("/")) .build() private fun bindings(shape: ToShapeId?) = shape?.let { model.expectShape(it.toShapeId()) }?.members() ?.map { HttpBindingDescriptor(it, HttpLocation.DOCUMENT, "document") } ?.toList() ?: emptyList() override fun httpTrait(operationShape: OperationShape): HttpTrait = httpTrait override fun requestBindings(operationShape: OperationShape): List<HttpBindingDescriptor> = bindings(operationShape.input.orNull()) override fun responseBindings(operationShape: OperationShape): List<HttpBindingDescriptor> = bindings(operationShape.output.orNull()) override fun errorResponseBindings(errorShape: ToShapeId): List<HttpBindingDescriptor> = bindings(errorShape) override fun requestContentType(operationShape: OperationShape): String = "application/x-amz-json-${awsJsonVersion.value}" } // TODO: Refactor to use HttpBoundProtocolGenerator class BasicAwsJsonGenerator( private val protocolConfig: ProtocolConfig, private val awsJsonVersion: AwsJsonVersion Loading Loading @@ -198,7 +234,7 @@ class BasicAwsJsonGenerator( } override fun RustWriter.body(self: String, operationShape: OperationShape): BodyMetadata { val generator = JsonSerializerGenerator(protocolConfig) val generator = JsonSerializerGenerator(protocolConfig, AwsJsonHttpBindingResolver(model, awsJsonVersion)) val serializer = generator.operationSerializer(operationShape) serializer?.also { sym -> rustTemplate( Loading codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsQuery.kt 0 → 100644 +86 −0 Original line number Diff line number Diff line /* * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0. */ package software.amazon.smithy.rust.codegen.smithy.protocols import software.amazon.smithy.model.Model import software.amazon.smithy.model.pattern.UriPattern import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport import software.amazon.smithy.rust.codegen.smithy.protocols.parse.AwsQueryParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.AwsQuerySerializerGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations class AwsQueryFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> { override fun buildProtocolGenerator(protocolConfig: ProtocolConfig): HttpBoundProtocolGenerator = HttpBoundProtocolGenerator(protocolConfig, AwsQueryProtocol(protocolConfig)) override fun transformModel(model: Model): Model { return OperationNormalizer(model).transformModel( inputBodyFactory = OperationNormalizer.NoBody, outputBodyFactory = OperationNormalizer.NoBody ).let(RemoveEventStreamOperations::transform) } override fun support(): ProtocolSupport { return ProtocolSupport( requestSerialization = true, requestBodySerialization = true, responseDeserialization = true, errorDeserialization = true, ) } } class AwsQueryProtocol(private val protocolConfig: ProtocolConfig) : Protocol { private val runtimeConfig = protocolConfig.runtimeConfig private val awsQueryErrors: RuntimeType = RuntimeType.wrappedXmlErrors(runtimeConfig) override val httpBindingResolver: HttpBindingResolver = StaticHttpBindingResolver( protocolConfig.model, HttpTrait.builder() .code(200) .method("POST") .uri(UriPattern.parse("/")) .build(), "application/x-www-form-urlencoded" ) override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = AwsQueryParserGenerator(protocolConfig, awsQueryErrors) override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = AwsQuerySerializerGenerator(protocolConfig) override fun parseGenericError(operationShape: OperationShape): RuntimeType { /** fn parse_generic(response: &Response<Bytes>) -> Result<smithy_types::error::Generic, T: Error> **/ return RuntimeType.forInlineFun("parse_generic_error", "xml_deser") { it.rustBlockTemplate( "pub fn parse_generic_error(response: &#{Response}<#{Bytes}>) -> Result<#{Error}, #{XmlError}>", "Response" to RuntimeType.http.member("Response"), "Bytes" to RuntimeType.Bytes, "Error" to RuntimeType.GenericError(runtimeConfig), "XmlError" to CargoDependency.smithyXml(runtimeConfig).asType().member("decode::XmlError") ) { rust("#T::parse_generic_error(response.body().as_ref())", awsQueryErrors) } } } } Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustWriter.kt +2 −2 Original line number Diff line number Diff line Loading @@ -285,7 +285,7 @@ class RustWriter private constructor( * - If the field is an unboxed primitive, it will only be called if the field is non-zero * */ fun ifSet(shape: Shape, member: Symbol, outerField: String, block: CodeWriter.(field: String) -> Unit) { fun ifSet(shape: Shape, member: Symbol, outerField: String, block: RustWriter.(field: String) -> Unit) { // TODO: this API should be refactored so that we don't need to strip `&` to get reference comparisons to work. when { member.isOptional() -> { Loading @@ -307,7 +307,7 @@ class RustWriter private constructor( fun listForEach( target: Shape, outerField: String, block: CodeWriter.(field: String, target: ShapeId) -> Unit block: RustWriter.(field: String, target: ShapeId) -> Unit ) { if (target is CollectionShape) { val derefName = safeName("inner") Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt +27 −15 Original line number Diff line number Diff line Loading @@ -26,8 +26,10 @@ import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.smithy.generators.redactIfNecessary import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectMember import software.amazon.smithy.rust.codegen.util.hasTrait Loading Loading @@ -59,16 +61,30 @@ class RequestBindingGenerator( val model: Model, private val symbolProvider: RustSymbolProvider, private val runtimeConfig: RuntimeConfig, private val writer: RustWriter, private val defaultTimestampFormat: TimestampFormatTrait.Format, private val shape: OperationShape, private val inputShape: StructureShape, private val httpTrait: HttpTrait private val httpTrait: HttpTrait, ) { // TODO: make defaultTimestampFormat configurable private val defaultTimestampFormat = TimestampFormatTrait.Format.EPOCH_SECONDS private val index = HttpBindingIndex.of(model) private val buildError = runtimeConfig.operationBuildError() constructor( protocolConfig: ProtocolConfig, defaultTimestampFormat: TimestampFormatTrait.Format, httpBindingResolver: HttpBindingResolver, shape: OperationShape, inputShape: StructureShape, ) : this( protocolConfig.model, protocolConfig.symbolProvider, protocolConfig.runtimeConfig, defaultTimestampFormat, shape, inputShape, httpBindingResolver.httpTrait(shape), ) /** * Generates `update_http_builder` and all necessary dependency functions into the impl block provided by * [implBlockWriter]. The specific behavior is configured by [httpTrait]. Loading Loading @@ -145,7 +161,7 @@ class RequestBindingGenerator( #{build_error}::InvalidField { field: ${memberName.dq()}, details: format!("`{}` cannot be used as a header name: {}", k, err)} })?; use std::convert::TryFrom; let header_value = ${headerFmtFun(target, memberShape, "v")}; let header_value = ${headerFmtFun(this, target, memberShape, "v")}; let header_value = http::header::HeaderValue::try_from(header_value).map_err(|err| { #{build_error}::InvalidField { field: ${memberName.dq()}, Loading Loading @@ -174,7 +190,7 @@ class RequestBindingGenerator( ifSet(memberType, memberSymbol, "&self.$memberName") { field -> listForEach(memberType, field) { innerField, targetId -> val innerMemberType = model.expectShape(targetId) val formatted = headerFmtFun(innerMemberType, memberShape, innerField) val formatted = headerFmtFun(this, innerMemberType, memberShape, innerField) val safeName = safeName("formatted") write("let $safeName = $formatted;") rustBlock("if !$safeName.is_empty()") { Loading Loading @@ -203,7 +219,7 @@ class RequestBindingGenerator( /** * Format [member] in the when used as an HTTP header */ private fun headerFmtFun(target: Shape, member: MemberShape, targetName: String): String { private fun headerFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String): String { return when { target.isStringShape -> { if (target.hasTrait<MediaTypeTrait>()) { Loading Loading @@ -238,7 +254,7 @@ class RequestBindingGenerator( val formatString = httpTrait.uriFormatString() val args = httpTrait.uri.labels.map { label -> val member = inputShape.expectMember(label.content) "${label.content} = ${labelFmtFun(model.expectShape(member.target), member, label)}" "${label.content} = ${labelFmtFun(writer, model.expectShape(member.target), member, label)}" } val combinedArgs = listOf(formatString, *args.toTypedArray()) writer.addImport(RuntimeType.stdfmt.member("Write").toSymbol(), null) Loading Loading @@ -316,11 +332,7 @@ class RequestBindingGenerator( val target = model.expectShape(targetId) rust( "query.push_kv(${param.locationName.dq()}, &${ paramFmtFun( target, memberShape, innerField ) paramFmtFun(writer, target, memberShape, innerField) });" ) } Loading @@ -333,7 +345,7 @@ class RequestBindingGenerator( /** * Format [member] when used as a queryParam */ private fun paramFmtFun(target: Shape, member: MemberShape, targetName: String): String { private fun paramFmtFun(writer: RustWriter, target: Shape, member: MemberShape, targetName: String): String { return when { target.isStringShape -> { val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_string")) Loading @@ -359,7 +371,7 @@ class RequestBindingGenerator( /** * Format [member] when used as an HTTP Label (`/bucket/{key}`) */ private fun labelFmtFun(target: Shape, member: MemberShape, label: SmithyPattern.Segment): String { private fun labelFmtFun(writer: RustWriter, target: Shape, member: MemberShape, label: SmithyPattern.Segment): String { val memberName = symbolProvider.toMemberName(member) return when { target.isStringShape -> { Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt +10 −7 Original line number Diff line number Diff line Loading @@ -34,6 +34,8 @@ import software.amazon.smithy.rust.codegen.smithy.Default import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.defaultValue import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.hasTrait Loading Loading @@ -62,8 +64,8 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera * } * ``` */ fun generateDeserializeHeaderFn(binding: HttpBinding): RuntimeType { check(binding.location == HttpBinding.Location.HEADER) fun generateDeserializeHeaderFn(binding: HttpBindingDescriptor): RuntimeType { check(binding.location == HttpLocation.HEADER) val outputT = symbolProvider.toSymbol(binding.member) val fnName = "deser_header_${fnName(operationShape, binding)}" return RuntimeType.forInlineFun(fnName, "http_serde") { writer -> Loading @@ -79,7 +81,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera } } fun generateDeserializePrefixHeaderFn(binding: HttpBinding): RuntimeType { fun generateDeserializePrefixHeaderFn(binding: HttpBindingDescriptor): RuntimeType { check(binding.location == HttpBinding.Location.PREFIX_HEADERS) val outputT = symbolProvider.toSymbol(binding.member) check(outputT.rustType().stripOuter<RustType.Option>() is RustType.HashMap) { outputT.rustType() } Loading Loading @@ -122,7 +124,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera * Generate a function to deserialize `[binding]` from the response payload */ fun generateDeserializePayloadFn( binding: HttpBinding, binding: HttpBindingDescriptor, errorT: RuntimeType, // Deserialize a single structure or union member marked as a payload structuredHandler: RustWriter.(String) -> Unit, Loading Loading @@ -156,7 +158,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera } private fun RustWriter.deserializeStreamingBody( binding: HttpBinding, binding: HttpBindingDescriptor, ) { val member = binding.member val targetShape = model.expectShape(member.target) Loading @@ -172,7 +174,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera } private fun RustWriter.deserializePayloadBody( binding: HttpBinding, binding: HttpBindingDescriptor, errorSymbol: RuntimeType, structuredHandler: RustWriter.(String) -> Unit, docShapeHandler: RustWriter.(String) -> Unit Loading Loading @@ -312,5 +314,6 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera * Generate a unique name for the deserializer function for a given operationShape -> member pair */ // rename here technically not required, operations and members cannot be renamed private fun fnName(operationShape: OperationShape, binding: HttpBinding) = "${operationShape.id.getName(service).toSnakeCase()}_${binding.memberName.toSnakeCase()}" private fun fnName(operationShape: OperationShape, binding: HttpBindingDescriptor) = "${operationShape.id.getName(service).toSnakeCase()}_${binding.memberName.toSnakeCase()}" }
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson.kt +37 −1 Original line number Diff line number Diff line Loading @@ -8,10 +8,13 @@ package software.amazon.smithy.rust.codegen.smithy.protocols import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.OperationIndex import software.amazon.smithy.model.pattern.UriPattern import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.ToShapeId import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.RustType Loading Loading @@ -46,6 +49,7 @@ import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStream import software.amazon.smithy.rust.codegen.smithy.transformers.StructureModifier import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.hasTrait import software.amazon.smithy.rust.codegen.util.orNull import software.amazon.smithy.rust.codegen.util.outputShape sealed class AwsJsonVersion { Loading Loading @@ -148,6 +152,38 @@ class SyntheticBodySymbolProvider(private val model: Model, private val base: Ru } } class AwsJsonHttpBindingResolver( private val model: Model, private val awsJsonVersion: AwsJsonVersion, ) : HttpBindingResolver { private val httpTrait = HttpTrait.builder() .code(200) .method("POST") .uri(UriPattern.parse("/")) .build() private fun bindings(shape: ToShapeId?) = shape?.let { model.expectShape(it.toShapeId()) }?.members() ?.map { HttpBindingDescriptor(it, HttpLocation.DOCUMENT, "document") } ?.toList() ?: emptyList() override fun httpTrait(operationShape: OperationShape): HttpTrait = httpTrait override fun requestBindings(operationShape: OperationShape): List<HttpBindingDescriptor> = bindings(operationShape.input.orNull()) override fun responseBindings(operationShape: OperationShape): List<HttpBindingDescriptor> = bindings(operationShape.output.orNull()) override fun errorResponseBindings(errorShape: ToShapeId): List<HttpBindingDescriptor> = bindings(errorShape) override fun requestContentType(operationShape: OperationShape): String = "application/x-amz-json-${awsJsonVersion.value}" } // TODO: Refactor to use HttpBoundProtocolGenerator class BasicAwsJsonGenerator( private val protocolConfig: ProtocolConfig, private val awsJsonVersion: AwsJsonVersion Loading Loading @@ -198,7 +234,7 @@ class BasicAwsJsonGenerator( } override fun RustWriter.body(self: String, operationShape: OperationShape): BodyMetadata { val generator = JsonSerializerGenerator(protocolConfig) val generator = JsonSerializerGenerator(protocolConfig, AwsJsonHttpBindingResolver(model, awsJsonVersion)) val serializer = generator.operationSerializer(operationShape) serializer?.also { sym -> rustTemplate( Loading
codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsQuery.kt 0 → 100644 +86 −0 Original line number Diff line number Diff line /* * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0. */ package software.amazon.smithy.rust.codegen.smithy.protocols import software.amazon.smithy.model.Model import software.amazon.smithy.model.pattern.UriPattern import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport import software.amazon.smithy.rust.codegen.smithy.protocols.parse.AwsQueryParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.AwsQuerySerializerGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations class AwsQueryFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator> { override fun buildProtocolGenerator(protocolConfig: ProtocolConfig): HttpBoundProtocolGenerator = HttpBoundProtocolGenerator(protocolConfig, AwsQueryProtocol(protocolConfig)) override fun transformModel(model: Model): Model { return OperationNormalizer(model).transformModel( inputBodyFactory = OperationNormalizer.NoBody, outputBodyFactory = OperationNormalizer.NoBody ).let(RemoveEventStreamOperations::transform) } override fun support(): ProtocolSupport { return ProtocolSupport( requestSerialization = true, requestBodySerialization = true, responseDeserialization = true, errorDeserialization = true, ) } } class AwsQueryProtocol(private val protocolConfig: ProtocolConfig) : Protocol { private val runtimeConfig = protocolConfig.runtimeConfig private val awsQueryErrors: RuntimeType = RuntimeType.wrappedXmlErrors(runtimeConfig) override val httpBindingResolver: HttpBindingResolver = StaticHttpBindingResolver( protocolConfig.model, HttpTrait.builder() .code(200) .method("POST") .uri(UriPattern.parse("/")) .build(), "application/x-www-form-urlencoded" ) override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = AwsQueryParserGenerator(protocolConfig, awsQueryErrors) override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = AwsQuerySerializerGenerator(protocolConfig) override fun parseGenericError(operationShape: OperationShape): RuntimeType { /** fn parse_generic(response: &Response<Bytes>) -> Result<smithy_types::error::Generic, T: Error> **/ return RuntimeType.forInlineFun("parse_generic_error", "xml_deser") { it.rustBlockTemplate( "pub fn parse_generic_error(response: &#{Response}<#{Bytes}>) -> Result<#{Error}, #{XmlError}>", "Response" to RuntimeType.http.member("Response"), "Bytes" to RuntimeType.Bytes, "Error" to RuntimeType.GenericError(runtimeConfig), "XmlError" to CargoDependency.smithyXml(runtimeConfig).asType().member("decode::XmlError") ) { rust("#T::parse_generic_error(response.body().as_ref())", awsQueryErrors) } } } }