diff --git a/aws/sdk-codegen-test/build.gradle.kts b/aws/sdk-codegen-test/build.gradle.kts index 83287ce4f71c8faf34f5101cb333527baff28417..ec6595c9c94eca5bae33f7a805a958c6448c733d 100644 --- a/aws/sdk-codegen-test/build.gradle.kts +++ b/aws/sdk-codegen-test/build.gradle.kts @@ -9,7 +9,7 @@ extra["moduleName"] = "software.amazon.smithy.kotlin.codegen.test" tasks["jar"].enabled = false plugins { - id("software.amazon.smithy").version("0.5.2") + id("software.amazon.smithy").version("0.5.3") } val smithyVersion: String by project diff --git a/aws/sdk/build.gradle.kts b/aws/sdk/build.gradle.kts index caf05e178971464714bd3b12c3ff047d440f2027..98e79ce435377865c765b40f4ba423f09fb11226 100644 --- a/aws/sdk/build.gradle.kts +++ b/aws/sdk/build.gradle.kts @@ -13,7 +13,7 @@ extra["moduleName"] = "software.amazon.smithy.rust.awssdk" tasks["jar"].enabled = false plugins { - id("software.amazon.smithy").version("0.5.2") + id("software.amazon.smithy").version("0.5.3") } val smithyVersion: String by project diff --git a/aws/sdk/integration-tests/kms/Cargo.toml b/aws/sdk/integration-tests/kms/Cargo.toml index 8352558dd339df3411e82ff4640deab67b86662a..e2bd29e6fd491eb3887674c1139a13443c5f4f79 100644 --- a/aws/sdk/integration-tests/kms/Cargo.toml +++ b/aws/sdk/integration-tests/kms/Cargo.toml @@ -17,3 +17,4 @@ aws-auth = { path = "../../build/aws-sdk/aws-auth" } aws-http = { path = "../../build/aws-sdk/aws-http" } tokio = { version = "1", features = ["full"]} tracing-subscriber = "0.2.16" +bytes = "1" diff --git a/aws/sdk/integration-tests/kms/tests/sensitive-it.rs b/aws/sdk/integration-tests/kms/tests/sensitive-it.rs index 9654b00187ef477080f068404538c3baef4418ed..9a526d0bfc067f99affe1f471e670b28f451e269 100644 --- a/aws/sdk/integration-tests/kms/tests/sensitive-it.rs +++ b/aws/sdk/integration-tests/kms/tests/sensitive-it.rs @@ -5,12 +5,14 @@ use aws_http::AwsErrorRetryPolicy; use aws_sdk_kms as kms; +use bytes::Bytes; use kms::error::CreateAliasError; use kms::operation::{CreateAlias, GenerateRandom}; use kms::output::GenerateRandomOutput; use kms::Blob; use smithy_http::body::SdkBody; use smithy_http::operation::Parts; +use smithy_http::response::ParseStrictResponse; use smithy_http::result::SdkError; use smithy_http::retry::ClassifyResponse; use smithy_types::retry::{ErrorKind, RetryKind}; @@ -69,11 +71,13 @@ fn errors_are_retryable() { let op = create_alias_op(); let http_response = http::Response::builder() .status(400) - .body(r#"{ "code": "LimitExceededException" }"#) + .body(Bytes::from_static( + br#"{ "code": "LimitExceededException" }"#, + )) .unwrap(); let err = op .response_handler - .parse_response(&http_response) + .parse(&http_response) .map_err(|e| SdkError::ServiceError { err: e, raw: http_response.map(SdkBody::from), @@ -87,11 +91,11 @@ fn unmodeled_errors_are_retryable() { let op = create_alias_op(); let http_response = http::Response::builder() .status(400) - .body(r#"{ "code": "ThrottlingException" }"#) + .body(Bytes::from_static(br#"{ "code": "ThrottlingException" }"#)) .unwrap(); let err = op .response_handler - .parse_response(&http_response) + .parse(&http_response) .map_err(|e| SdkError::ServiceError { err: e, raw: http_response.map(SdkBody::from), diff --git a/codegen-test/build.gradle.kts b/codegen-test/build.gradle.kts index ba5e62d11d8c3748a152cd26425a857600333bb3..9fe3568f9e5ce68dedd8f1c8e668aa077b30c21c 100644 --- a/codegen-test/build.gradle.kts +++ b/codegen-test/build.gradle.kts @@ -9,7 +9,7 @@ extra["moduleName"] = "software.amazon.smithy.kotlin.codegen.test" tasks["jar"].enabled = false plugins { - id("software.amazon.smithy").version("0.5.2") + id("software.amazon.smithy").version("0.5.3") } val smithyVersion: String by project diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/StreamingTraitSymbolProvider.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/StreamingTraitSymbolProvider.kt index 67283b20ee001fda542f23f6bd93cc984ce42dd4..c518ddb102f68469027e30eb2abf3db70cbaed49 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/StreamingTraitSymbolProvider.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/StreamingTraitSymbolProvider.kt @@ -14,6 +14,7 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.rust.codegen.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.util.hasStreamingMember import software.amazon.smithy.rust.codegen.util.isStreaming @@ -33,7 +34,7 @@ class StreamingShapeSymbolProvider(private val base: RustSymbolProvider, private val container = model.expectShape(shape.container) // We are only targeting output shapes - if (!container.hasTrait(SyntheticOutputTrait::class.java)) { + if (!(container.hasTrait(SyntheticOutputTrait::class.java) || container.hasTrait(SyntheticInputTrait::class.java))) { return initial } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt index 82f542bfb18bde1cd04f17a3a1427e85c22eac4f..46266ea86d21eac815b8e16e42ae4ba93d6777da 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt @@ -17,6 +17,7 @@ import software.amazon.smithy.rust.codegen.rustlang.docs import software.amazon.smithy.rust.codegen.rustlang.documentShape import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeType @@ -25,7 +26,6 @@ import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomizati import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.smithy.letIf -import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.hasStreamingMember import software.amazon.smithy.rust.codegen.util.inputShape @@ -81,18 +81,22 @@ abstract class HttpProtocolGenerator( inputWriter.implBlock(inputShape, symbolProvider) { buildOperation(this, operationShape, customizations, sdkId) toHttpRequestImpl(this, operationShape, inputShape) - val shapeId = inputShape.expectTrait(SyntheticInputTrait::class.java).body - val body = shapeId?.let { model.expectShape(it, StructureShape::class.java) } - toBodyImpl(this, inputShape, body, operationShape) // TODO: streaming shapes need special support rustBlock( - "pub fn assemble(builder: #1T, body: #3T) -> #2T<#3T>", + "fn assemble(mut builder: #1T, body: #3T) -> #2T<#3T>", RuntimeType.HttpRequestBuilder, RuntimeType.Http("request::Request"), - RuntimeType.ByteSlab + RuntimeType.sdkBody(protocolConfig.runtimeConfig) ) { - write("builder.header(#T, body.len()).body(body)", RuntimeType.Http("header::CONTENT_LENGTH")) - write(""".expect("http request should be valid")""") + rustTemplate( + """ + if let Some(content_length) = body.content_length() { + builder = builder.header(#{http}::header::CONTENT_LENGTH, content_length) + } + builder.body(body).expect("should be valid request") + """, + "http" to RuntimeType.http + ) } // pub fn builder() -> ... { } @@ -113,7 +117,7 @@ abstract class HttpProtocolGenerator( fromResponseImpl(this, operationShape) rustBlock( - "pub fn parse_response(&self, $mutability response: &$mutability #T<$type>) -> Result<#T, #T>", + "fn parse_response(&self, $mutability response: &$mutability #T<$type>) -> Result<#T, #T>", RuntimeType.Http("response::Response"), symbolProvider.toSymbol(operationShape.outputShape(model)), operationShape.errorSymbol(symbolProvider) @@ -140,21 +144,17 @@ abstract class HttpProtocolGenerator( } protected fun httpBuilderFun(implBlockWriter: RustWriter, f: RustWriter.() -> Unit) { + Attribute.Custom("allow(clippy::unnecessary_wraps)").render(implBlockWriter) implBlockWriter.rustBlock( - "pub fn request_builder_base(&self) -> Result<#T, #T>", + "fn request_builder_base(&self) -> Result<#T, #T>", RuntimeType.HttpRequestBuilder, buildErrorT ) { f(this) } } - protected fun bodyBuilderFun(implBlockWriter: RustWriter, f: RustWriter.() -> Unit) { - implBlockWriter.rustBlock( - "pub fn build_body(&self) -> #T", RuntimeType.ByteSlab - ) { - f(this) - } - } + data class BodyMetadata(val takesOwnership: Boolean) + abstract fun RustWriter.body(self: String, operationShape: OperationShape): BodyMetadata protected fun fromResponseFun( implBlockWriter: RustWriter, @@ -192,8 +192,9 @@ abstract class HttpProtocolGenerator( implBlockWriter.docs("Consumes the builder and constructs an Operation<#D>", outputSymbol) // For codegen simplicity, allow `let x = ...; x` implBlockWriter.rust("##[allow(clippy::let_and_return)]") + val bodyMetadata = RustWriter.root().body("self", shape) val mut = features.any { it.mutSelf() } - val consumes = features.any { it.consumesSelf() } + val consumes = features.any { it.consumesSelf() } || bodyMetadata.takesOwnership val self = "self".letIf(mut) { "mut $it" }.letIf(!consumes) { "&$it" } implBlockWriter.rustBlock( "pub fn make_operation($self, _config: &#T::Config) -> $returnType", @@ -201,7 +202,11 @@ abstract class HttpProtocolGenerator( ) { withBlock("Ok({", "})") { features.forEach { it.section(OperationSection.MutateInput("self", "_config"))(this) } - rust("let request = Self::assemble(self.request_builder_base()?, self.build_body());") + rust("let request = self.request_builder_base()?;") + withBlock("let body = ", ";") { + body("self", shape) + } + rust("let request = Self::assemble(request, body);") rust( """ ##[allow(unused_mut)] @@ -231,18 +236,6 @@ abstract class HttpProtocolGenerator( abstract fun fromResponseImpl(implBlockWriter: RustWriter, operationShape: OperationShape) - /** - * Add necessary methods to the impl block to generate the request body - * - * Your implementation MUST call [bodyBuilderFun] to create the public method. - */ - abstract fun toBodyImpl( - implBlockWriter: RustWriter, - inputShape: StructureShape, - inputBody: StructureShape?, - operationShape: OperationShape - ) - /** * Add necessary methods to the impl block for the input shape. * diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt index 9350cf75b7595fff460e08653322d85e9641a8c7..87abb00a73a6de4d32be7a9668e643f7dd0ccfaa 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt @@ -41,7 +41,6 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.isOptional import software.amazon.smithy.rust.codegen.smithy.letIf import software.amazon.smithy.rust.codegen.smithy.rustType -import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectMember import software.amazon.smithy.rust.codegen.util.isStreaming @@ -152,8 +151,7 @@ class Instantiator( ctx.letIf(shape.getMemberTrait(model, HttpPrefixHeadersTrait::class.java).isPresent) { it.copy(lowercaseMapKeys = true) }.letIf( - shape.isStreaming(model) && - model.expectShape(shape.container).hasTrait(SyntheticOutputTrait::class.java) + shape.isStreaming(model) ) { it.copy(streaming = true) } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt index 019ca403c729e10f8380a5dca0400eb07ede1142..55898469d44e58e7dcc81746af6712d407ef4b61 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt @@ -13,8 +13,10 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.RustType import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.asType import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustTemplate @@ -30,17 +32,21 @@ import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.smithy.locatedIn import software.amazon.smithy.rust.codegen.smithy.meta import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.rust.codegen.smithy.traits.InputBodyTrait import software.amazon.smithy.rust.codegen.smithy.traits.OutputBodyTrait +import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations import software.amazon.smithy.rust.codegen.smithy.transformers.StructureModifier import software.amazon.smithy.rust.codegen.util.dq +import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.outputShape +import software.amazon.smithy.rust.codegen.util.toSnakeCase sealed class AwsJsonVersion { abstract val value: String @@ -146,6 +152,8 @@ class BasicAwsJsonGenerator( private val awsJsonVersion: AwsJsonVersion ) : HttpProtocolGenerator(protocolConfig) { private val model = protocolConfig.model + private val runtimeConfig = protocolConfig.runtimeConfig + override fun traitImplementations(operationWriter: RustWriter, operationShape: OperationShape) { val outputSymbol = symbolProvider.toSymbol(operationShape.outputShape(model)) val operationName = symbolProvider.toSymbol(operationShape).name @@ -158,7 +166,7 @@ class BasicAwsJsonGenerator( } } """, - "parse_strict" to RuntimeType.parseStrict(symbolProvider.config().runtimeConfig), + "parse_strict" to RuntimeType.parseStrict(runtimeConfig), "output" to outputSymbol, "error" to operationShape.errorSymbol(symbolProvider), "response" to RuntimeType.Http("Response"), @@ -189,30 +197,47 @@ class BasicAwsJsonGenerator( } } - override fun toBodyImpl( - implBlockWriter: RustWriter, - inputShape: StructureShape, - inputBody: StructureShape?, - operationShape: OperationShape - ) { + override fun RustWriter.body(self: String, operationShape: OperationShape): BodyMetadata { + val fnName = "${operationShape.id.name.toSnakeCase()}_mk_body" + val inputShape = operationShape.inputShape(model) + val inputBody = inputShape.expectTrait(SyntheticInputTrait::class.java).body?.let { + model.expectShape( + it, + StructureShape::class.java + ) + } + val sdkBody = RuntimeType.sdkBody(runtimeConfig) if (inputBody == null) { - bodyBuilderFun(implBlockWriter) { - write("\"{}\".to_string().into()") - } - return + rustTemplate("""#{SdkBody}::from("{}")""", "SdkBody" to sdkBody) + return BodyMetadata(takesOwnership = false) } - val bodySymbol = protocolConfig.symbolProvider.toSymbol(inputBody) - implBlockWriter.rustBlock("fn body(&self) -> #T", bodySymbol) { - rustBlock("#T", bodySymbol) { - for (member in inputBody.members()) { - val name = protocolConfig.symbolProvider.toMemberName(member) - write("$name: &self.$name,") + val bodySer = RuntimeType.forInlineFun(fnName, "operation_ser") { + it.rustBlock( + "pub fn $fnName(input: &#T) -> Result<#T, #T>", + symbolProvider.toSymbol(inputShape), + RuntimeType.sdkBody(runtimeConfig), + runtimeConfig.operationBuildError() + ) { + withBlock("let body = ", ";") { + rustBlock("#T", symbolProvider.toSymbol(inputBody)) { + for (member in inputBody.members()) { + val name = protocolConfig.symbolProvider.toMemberName(member) + write("$name: &input.$name,") + } + } } + rustTemplate( + """#{serde_json}::to_vec(&body) + .map(#{SdkBody}::from) + .map_err(|err|#{BuildError}::SerializationError(err.into()))""", + "serde_json" to CargoDependency.SerdeJson.asType(), + "BuildError" to runtimeConfig.operationBuildError(), + "SdkBody" to sdkBody + ) } } - bodyBuilderFun(implBlockWriter) { - write("""#T(&self.body()).expect("serialization should succeed")""", RuntimeType.SerdeJson("to_vec")) - } + rust("#T(&$self)?", bodySer) + return BodyMetadata(takesOwnership = false) } override fun fromResponseImpl(implBlockWriter: RustWriter, operationShape: OperationShape) { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt index 582acf2398883680c193e02765fcca4acf8e04eb..c0ba0a9d5ae4f7e2a2fb508dd42c0c0bc8e0c000 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.DocumentShape +import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape @@ -41,16 +42,20 @@ import software.amazon.smithy.rust.codegen.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.smithy.generators.http.RequestBindingGenerator import software.amazon.smithy.rust.codegen.smithy.generators.http.ResponseBindingGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.smithy.generators.setterName import software.amazon.smithy.rust.codegen.smithy.isOptional +import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.smithy.transformers.RemoveEventStreamOperations import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectMember import software.amazon.smithy.rust.codegen.util.hasStreamingMember +import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.isStreaming import software.amazon.smithy.rust.codegen.util.outputShape +import software.amazon.smithy.rust.codegen.util.toSnakeCase import java.util.logging.Logger class AwsRestJsonFactory : ProtocolGeneratorFactory { @@ -426,93 +431,147 @@ class AwsRestJsonGenerator( } } - private fun serializeViaSyntheticBody( - implBlockWriter: RustWriter, + private fun RustWriter.serializeViaSyntheticBody( + self: String, + inputShape: StructureShape, inputBody: StructureShape ) { - val bodySymbol = protocolConfig.symbolProvider.toSymbol(inputBody) - implBlockWriter.rustBlock("fn body(&self) -> #T", bodySymbol) { - rustBlock("#T", bodySymbol) { - for (member in inputBody.members()) { - val name = protocolConfig.symbolProvider.toMemberName(member) - write("$name: &self.$name,") + val fnName = "synth_body_${inputBody.id.name.toSnakeCase()}" + val bodySer = RuntimeType.forInlineFun(fnName, "operation_ser") { + it.rustBlock( + "pub fn $fnName(input: &#T) -> Result<#T, #T>", + symbolProvider.toSymbol(inputShape), + RuntimeType.sdkBody(runtimeConfig), + runtimeConfig.operationBuildError() + ) { + withBlock("let body = ", ";") { + rustBlock("#T", symbolProvider.toSymbol(inputBody)) { + for (member in inputBody.members()) { + val name = protocolConfig.symbolProvider.toMemberName(member) + write("$name: &input.$name,") + } + } } + rustTemplate( + """#{serde_json}::to_vec(&body) + .map(#{SdkBody}::from) + .map_err(|err|#{BuildError}::SerializationError(err.into()))""", + "serde_json" to CargoDependency.SerdeJson.asType(), + "BuildError" to runtimeConfig.operationBuildError(), + "SdkBody" to sdkBody + ) } } - bodyBuilderFun(implBlockWriter) { - write("""#T(&self.body()).expect("serialization should succeed")""", RuntimeType.SerdeJson("to_vec")) - } + rust("#T(&$self)?", bodySer) } - override fun toBodyImpl( - implBlockWriter: RustWriter, - inputShape: StructureShape, - inputBody: StructureShape?, - operationShape: OperationShape - ) { - // If we created a synthetic input body, serialize that + override fun RustWriter.body(self: String, operationShape: OperationShape): BodyMetadata { + val inputShape = operationShape.inputShape(model) + val inputBody = inputShape.expectTrait(SyntheticInputTrait::class.java).body?.let { + model.expectShape( + it, + StructureShape::class.java + ) + } if (inputBody != null) { - return serializeViaSyntheticBody(implBlockWriter, inputBody) + serializeViaSyntheticBody(self, inputShape, inputBody) + return BodyMetadata(takesOwnership = false) } - - // Otherwise, we need to serialize via the HTTP payload trait val bindings = httpIndex.getRequestBindings(operationShape).toList() - val payload: Pair? = - bindings.firstOrNull { (_, binding) -> binding.location == HttpBinding.Location.PAYLOAD } - val payloadSerde = payload?.let { (payloadMemberName, _) -> + val payloadMemberName: String? = + bindings.firstOrNull { (_, binding) -> binding.location == HttpBinding.Location.PAYLOAD }?.first + if (payloadMemberName == null) { + rustTemplate("""#{SdkBody}::from("")""", "SdkBody" to sdkBody) + return BodyMetadata(takesOwnership = false) + } else { val member = inputShape.expectMember(payloadMemberName) - val rustMemberName = "self.${symbolProvider.toMemberName(member)}" - val targetShape = model.expectShape(member.target) - writable { - val payloadName = safeName() - rust("let $payloadName = &$rustMemberName;") + return serializeViaPayload(member) + } + } + + private fun RustWriter.serializeViaPayload(member: MemberShape): BodyMetadata { + val fnName = "ser_payload_${member.container.name.toSnakeCase()}" + val targetShape = model.expectShape(member.target) + val bodyMetadata: BodyMetadata = RustWriter.root().renderPayload(targetShape, "payload") + val ref = when (bodyMetadata.takesOwnership) { + true -> "" + false -> "&" + } + val serializer = RuntimeType.forInlineFun(fnName, "operation_ser") { + it.rustBlock( + "pub fn $fnName(payload: $ref#T) -> Result<#T, #T>", + symbolProvider.toSymbol(member), + sdkBody, + runtimeConfig.operationBuildError() + ) { // If this targets a member & the member is None, return an empty vec + val ref = when (bodyMetadata.takesOwnership) { + false -> ".as_ref()" + true -> "" + } + if (symbolProvider.toSymbol(member).isOptional()) { - rust( + rustTemplate( """ - let $payloadName = match $payloadName.as_ref() { + let payload = match payload$ref { Some(t) => t, - None => return vec![] - };""" + None => return Ok(#{SdkBody}::from("")) + };""", + "SdkBody" to sdkBody ) } - renderPayload(targetShape, payloadName) + // When the body is a streaming blob it _literally_ is a SdkBody already + // mute this clippy warning to make the codegen a little simpler + Attribute.Custom("allow(clippy::useless_conversion)").render(this) + withBlock("Ok(#T::from(", "))", sdkBody) { + renderPayload(targetShape, "payload") + } } - // body is null, no payload set, so this is empty - } ?: writable { rust("vec![]") } - bodyBuilderFun(implBlockWriter) { - payloadSerde(this) } + rust("#T($ref self.${symbolProvider.toMemberName(member)})?", serializer) + return bodyMetadata } private fun RustWriter.renderPayload( targetShape: Shape, payloadName: String, - ) { + ): BodyMetadata { val serdeToVec = RuntimeType.SerdeJson("to_vec") - when (targetShape) { + return when (targetShape) { // Write the raw string to the payload - is StringShape -> + is StringShape -> { if (targetShape.hasTrait(EnumTrait::class.java)) { - rust("$payloadName.as_str().into()") + rust("$payloadName.as_str()") } else { - rust("""$payloadName.to_string().into()""") + rust("""$payloadName.to_string()""") } - is BlobShape -> + BodyMetadata(takesOwnership = false) + } + + // This works for streaming & non streaming blobs because they both have `into_inner()` which + // can be converted into an SDK body! + is BlobShape -> { // Write the raw blob to the payload - rust("$payloadName.as_ref().into()") - is StructureShape, is UnionShape -> + rust("$payloadName.into_inner()") + BodyMetadata(takesOwnership = true) + } + is StructureShape, is UnionShape -> { // JSON serialize the structure or union targetted rust( - """#T(&$payloadName).expect("serialization should succeed")""", - serdeToVec + """#T(&$payloadName).map_err(|err|#T::SerializationError(err.into()))?""", + serdeToVec, runtimeConfig.operationBuildError() ) - is DocumentShape -> + BodyMetadata(takesOwnership = false) + } + is DocumentShape -> { rustTemplate( - """#{to_vec}(&#{doc_json}::SerDoc(&$payloadName)).expect("serialization should succeed")""", + """#{to_vec}(&#{doc_json}::SerDoc(&$payloadName)).map_err(|err|#{BuildError}::SerializationError(err.into()))?""", "to_vec" to serdeToVec, - "doc_json" to RuntimeType.DocJson + "doc_json" to RuntimeType.DocJson, + "BuildError" to runtimeConfig.operationBuildError() ) + BodyMetadata(takesOwnership = false) + } else -> TODO("Unexpected payload target type") } } diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/StreamingShapeSymbolProviderTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/StreamingShapeSymbolProviderTest.kt index c0240ebaadc796d97122f1e2dc23119ae29b2aad..757704ad5e31985513825dd4e0a1df25e0257e7e 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/StreamingShapeSymbolProviderTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/StreamingShapeSymbolProviderTest.kt @@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.smithy import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.MemberShape -import software.amazon.smithy.rust.codegen.rustlang.RustType import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.testutil.testSymbolProvider @@ -39,17 +38,19 @@ internal class StreamingShapeSymbolProviderTest { OperationNormalizer(model).transformModel(OperationNormalizer.NoBody, OperationNormalizer.NoBody) val symbolProvider = testSymbolProvider(modelWithOperationTraits) symbolProvider.toSymbol(modelWithOperationTraits.lookup("test#GenerateSpeechOutput\$data")).name shouldBe ("byte_stream::ByteStream") - symbolProvider.toSymbol(modelWithOperationTraits.lookup("test#GenerateSpeechInput\$data")) - .rustType() shouldBe (RustType.Option(RustType.Opaque("Blob", "smithy_types"))) + symbolProvider.toSymbol(modelWithOperationTraits.lookup("test#GenerateSpeechInput\$data")).name shouldBe ("byte_stream::ByteStream") } - // This is coming up soon, but for now, we only generate streaming binary output @Test - fun `do not generate a byte stream in streaming input`() { + fun `streaming members have a default`() { val modelWithOperationTraits = OperationNormalizer(model).transformModel(OperationNormalizer.NoBody, OperationNormalizer.NoBody) val symbolProvider = testSymbolProvider(modelWithOperationTraits) - symbolProvider.toSymbol(modelWithOperationTraits.lookup("test#GenerateSpeechInput\$data")) - .rustType() shouldBe (RustType.Option(RustType.Opaque("Blob", "smithy_types"))) + + val outputSymbol = symbolProvider.toSymbol(modelWithOperationTraits.lookup("test#GenerateSpeechOutput\$data")) + val inputSymbol = symbolProvider.toSymbol(modelWithOperationTraits.lookup("test#GenerateSpeechInput\$data")) + // Ensure that users don't need to set an input + outputSymbol.defaultValue() shouldBe Default.RustDefault + inputSymbol.defaultValue() shouldBe Default.RustDefault } } diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt index 2d1ddc17c446fdbf44c265663ea8dd96933ef547..e0dcf0b8a6a8fb2399b34175426b8bb35ffdf508 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt @@ -122,6 +122,11 @@ class HttpProtocolTestGeneratorTest { // A stubbed test protocol to do enable testing intentionally broken protocols class TestProtocol(private val protocolConfig: ProtocolConfig) : HttpProtocolGenerator(protocolConfig) { private val symbolProvider = protocolConfig.symbolProvider + override fun RustWriter.body(self: String, operationShape: OperationShape): BodyMetadata { + writeWithNoFormatting(body) + return BodyMetadata(takesOwnership = false) + } + override fun traitImplementations(operationWriter: RustWriter, operationShape: OperationShape) { operationWriter.rustTemplate( """ @@ -145,17 +150,6 @@ class HttpProtocolTestGeneratorTest { } } - override fun toBodyImpl( - implBlockWriter: RustWriter, - inputShape: StructureShape, - inputBody: StructureShape?, - operationShape: OperationShape - ) { - bodyBuilderFun(implBlockWriter) { - writeWithNoFormatting(body) - } - } - override fun toHttpRequestImpl( implBlockWriter: RustWriter, operationShape: OperationShape, diff --git a/rust-runtime/smithy-http/src/body.rs b/rust-runtime/smithy-http/src/body.rs index bffafe3c7708f8a144899771c0fde79c86e1bafa..e9ab2a0a1d5950b9a971debc672cf113476cd0ca 100644 --- a/rust-runtime/smithy-http/src/body.rs +++ b/rust-runtime/smithy-http/src/body.rs @@ -101,11 +101,15 @@ impl SdkBody { _ => None, } } + + pub fn content_length(&self) -> Option { + self.size_hint().exact() + } } impl From<&str> for SdkBody { fn from(s: &str) -> Self { - SdkBody(Inner::Once(Some(Bytes::copy_from_slice(s.as_bytes())))) + Self::from(s.as_bytes()) } } @@ -127,6 +131,18 @@ impl From> for SdkBody { } } +impl From for SdkBody { + fn from(s: String) -> Self { + Self::from(s.into_bytes()) + } +} + +impl From<&[u8]> for SdkBody { + fn from(data: &[u8]) -> Self { + SdkBody(Inner::Once(Some(Bytes::copy_from_slice(data)))) + } +} + impl http_body::Body for SdkBody { type Data = Bytes; type Error = Error; diff --git a/rust-runtime/smithy-http/src/operation.rs b/rust-runtime/smithy-http/src/operation.rs index 457b381c8afe06b1adc040099b33c2ba9b35482b..32de7a687ba68930c468c41528d12e2b1bc6d80e 100644 --- a/rust-runtime/smithy-http/src/operation.rs +++ b/rust-runtime/smithy-http/src/operation.rs @@ -6,6 +6,7 @@ use crate::body::SdkBody; use crate::property_bag::PropertyBag; use std::borrow::Cow; +use std::error::Error; use std::ops::DerefMut; use std::sync::{Arc, Mutex, MutexGuard}; use thiserror::Error; @@ -62,6 +63,8 @@ pub enum BuildError { field: &'static str, details: &'static str, }, + #[error("Failed during serialization: {0}")] + SerializationError(#[from] Box), } pub struct Operation { diff --git a/rust-runtime/smithy-types/src/lib.rs b/rust-runtime/smithy-types/src/lib.rs index 08f83659be422cda7190a087567e84eaf75fd7fd..0a2aacb7e6469ed6475e727b137e311ca7738655 100644 --- a/rust-runtime/smithy-types/src/lib.rs +++ b/rust-runtime/smithy-types/src/lib.rs @@ -22,6 +22,9 @@ impl Blob { pub fn new>>(inp: T) -> Self { Blob { inner: inp.into() } } + pub fn into_inner(self) -> Vec { + self.inner + } } impl AsRef<[u8]> for Blob {