From 5fa9c0c439a840cbc113a5a29f65d4e20620f2c9 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Thu, 30 Mar 2023 10:06:31 -0700 Subject: [PATCH] Generate `RequestSerializer` traits for client operations (#2510) * Clean up client deserialization generation logic * Generate `RequestSerializer` traits * Fix property access issue for customizations * Move `aws-smithy-runtime-test` into `aws/sra-test` * Generate operation runtime plugins to register ser/de * Add `send_v2` to fluent builders that uses the orchestrator * Fixes --- .../rustsdk/IntegrationTestDependencies.kt | 20 +- .../aws-smithy-runtime-test/src/auth.rs | 24 - .../aws-smithy-runtime-test/src/conn.rs | 44 -- .../aws-smithy-runtime-test/src/de.rs | 35 -- .../aws-smithy-runtime-test/src/endpoints.rs | 24 - .../aws-smithy-runtime-test/src/main.rs | 70 --- .../aws-smithy-runtime-test/src/retry.rs | 48 -- .../aws-smithy-runtime-test/src/ser.rs | 32 -- aws/sra-test/build.gradle.kts | 75 ++- .../integration-tests/aws-sdk-s3/.gitignore | 1 + .../integration-tests/aws-sdk-s3/Cargo.toml | 24 + .../aws-sdk-s3/tests/interceptors.rs | 103 ++++ .../aws-sdk-s3/tests/sra_test.rs | 24 + .../client/smithy/ClientCodegenVisitor.kt | 38 +- .../OperationRuntimePluginGenerator.kt | 43 ++ .../smithy/generators/ServiceGenerator.kt | 27 +- .../ServiceRuntimePluginGenerator.kt | 52 ++ .../client/FluentClientGenerator.kt | 40 ++ .../http/RequestBindingGenerator.kt | 4 +- .../protocol/ClientProtocolGenerator.kt | 17 +- .../protocol/MakeOperationGenerator.kt | 1 + .../protocol/ProtocolParserGenerator.kt | 344 +++++++++++++ .../protocol/RequestSerializerGenerator.kt | 143 ++++++ .../protocol/ResponseDeserializerGenerator.kt | 169 +++++++ .../protocols/HttpBoundProtocolGenerator.kt | 473 +----------------- .../protocol/ProtocolTestGeneratorTest.kt | 1 + .../rust/codegen/core/smithy/RuntimeType.kt | 2 + .../src/client/interceptors/context.rs | 33 +- .../src/client/interceptors/error.rs | 31 +- .../src/client/orchestrator.rs | 10 +- .../src/client/runtime_plugin.rs | 13 +- .../src/type_erasure.rs | 27 + .../src/client/orchestrator.rs | 14 +- 33 files changed, 1170 insertions(+), 836 deletions(-) delete mode 100644 aws/sdk/integration-tests/aws-smithy-runtime-test/src/auth.rs delete mode 100644 aws/sdk/integration-tests/aws-smithy-runtime-test/src/conn.rs delete mode 100644 aws/sdk/integration-tests/aws-smithy-runtime-test/src/de.rs delete mode 100644 aws/sdk/integration-tests/aws-smithy-runtime-test/src/endpoints.rs delete mode 100644 aws/sdk/integration-tests/aws-smithy-runtime-test/src/main.rs delete mode 100644 aws/sdk/integration-tests/aws-smithy-runtime-test/src/retry.rs delete mode 100644 aws/sdk/integration-tests/aws-smithy-runtime-test/src/ser.rs create mode 100644 aws/sra-test/integration-tests/aws-sdk-s3/.gitignore create mode 100644 aws/sra-test/integration-tests/aws-sdk-s3/Cargo.toml create mode 100644 aws/sra-test/integration-tests/aws-sdk-s3/tests/interceptors.rs create mode 100644 aws/sra-test/integration-tests/aws-sdk-s3/tests/sra_test.rs create mode 100644 codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationRuntimePluginGenerator.kt create mode 100644 codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt create mode 100644 codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt create mode 100644 codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt create mode 100644 codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt index 35de46a16..6eedc44a8 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt @@ -28,7 +28,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Compani import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly @@ -56,8 +55,8 @@ class IntegrationTestDecorator : ClientCodegenDecorator { val hasTests = Files.exists(testPackagePath.resolve("tests")) val hasBenches = Files.exists(testPackagePath.resolve("benches")) baseCustomizations + IntegrationTestDependencies( + codegenContext, moduleName, - codegenContext.runtimeConfig, hasTests, hasBenches, ) @@ -68,18 +67,18 @@ class IntegrationTestDecorator : ClientCodegenDecorator { } class IntegrationTestDependencies( + private val codegenContext: ClientCodegenContext, private val moduleName: String, - private val runtimeConfig: RuntimeConfig, private val hasTests: Boolean, private val hasBenches: Boolean, ) : LibRsCustomization() { override fun section(section: LibRsSection) = when (section) { is LibRsSection.Body -> testDependenciesOnly { if (hasTests) { - val smithyClient = CargoDependency.smithyClient(runtimeConfig) + val smithyClient = CargoDependency.smithyClient(codegenContext.runtimeConfig) .copy(features = setOf("test-util"), scope = DependencyScope.Dev) addDependency(smithyClient) - addDependency(CargoDependency.smithyProtocolTestHelpers(runtimeConfig)) + addDependency(CargoDependency.smithyProtocolTestHelpers(codegenContext.runtimeConfig)) addDependency(SerdeJson) addDependency(Tokio) addDependency(FuturesUtil) @@ -99,7 +98,7 @@ class IntegrationTestDependencies( private fun serviceSpecificCustomizations(): List = when (moduleName) { "transcribestreaming" -> listOf(TranscribeTestDependencies()) - "s3" -> listOf(S3TestDependencies()) + "s3" -> listOf(S3TestDependencies(codegenContext)) else -> emptyList() } } @@ -113,7 +112,7 @@ class TranscribeTestDependencies : LibRsCustomization() { } } -class S3TestDependencies : LibRsCustomization() { +class S3TestDependencies(private val codegenContext: ClientCodegenContext) : LibRsCustomization() { override fun section(section: LibRsSection): Writable = writable { addDependency(AsyncStd) @@ -124,5 +123,12 @@ class S3TestDependencies : LibRsCustomization() { addDependency(TempFile) addDependency(TracingAppender) addDependency(TracingTest) + + // TODO(enableNewSmithyRuntime): These additional dependencies may not be needed anymore when removing this flag + // depending on if the sra-test is kept around or not. + if (codegenContext.settings.codegenConfig.enableNewSmithyRuntime) { + addDependency(CargoDependency.smithyRuntime(codegenContext.runtimeConfig).toDevDependency()) + addDependency(CargoDependency.smithyRuntimeApi(codegenContext.runtimeConfig).toDevDependency()) + } } } diff --git a/aws/sdk/integration-tests/aws-smithy-runtime-test/src/auth.rs b/aws/sdk/integration-tests/aws-smithy-runtime-test/src/auth.rs deleted file mode 100644 index 632f7957a..000000000 --- a/aws/sdk/integration-tests/aws-smithy-runtime-test/src/auth.rs +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -use aws_smithy_runtime_api::client::orchestrator::BoxError; -use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin; -use aws_smithy_runtime_api::config_bag::ConfigBag; - -#[derive(Debug)] -pub struct GetObjectAuthOrc {} - -impl GetObjectAuthOrc { - pub fn new() -> Self { - Self {} - } -} - -impl RuntimePlugin for GetObjectAuthOrc { - fn configure(&self, _cfg: &mut ConfigBag) -> Result<(), BoxError> { - // TODO(orchestrator) put an auth orchestrator in the bag - Ok(()) - } -} diff --git a/aws/sdk/integration-tests/aws-smithy-runtime-test/src/conn.rs b/aws/sdk/integration-tests/aws-smithy-runtime-test/src/conn.rs deleted file mode 100644 index 84350fd7c..000000000 --- a/aws/sdk/integration-tests/aws-smithy-runtime-test/src/conn.rs +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -use aws_smithy_client::conns::Https; -use aws_smithy_client::hyper_ext::Adapter; -use aws_smithy_http::body::SdkBody; -use aws_smithy_runtime_api::client::orchestrator::{ - BoxError, BoxFallibleFut, Connection, HttpRequest, -}; -use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin; -use aws_smithy_runtime_api::config_bag::ConfigBag; - -#[derive(Debug)] -pub struct HyperConnection { - _adapter: Adapter, -} - -impl RuntimePlugin for HyperConnection { - fn configure(&self, _cfg: &mut ConfigBag) -> Result<(), BoxError> { - // TODO(orchestrator) put a connection in the bag - Ok(()) - } -} - -impl HyperConnection { - pub fn new() -> Self { - Self { - _adapter: Adapter::builder().build(aws_smithy_client::conns::https()), - } - } -} - -impl Connection for HyperConnection { - fn call( - &self, - _req: &mut HttpRequest, - _cfg: &ConfigBag, - ) -> BoxFallibleFut> { - todo!("hyper's connector wants to take ownership of req"); - // self.adapter.call(req) - } -} diff --git a/aws/sdk/integration-tests/aws-smithy-runtime-test/src/de.rs b/aws/sdk/integration-tests/aws-smithy-runtime-test/src/de.rs deleted file mode 100644 index 522889f2e..000000000 --- a/aws/sdk/integration-tests/aws-smithy-runtime-test/src/de.rs +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -use aws_smithy_runtime_api::client::interceptors::context::OutputOrError; -use aws_smithy_runtime_api::client::orchestrator::{BoxError, HttpResponse, ResponseDeserializer}; -use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin; -use aws_smithy_runtime_api::config_bag::ConfigBag; - -#[derive(Debug)] -pub struct GetObjectResponseDeserializer {} - -impl GetObjectResponseDeserializer { - pub fn new() -> Self { - Self {} - } -} - -impl RuntimePlugin for GetObjectResponseDeserializer { - fn configure(&self, _cfg: &mut ConfigBag) -> Result<(), BoxError> { - // TODO(orchestrator) put a deserializer in the bag - Ok(()) - } -} - -impl ResponseDeserializer for GetObjectResponseDeserializer { - fn deserialize_streaming(&self, _response: &mut HttpResponse) -> Option { - todo!() - } - - fn deserialize_nonstreaming(&self, _response: &HttpResponse) -> OutputOrError { - todo!() - } -} diff --git a/aws/sdk/integration-tests/aws-smithy-runtime-test/src/endpoints.rs b/aws/sdk/integration-tests/aws-smithy-runtime-test/src/endpoints.rs deleted file mode 100644 index 0eb3ac40a..000000000 --- a/aws/sdk/integration-tests/aws-smithy-runtime-test/src/endpoints.rs +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -use aws_smithy_http::event_stream::BoxError; -use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin; -use aws_smithy_runtime_api::config_bag::ConfigBag; - -#[derive(Debug)] -pub struct GetObjectEndpointOrc {} - -impl GetObjectEndpointOrc { - pub fn new() -> Self { - Self {} - } -} - -impl RuntimePlugin for GetObjectEndpointOrc { - fn configure(&self, _cfg: &mut ConfigBag) -> Result<(), BoxError> { - // TODO(orchestrator) put an endpoint orchestrator in the bag - Ok(()) - } -} diff --git a/aws/sdk/integration-tests/aws-smithy-runtime-test/src/main.rs b/aws/sdk/integration-tests/aws-smithy-runtime-test/src/main.rs deleted file mode 100644 index 7309a4caa..000000000 --- a/aws/sdk/integration-tests/aws-smithy-runtime-test/src/main.rs +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -mod auth; -mod conn; -mod de; -mod endpoints; -mod interceptors; -mod retry; -mod ser; - -use aws_sdk_s3::operation::get_object::{GetObjectError, GetObjectInput, GetObjectOutput}; -use aws_sdk_s3::types::ChecksumMode; -use aws_smithy_runtime::client::orchestrator::invoke; -use aws_smithy_runtime_api::client::interceptors::Interceptors; -use aws_smithy_runtime_api::client::orchestrator::{BoxError, HttpRequest, HttpResponse}; -use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugins; -use aws_smithy_runtime_api::config_bag::ConfigBag; -use aws_smithy_runtime_api::type_erasure::TypedBox; - -#[tokio::main] -async fn main() -> Result<(), BoxError> { - tracing_subscriber::fmt::init(); - - // Create the config we'll need to send the request + the request itself - let sdk_config = aws_config::load_from_env().await; - let _service_config = aws_sdk_s3::Config::from(&sdk_config); - - let input = TypedBox::new( - GetObjectInput::builder() - .bucket("zhessler-test-bucket") - .key("1000-lines.txt") - .checksum_mode(ChecksumMode::Enabled) - .build()?, - ) - .erase(); - - let mut runtime_plugins = RuntimePlugins::new(); - - // TODO(smithy-orchestrator-codegen) Make it so these are added by default for S3 - runtime_plugins - .with_client_plugin(auth::GetObjectAuthOrc::new()) - .with_client_plugin(conn::HyperConnection::new()) - // TODO(smithy-orchestrator-codegen) Make it so these are added by default for this S3 operation - .with_operation_plugin(endpoints::GetObjectEndpointOrc::new()) - .with_operation_plugin(retry::GetObjectRetryStrategy::new()) - .with_operation_plugin(de::GetObjectResponseDeserializer::new()) - .with_operation_plugin(ser::GetObjectInputSerializer::new()); - - let mut cfg = ConfigBag::base(); - let mut interceptors: Interceptors = Interceptors::new(); - let output = TypedBox::::assume_from( - invoke(input, &mut interceptors, &runtime_plugins, &mut cfg) - .await - .map_err(|err| { - err.map_service_error(|err| { - TypedBox::::assume_from(err) - .expect("error is GetObjectError") - .unwrap() - }) - })?, - ) - .expect("output is GetObjectOutput") - .unwrap(); - - dbg!(output); - Ok(()) -} diff --git a/aws/sdk/integration-tests/aws-smithy-runtime-test/src/retry.rs b/aws/sdk/integration-tests/aws-smithy-runtime-test/src/retry.rs deleted file mode 100644 index 2a731a5f3..000000000 --- a/aws/sdk/integration-tests/aws-smithy-runtime-test/src/retry.rs +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -use aws_smithy_runtime_api::client::interceptors::InterceptorContext; -use aws_smithy_runtime_api::client::orchestrator::{ - BoxError, HttpRequest, HttpResponse, RetryStrategy, -}; -use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin; -use aws_smithy_runtime_api::config_bag::ConfigBag; - -#[derive(Debug)] -pub struct GetObjectRetryStrategy {} - -impl GetObjectRetryStrategy { - pub fn new() -> Self { - Self {} - } -} - -impl RuntimePlugin for GetObjectRetryStrategy { - fn configure(&self, _cfg: &mut ConfigBag) -> Result<(), BoxError> { - // TODO(orchestrator) put a retry strategy in the bag - Ok(()) - } -} - -impl RetryStrategy for GetObjectRetryStrategy { - fn should_attempt_initial_request(&self, _cfg: &ConfigBag) -> Result<(), BoxError> { - todo!() - } - - fn should_attempt_retry( - &self, - _context: &InterceptorContext, - _cfg: &ConfigBag, - ) -> Result { - todo!() - } -} - -// retry_classifier: Arc::new( -// |res: Result<&SdkSuccess, &SdkError>| -> RetryKind { -// let classifier = AwsResponseRetryClassifier::new(); -// classifier.classify_retry(res) -// }, -// ), diff --git a/aws/sdk/integration-tests/aws-smithy-runtime-test/src/ser.rs b/aws/sdk/integration-tests/aws-smithy-runtime-test/src/ser.rs deleted file mode 100644 index 5aed30268..000000000 --- a/aws/sdk/integration-tests/aws-smithy-runtime-test/src/ser.rs +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -use aws_smithy_http::event_stream::BoxError; -use aws_smithy_runtime_api::client::interceptors::context::Input; -use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, RequestSerializer}; -use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin; -use aws_smithy_runtime_api::config_bag::ConfigBag; - -#[derive(Debug)] -pub struct GetObjectInputSerializer {} - -impl GetObjectInputSerializer { - pub fn new() -> Self { - Self {} - } -} - -impl RuntimePlugin for GetObjectInputSerializer { - fn configure(&self, _cfg: &mut ConfigBag) -> Result<(), BoxError> { - // TODO(orchestrator) put a serializer in the bag - Ok(()) - } -} - -impl RequestSerializer for GetObjectInputSerializer { - fn serialize_input(&self, _input: &Input, _cfg: &ConfigBag) -> Result { - todo!() - } -} diff --git a/aws/sra-test/build.gradle.kts b/aws/sra-test/build.gradle.kts index d7cdeaa69..7344916af 100644 --- a/aws/sra-test/build.gradle.kts +++ b/aws/sra-test/build.gradle.kts @@ -19,6 +19,9 @@ val properties = PropertyRetriever(rootProject, project) val pluginName = "rust-client-codegen" val workingDirUnderBuildDir = "smithyprojections/sdk-sra-test/" +val publisherToolPath = rootProject.projectDir.resolve("tools/ci-build/publisher") +val outputDir = buildDir.resolve("sdk") + configure { outputDirectory = file("$buildDir/$workingDirUnderBuildDir") } @@ -36,28 +39,28 @@ dependencies { implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") } -val allCodegenTests = listOf( - CodegenTest( +data class Service( + val serviceId: String, + val moduleName: String, + val imports: List, +) +val servicesToGenerate = listOf( + Service( "com.amazonaws.dynamodb#DynamoDB_20120810", "aws-sdk-dynamodb", - imports = listOf("../sdk/aws-models/dynamodb.json"), - extraConfig = """ - , - "codegen": { - "includeFluentClient": false, - "enableNewSmithyRuntime": true - }, - "customizationConfig": { - "awsSdk": { - "generateReadme": false - } - } - """, + listOf("../sdk/aws-models/dynamodb.json"), ), - CodegenTest( + Service( "com.amazonaws.s3#AmazonS3", "aws-sdk-s3", - imports = listOf("../sdk/aws-models/s3.json", "../sdk/aws-models/s3-tests.smithy"), + listOf("../sdk/aws-models/s3.json", "../sdk/aws-models/s3-tests.smithy"), + ), +) +val allCodegenTests = servicesToGenerate.map { + CodegenTest( + it.serviceId, + it.moduleName, + imports = it.imports, extraConfig = """ , "codegen": { @@ -70,8 +73,8 @@ val allCodegenTests = listOf( } } """, - ), -) + ) +} project.registerGenerateSmithyBuildTask(rootProject, pluginName, allCodegenTests) project.registerGenerateCargoWorkspaceTask(rootProject, pluginName, allCodegenTests, workingDirUnderBuildDir) @@ -86,3 +89,37 @@ project.registerCargoCommandsTasks(buildDir.resolve(workingDirUnderBuildDir), de tasks["test"].finalizedBy(cargoCommands(properties).map { it.toString }) tasks["clean"].doFirst { delete("smithy-build.json") } + +/** + * The aws/rust-runtime crates depend on local versions of the Smithy core runtime enabling local compilation. However, + * those paths need to be replaced in the final build. We should probably fix this with some symlinking. + */ +fun rewritePathDependency(line: String): String { + // some runtime crates are actually dependent on the generated bindings: + return line.replace("../sdk/build/aws-sdk/sdk/", "") + // others use relative dependencies:: + .replace("../../rust-runtime/", "") +} + +tasks.register("relocateServices") { + description = "relocate AWS services to their final destination" + doLast { + servicesToGenerate.forEach { service -> + logger.info("Relocating ${service.moduleName}...") + copy { + from("$buildDir/smithyprojections/sdk-sra-test/${service.moduleName}/rust-client-codegen") + into(outputDir.resolve(service.moduleName)) + } + copy { + from(projectDir.resolve("integration-tests/${service.moduleName}/tests")) + into(outputDir.resolve(service.moduleName).resolve("tests")) + } + } + } + dependsOn("smithyBuildJar") + inputs.dir("$buildDir/smithyprojections/sdk-sra-test/") + outputs.dir(outputDir) +} +tasks["assemble"].apply { + dependsOn("relocateServices") +} diff --git a/aws/sra-test/integration-tests/aws-sdk-s3/.gitignore b/aws/sra-test/integration-tests/aws-sdk-s3/.gitignore new file mode 100644 index 000000000..5a44eef09 --- /dev/null +++ b/aws/sra-test/integration-tests/aws-sdk-s3/.gitignore @@ -0,0 +1 @@ +/Cargo.lock diff --git a/aws/sra-test/integration-tests/aws-sdk-s3/Cargo.toml b/aws/sra-test/integration-tests/aws-sdk-s3/Cargo.toml new file mode 100644 index 000000000..d2a2cd324 --- /dev/null +++ b/aws/sra-test/integration-tests/aws-sdk-s3/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "aws-smithy-runtime-test" +version = "0.1.0" +edition = "2021" +publish = false +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +aws-credential-types = { path = "../../../rust-runtime/aws-credential-types", features = ["test-util"] } +aws-http = { path = "../../../rust-runtime/aws-http" } +aws-sigv4 = { path = "../../../rust-runtime/aws-sigv4" } +aws-types = { path = "../../../rust-runtime/aws-types" } +aws-sdk-s3 = { path = "../../build/sdk/aws-sdk-s3" } +aws-smithy-async = { path = "../../../../rust-runtime/aws-smithy-async", features = ["rt-tokio"] } +aws-smithy-client = { path = "../../../../rust-runtime/aws-smithy-client" } +aws-smithy-types = { path = "../../../../rust-runtime/aws-smithy-types" } +aws-smithy-http = { path = "../../../../rust-runtime/aws-smithy-http" } +aws-smithy-runtime = { path = "../../../../rust-runtime/aws-smithy-runtime" } +aws-smithy-runtime-api = { path = "../../../../rust-runtime/aws-smithy-runtime-api" } +tokio = { version = "1.23.1", features = ["macros", "test-util", "rt-multi-thread"] } +tracing = "0.1.37" +tracing-subscriber = { version = "0.3.15", features = ["env-filter", "json"] } +http = "0.2.3" +http-body = "0.4.5" diff --git a/aws/sra-test/integration-tests/aws-sdk-s3/tests/interceptors.rs b/aws/sra-test/integration-tests/aws-sdk-s3/tests/interceptors.rs new file mode 100644 index 000000000..fd4c8649c --- /dev/null +++ b/aws/sra-test/integration-tests/aws-sdk-s3/tests/interceptors.rs @@ -0,0 +1,103 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +// type TxReq = http::Request; +// type TxRes = http::Response; +// +// pub struct SigV4SigningConfigInterceptor { +// pub signing_service: &'static str, +// pub signing_region: Option, +// } + +// // Mount the interceptors +// let mut interceptors = Interceptors::new(); +// let sig_v4_signing_config_interceptor = SigV4SigningConfigInterceptor { +// signing_region: service_config.region.clone(), +// signing_service: service_config.signing_service(), +// }; +// let credentials_cache_interceptor = CredentialsCacheInterceptor { +// shared_credentials_cache: service_config.credentials_cache.clone(), +// }; +// let checksum_interceptor = ChecksumInterceptor { +// checksum_mode: input.checksum_mode().cloned(), +// }; +// interceptors +// .with_interceptor(sig_v4_signing_config_interceptor) +// .with_interceptor(credentials_cache_interceptor) +// .with_interceptor(checksum_interceptor); + +// let token_bucket = Box::new(standard::TokenBucket::builder().max_tokens(500).build()); +// +// impl Interceptor for SigV4SigningConfigInterceptor { +// fn modify_before_signing( +// &mut self, +// context: &mut InterceptorContext, +// ) -> Result<(), InterceptorError> { +// let mut props = context.properties_mut(); +// +// let mut signing_config = OperationSigningConfig::default_config(); +// signing_config.signing_options.content_sha256_header = true; +// signing_config.signing_options.double_uri_encode = false; +// signing_config.signing_options.normalize_uri_path = false; +// props.insert(signing_config); +// props.insert(aws_types::SigningService::from_static(self.signing_service)); +// +// if let Some(signing_region) = self.signing_region.as_ref() { +// props.insert(aws_types::region::SigningRegion::from( +// signing_region.clone(), +// )); +// } +// +// Ok(()) +// } +// } +// +// pub struct CredentialsCacheInterceptor { +// pub shared_credentials_cache: SharedCredentialsCache, +// } +// +// impl Interceptor for CredentialsCacheInterceptor { +// fn modify_before_signing( +// &mut self, +// context: &mut InterceptorContext, +// ) -> Result<(), InterceptorError> { +// match self +// .shared_credentials_cache +// .as_ref() +// .provide_cached_credentials() +// .now_or_never() +// { +// Some(Ok(creds)) => { +// context.properties_mut().insert(creds); +// } +// // ignore the case where there is no credentials cache wired up +// Some(Err(CredentialsError::CredentialsNotLoaded { .. })) => { +// tracing::info!("credentials cache returned CredentialsNotLoaded, ignoring") +// } +// // if we get another error class, there is probably something actually wrong that the user will +// // want to know about +// Some(Err(other)) => return Err(InterceptorError::ModifyBeforeSigning(other.into())), +// None => unreachable!("fingers crossed that creds are always available"), +// } +// +// Ok(()) +// } +// } +// +// pub struct ChecksumInterceptor { +// pub checksum_mode: Option, +// } +// +// impl Interceptor for ChecksumInterceptor { +// fn modify_before_serialization( +// &mut self, +// context: &mut InterceptorContext, +// ) -> Result<(), InterceptorError> { +// let mut props = context.properties_mut(); +// props.insert(self.checksum_mode.clone()); +// +// Ok(()) +// } +// } diff --git a/aws/sra-test/integration-tests/aws-sdk-s3/tests/sra_test.rs b/aws/sra-test/integration-tests/aws-sdk-s3/tests/sra_test.rs new file mode 100644 index 000000000..caa1b5bba --- /dev/null +++ b/aws/sra-test/integration-tests/aws-sdk-s3/tests/sra_test.rs @@ -0,0 +1,24 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +mod interceptors; + +#[tokio::test] +async fn sra_test() { + tracing_subscriber::fmt::init(); + + // TODO(orchestrator-testing): Replace the connector with a fake request/response + let config = aws_sdk_s3::Config::builder().build(); + let client = aws_sdk_s3::Client::from_conf(config); + + let _ = dbg!( + client + .get_object() + .bucket("zhessler-test-bucket") + .key("1000-lines.txt") + .send_v2() + .await + ); +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt index 1a7681006..d7ce6440e 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt @@ -24,6 +24,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceGener import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.error.OperationErrorGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator +import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientProtocolLoader import software.amazon.smithy.rust.codegen.client.smithy.transformers.AddErrorMessage import software.amazon.smithy.rust.codegen.client.smithy.transformers.RemoveEventStreamOperations @@ -47,6 +48,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveSha import software.amazon.smithy.rust.codegen.core.util.CommandFailed import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.isEventStream import software.amazon.smithy.rust.codegen.core.util.letIf import software.amazon.smithy.rust.codegen.core.util.runCommand @@ -185,13 +187,7 @@ class ClientCodegenVisitor( * - Operation structures */ override fun serviceShape(shape: ServiceShape) { - ServiceGenerator( - rustCrate, - protocolGenerator, - protocolGeneratorFactory.support(), - codegenContext, - codegenDecorator, - ).render() + ServiceGenerator(rustCrate, codegenContext, codegenDecorator).render() } override fun getDefault(shape: Shape?) { @@ -301,14 +297,34 @@ class ClientCodegenVisitor( } /** - * Generate errors for operation shapes + * Generate operations */ - override fun operationShape(shape: OperationShape) { - rustCrate.withModule(symbolProvider.moduleForOperationError(shape)) { + override fun operationShape(operationShape: OperationShape) { + rustCrate.useShapeWriter(operationShape) operationWriter@{ + rustCrate.useShapeWriter(operationShape.inputShape(codegenContext.model)) inputWriter@{ + // Render the operation shape & serializers input `input.rs` + protocolGenerator.renderOperation( + this@operationWriter, + this@inputWriter, + operationShape, + codegenDecorator.operationCustomizations(codegenContext, operationShape, listOf()), + ) + + // render protocol tests into `operation.rs` (note operationWriter vs. inputWriter) + ProtocolTestGenerator( + codegenContext, + protocolGeneratorFactory.support(), + operationShape, + this@operationWriter, + ).render() + } + } + + rustCrate.withModule(symbolProvider.moduleForOperationError(operationShape)) { OperationErrorGenerator( model, symbolProvider, - shape, + operationShape, codegenDecorator.errorCustomizations(codegenContext, emptyList()), ).render(this) } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationRuntimePluginGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationRuntimePluginGenerator.kt new file mode 100644 index 000000000..fb0a7f6e1 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationRuntimePluginGenerator.kt @@ -0,0 +1,43 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators + +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType + +/** + * Generates operation-level runtime plugins + */ +class OperationRuntimePluginGenerator( + codegenContext: ClientCodegenContext, +) { + private val codegenScope = codegenContext.runtimeConfig.let { rc -> + arrayOf( + "BoxError" to RuntimeType.smithyRuntimeApi(rc).resolve("client::runtime_plugin::BoxError"), + "ConfigBag" to RuntimeType.smithyRuntimeApi(rc).resolve("config_bag::ConfigBag"), + "ConfigBagAccessors" to RuntimeType.smithyRuntimeApi(rc).resolve("client::orchestrator::ConfigBagAccessors"), + "RuntimePlugin" to RuntimeType.smithyRuntimeApi(rc).resolve("client::runtime_plugin::RuntimePlugin"), + ) + } + + fun render(writer: RustWriter, operationStructName: String) { + writer.rustTemplate( + """ + impl #{RuntimePlugin} for $operationStructName { + fn configure(&self, cfg: &mut #{ConfigBag}) -> Result<(), #{BoxError}> { + use #{ConfigBagAccessors} as _; + cfg.set_request_serializer(${operationStructName}RequestSerializer); + cfg.set_response_deserializer(${operationStructName}ResponseDeserializer); + Ok(()) + } + } + """, + *codegenScope, + ) + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceGenerator.kt index c495c623f..b28a329fc 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceGenerator.kt @@ -11,12 +11,8 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfigGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ServiceErrorGenerator -import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator -import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.smithy.RustCrate -import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport -import software.amazon.smithy.rust.codegen.core.util.inputShape /** * ServiceGenerator @@ -26,8 +22,6 @@ import software.amazon.smithy.rust.codegen.core.util.inputShape */ class ServiceGenerator( private val rustCrate: RustCrate, - private val protocolGenerator: ClientProtocolGenerator, - private val protocolSupport: ProtocolSupport, private val clientCodegenContext: ClientCodegenContext, private val decorator: ClientCodegenDecorator, ) { @@ -39,23 +33,6 @@ class ServiceGenerator( */ fun render() { val operations = index.getContainedOperations(clientCodegenContext.serviceShape).sortedBy { it.id } - operations.map { operation -> - rustCrate.useShapeWriter(operation) operationWriter@{ - rustCrate.useShapeWriter(operation.inputShape(clientCodegenContext.model)) inputWriter@{ - // Render the operation shape & serializers input `input.rs` - protocolGenerator.renderOperation( - this@operationWriter, - this@inputWriter, - operation, - decorator.operationCustomizations(clientCodegenContext, operation, listOf()), - ) - - // render protocol tests into `operation.rs` (note operationWriter vs. inputWriter) - ProtocolTestGenerator(clientCodegenContext, protocolSupport, operation, this@operationWriter).render() - } - } - } - ServiceErrorGenerator( clientCodegenContext, operations, @@ -67,6 +44,10 @@ class ServiceGenerator( clientCodegenContext, extraCustomizations = decorator.configCustomizations(clientCodegenContext, listOf()), ).render(this) + + if (clientCodegenContext.settings.codegenConfig.enableNewSmithyRuntime) { + ServiceRuntimePluginGenerator(clientCodegenContext).render(this) + } } rustCrate.lib { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt new file mode 100644 index 000000000..952814c02 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt @@ -0,0 +1,52 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators + +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType + +/** + * Generates the service-level runtime plugin + */ +class ServiceRuntimePluginGenerator( + codegenContext: ClientCodegenContext, +) { + private val codegenScope = codegenContext.runtimeConfig.let { rc -> + arrayOf( + "BoxError" to RuntimeType.smithyRuntimeApi(rc).resolve("client::runtime_plugin::BoxError"), + "ConfigBag" to RuntimeType.smithyRuntimeApi(rc).resolve("config_bag::ConfigBag"), + "RuntimePlugin" to RuntimeType.smithyRuntimeApi(rc).resolve("client::runtime_plugin::RuntimePlugin"), + ) + } + + fun render(writer: RustWriter) { + writer.rustTemplate( + """ + pub(crate) struct ServiceRuntimePlugin; + + impl ServiceRuntimePlugin { + pub fn new() -> Self { Self } + } + + impl #{RuntimePlugin} for ServiceRuntimePlugin { + fn configure(&self, _cfg: &mut #{ConfigBag}) -> Result<(), #{BoxError}> { + // TODO(RuntimePlugins): Add the AuthOptionResolver to the config bag + // TODO(RuntimePlugins): Add the EndpointResolver to the config bag + // TODO(RuntimePlugins): Add the IdentityResolver to the config bag + // TODO(RuntimePlugins): Add the Connection to the config bag + // TODO(RuntimePlugins): Add the HttpAuthSchemes to the config bag + // TODO(RuntimePlugins): Add the RetryStrategy to the config bag + // TODO(RuntimePlugins): Add the TraceProbe to the config bag + Ok(()) + } + } + """, + *codegenScope, + ) + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt index b0950583d..b0f48bc01 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt @@ -316,6 +316,46 @@ class FluentClientGenerator( generics.toRustGenerics(), ), ) + if (codegenContext.settings.codegenConfig.enableNewSmithyRuntime) { + rustTemplate( + """ + // TODO(enableNewSmithyRuntime): Replace `send` with `send_v2` + /// Sends the request and returns the response. + /// + /// If an error occurs, an `SdkError` will be returned with additional details that + /// can be matched against. + /// + /// By default, any retryable failures will be retried twice. Retry behavior + /// is configurable with the [RetryConfig](aws_smithy_types::retry::RetryConfig), which can be + /// set when configuring the client. + pub async fn send_v2(self) -> std::result::Result<#{OperationOutput}, #{SdkError}<#{OperationError}, #{HttpResponse}>> { + let runtime_plugins = #{RuntimePlugins}::new() + .with_client_plugin(crate::config::ServiceRuntimePlugin::new()) + .with_operation_plugin(#{Operation}::new()); + let input = self.inner.build().map_err(#{SdkError}::construction_failure)?; + let input = #{TypedBox}::new(input).erase(); + let output = #{invoke}(input, &runtime_plugins) + .await + .map_err(|err| { + err.map_service_error(|err| { + #{TypedBox}::<#{OperationError}>::assume_from(err) + .expect("correct error type") + .unwrap() + }) + })?; + Ok(#{TypedBox}::<#{OperationOutput}>::assume_from(output).expect("correct output type").unwrap()) + } + """, + "HttpResponse" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::orchestrator::HttpResponse"), + "OperationError" to errorType, + "Operation" to symbolProvider.toSymbol(operation), + "OperationOutput" to outputType, + "RuntimePlugins" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::runtime_plugin::RuntimePlugins"), + "SdkError" to RuntimeType.sdkError(runtimeConfig), + "TypedBox" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("type_erasure::TypedBox"), + "invoke" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::invoke"), + ) + } PaginatorGenerator.paginatorType(codegenContext, generics, operation, retryClassifier)?.also { paginatorType -> rustTemplate( """ diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGenerator.kt index 1d3f38a88..148a09d4b 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGenerator.kt @@ -125,11 +125,11 @@ class RequestBindingGenerator( "${label.content} = ${local(member)}" } val combinedArgs = listOf(formatString, *args.toTypedArray()) - writer.addImport(RuntimeType.stdFmt.resolve("Write").toSymbol(), null) writer.rustBlockTemplate( "fn uri_base(_input: &#{Input}, output: &mut String) -> std::result::Result<(), #{BuildError}>", *codegenScope, ) { + rust("use #T as _;", RuntimeType.stdFmt.resolve("Write")) httpTrait.uri.labels.map { label -> val member = inputShape.expectMember(label.content) serializeLabel(member, label, local(member)) @@ -168,7 +168,7 @@ class RequestBindingGenerator( "fn uri_query(_input: &#{Input}, mut output: &mut String) -> Result<(), #{BuildError}>", *codegenScope, ) { - write("let mut query = #T::new(&mut output);", RuntimeType.queryFormat(runtimeConfig, "Writer")) + write("let mut query = #T::new(output);", RuntimeType.queryFormat(runtimeConfig, "Writer")) literalParams.forEach { (k, v) -> // When `v` is an empty string, no value should be set. // this generates a query string like `?k=v&xyz` diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolGenerator.kt index 52016f877..de1983472 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolGenerator.kt @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators.protocol import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationRuntimePluginGenerator import software.amazon.smithy.rust.codegen.client.smithy.protocols.HttpBoundProtocolTraitImplGenerator import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive @@ -19,17 +20,21 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustom import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.util.inputShape open class ClientProtocolGenerator( - codegenContext: ClientCodegenContext, + private val codegenContext: ClientCodegenContext, private val protocol: Protocol, /** * Operations generate a `make_operation(&config)` method to build a `aws_smithy_http::Operation` that can be dispatched * This is the serializer side of request dispatch */ + // TODO(enableNewSmithyRuntime): Remove the `makeOperationGenerator` private val makeOperationGenerator: MakeOperationGenerator, + private val bodyGenerator: ProtocolPayloadGenerator, + // TODO(enableNewSmithyRuntime): Remove the `traitGenerator` private val traitGenerator: HttpBoundProtocolTraitImplGenerator, ) : ProtocolGenerator(codegenContext, protocol) { /** @@ -42,6 +47,7 @@ open class ClientProtocolGenerator( */ fun renderOperation( operationWriter: RustWriter, + // TODO(enableNewSmithyRuntime): Remove the `inputWriter` since `make_operation` generation is going away inputWriter: RustWriter, operationShape: OperationShape, customizations: List, @@ -86,5 +92,14 @@ open class ClientProtocolGenerator( writeCustomizations(customizations, OperationSection.OperationImplBlock(customizations)) } traitGenerator.generateTraitImpls(operationWriter, operationShape, customizations) + + if (codegenContext.settings.codegenConfig.enableNewSmithyRuntime) { + OperationRuntimePluginGenerator(codegenContext).render(operationWriter, operationName) + + ResponseDeserializerGenerator(codegenContext, protocol) + .render(operationWriter, operationShape, customizations) + RequestSerializerGenerator(codegenContext, protocol, bodyGenerator) + .render(operationWriter, operationShape, customizations) + } } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt index 0ac4f727c..5d206d47e 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt @@ -33,6 +33,7 @@ import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.letIf +// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime` /** Generates the `make_operation` function on input structs */ open class MakeOperationGenerator( protected val codegenContext: CodegenContext, diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt new file mode 100644 index 000000000..616325ee3 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolParserGenerator.kt @@ -0,0 +1,344 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators.protocol + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.generators.http.ResponseBindingGenerator +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.assignment +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection +import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation +import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.errorMessageMember +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.isStreaming +import software.amazon.smithy.rust.codegen.core.util.outputShape + +class ProtocolParserGenerator( + private val codegenContext: ClientCodegenContext, + private val protocol: Protocol, +) { + private val model = codegenContext.model + private val httpBindingResolver = protocol.httpBindingResolver + private val protocolFunctions = ProtocolFunctions(codegenContext) + private val symbolProvider: RustSymbolProvider = codegenContext.symbolProvider + + private val codegenScope = arrayOf( + "http" to RuntimeType.Http, + "operation" to RuntimeType.operationModule(codegenContext.runtimeConfig), + "Bytes" to RuntimeType.Bytes, + "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig), + // TODO(enableNewSmithyRuntime): Remove the `PropertyBag` below + "PropertyBag" to RuntimeType.smithyHttp(codegenContext.runtimeConfig).resolve("property_bag::PropertyBag"), + ) + + fun parseResponseFn(operationShape: OperationShape, customizations: List): RuntimeType { + val outputShape = operationShape.outputShape(model) + val outputSymbol = symbolProvider.toSymbol(outputShape) + val errorSymbol = symbolProvider.symbolForOperationError(operationShape) + return protocolFunctions.deserializeFn(operationShape, fnNameSuffix = "http_response") { fnName -> + Attribute.AllowClippyUnnecessaryWraps.render(this) + rustBlockTemplate( + "pub fn $fnName(_response_status: u16, _response_headers: &#{http}::header::HeaderMap, _response_body: &[u8]) -> std::result::Result<#{O}, #{E}>", + *codegenScope, + "O" to outputSymbol, + "E" to errorSymbol, + ) { + withBlock("Ok({", "})") { + renderShapeParser( + operationShape, + outputShape, + httpBindingResolver.responseBindings(operationShape), + errorSymbol, + customizations, + ) + } + } + } + } + + fun parseErrorFn(operationShape: OperationShape, customizations: List): RuntimeType { + val outputShape = operationShape.outputShape(model) + val outputSymbol = symbolProvider.toSymbol(outputShape) + val errorSymbol = symbolProvider.symbolForOperationError(operationShape) + return protocolFunctions.deserializeFn(operationShape, fnNameSuffix = "http_error") { fnName -> + Attribute.AllowClippyUnnecessaryWraps.render(this) + rustBlockTemplate( + "pub fn $fnName(_response_status: u16, _response_headers: &#{http}::header::HeaderMap, _response_body: &[u8]) -> std::result::Result<#{O}, #{E}>", + *codegenScope, + "O" to outputSymbol, + "E" to errorSymbol, + ) { + Attribute.AllowUnusedMut.render(this) + rust( + "let mut generic_builder = #T(_response_status, _response_headers, _response_body).map_err(#T::unhandled)?;", + protocol.parseHttpErrorMetadata(operationShape), + errorSymbol, + ) + writeCustomizations( + customizations, + OperationSection.PopulateErrorMetadataExtras( + customizations, + "generic_builder", + "_response_status", + "_response_headers", + ), + ) + rust("let generic = generic_builder.build();") + if (operationShape.operationErrors(model).isNotEmpty()) { + rustTemplate( + """ + let error_code = match generic.code() { + Some(code) => code, + None => return Err(#{error_symbol}::unhandled(generic)) + }; + + let _error_message = generic.message().map(|msg|msg.to_owned()); + """, + "error_symbol" to errorSymbol, + ) + withBlock("Err(match error_code {", "})") { + val errors = operationShape.operationErrors(model) + errors.forEach { error -> + val errorShape = model.expectShape(error.id, software.amazon.smithy.model.shapes.StructureShape::class.java) + val variantName = symbolProvider.toSymbol(model.expectShape(error.id)).name + val errorCode = httpBindingResolver.errorCode(errorShape).dq() + withBlock( + "$errorCode => #1T::$variantName({", + "}),", + errorSymbol, + ) { + software.amazon.smithy.rust.codegen.core.rustlang.Attribute.AllowUnusedMut.render(this) + assignment("mut tmp") { + rustBlock("") { + renderShapeParser( + operationShape, + errorShape, + httpBindingResolver.errorResponseBindings(errorShape), + errorSymbol, + listOf( + object : OperationCustomization() { + override fun section(section: OperationSection): Writable = { + if (section is OperationSection.MutateOutput) { + rust("let output = output.meta(generic);") + } + } + }, + ), + ) + } + } + if (errorShape.errorMessageMember() != null) { + rust( + """ + if tmp.message.is_none() { + tmp.message = _error_message; + } + """, + ) + } + rust("tmp") + } + } + rust("_ => #T::generic(generic)", errorSymbol) + } + } else { + rust("Err(#T::generic(generic))", errorSymbol) + } + } + } + } + + fun parseStreamingResponseFn( + operationShape: OperationShape, + // TODO(enableNewSmithyRuntime): Remove the `includeProperties` flag as if it were always set to `false` + includeProperties: Boolean, + customizations: List, + ): RuntimeType { + val outputShape = operationShape.outputShape(model) + val outputSymbol = symbolProvider.toSymbol(outputShape) + val errorSymbol = symbolProvider.symbolForOperationError(operationShape) + val fnNameSuffix = if (includeProperties) "http_response_with_props" else "http_response" + return protocolFunctions.deserializeFn(operationShape, fnNameSuffix = fnNameSuffix) { fnName -> + Attribute.AllowClippyUnnecessaryWraps.render(this) + val propertiesArg = if (includeProperties) { + Attribute.AllowUnusedVariables.render(this) + ", properties: &#{PropertyBag}" + } else { + "" + } + rustBlockTemplate( + "pub fn $fnName(response: &mut #{http}::Response<#{SdkBody}>$propertiesArg) -> std::result::Result<#{O}, #{E}>", + *codegenScope, + "O" to outputSymbol, + "E" to errorSymbol, + ) { + rustTemplate( + """ + let mut _response_body = #{SdkBody}::taken(); + std::mem::swap(&mut _response_body, response.body_mut()); + let _response_body = &mut _response_body; + + let _response_status = response.status().as_u16(); + let _response_headers = response.headers(); + """, + *codegenScope, + ) + withBlock("Ok({", "})") { + renderShapeParser( + operationShape, + outputShape, + httpBindingResolver.responseBindings(operationShape), + errorSymbol, + customizations, + ) + } + } + } + } + + private fun RustWriter.renderShapeParser( + operationShape: OperationShape, + outputShape: StructureShape, + bindings: List, + errorSymbol: Symbol, + customizations: List, + ) { + val httpBindingGenerator = ResponseBindingGenerator(protocol, codegenContext, operationShape) + val structuredDataParser = protocol.structuredDataParser(operationShape) + Attribute.AllowUnusedMut.render(this) + rust("let mut output = #T::default();", symbolProvider.symbolForBuilder(outputShape)) + if (outputShape.id == operationShape.output.get()) { + structuredDataParser.operationParser(operationShape)?.also { parser -> + rust( + "output = #T(_response_body, output).map_err(#T::unhandled)?;", + parser, + errorSymbol, + ) + } + } else { + check(outputShape.hasTrait()) { "should only be called on outputs or errors $outputShape" } + structuredDataParser.errorParser(outputShape)?.also { parser -> + rust( + "output = #T(_response_body, output).map_err(#T::unhandled)?;", + parser, errorSymbol, + ) + } + } + for (binding in bindings) { + val member = binding.member + val parsedValue = renderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) + if (parsedValue != null) { + withBlock("output = output.${member.setterName()}(", ");") { + parsedValue(this) + } + } + } + + val err = if (BuilderGenerator.hasFallibleBuilder(outputShape, symbolProvider)) { + ".map_err(${format(errorSymbol)}::unhandled)?" + } else { + "" + } + + writeCustomizations( + customizations, + OperationSection.MutateOutput(customizations, operationShape, "_response_headers"), + ) + + rust("output.build()$err") + } + + /** + * Generate a parser & a parsed value converter for each output member of `operationShape` + * + * Returns a map with key = memberName, value = parsedValue + */ + private fun renderBindingParser( + binding: HttpBindingDescriptor, + operationShape: OperationShape, + httpBindingGenerator: ResponseBindingGenerator, + structuredDataParser: StructuredDataParserGenerator, + ): Writable? { + val errorSymbol = symbolProvider.symbolForOperationError(operationShape) + val member = binding.member + return when (binding.location) { + HttpLocation.HEADER -> writable { + val fnName = httpBindingGenerator.generateDeserializeHeaderFn(binding) + rust( + """ + #T(_response_headers) + .map_err(|_|#T::unhandled("Failed to parse ${member.memberName} from header `${binding.locationName}"))? + """, + fnName, errorSymbol, + ) + } + HttpLocation.DOCUMENT -> { + // document is handled separately + null + } + HttpLocation.PAYLOAD -> { + val payloadParser: RustWriter.(String) -> Unit = { body -> + rust("#T($body).map_err(#T::unhandled)", structuredDataParser.payloadParser(member), errorSymbol) + } + val deserializer = httpBindingGenerator.generateDeserializePayloadFn( + binding, + errorSymbol, + payloadParser = payloadParser, + ) + return if (binding.member.isStreaming(model)) { + writable { rust("Some(#T(_response_body)?)", deserializer) } + } else { + writable { rust("#T(_response_body)?", deserializer) } + } + } + HttpLocation.RESPONSE_CODE -> writable { + rust("Some(_response_status as _)") + } + HttpLocation.PREFIX_HEADERS -> { + val sym = httpBindingGenerator.generateDeserializePrefixHeaderFn(binding) + writable { + rustTemplate( + """ + #{deser}(_response_headers) + .map_err(|_| + #{err}::unhandled("Failed to parse ${member.memberName} from prefix header `${binding.locationName}") + )? + """, + "deser" to sym, "err" to errorSymbol, + ) + } + } + else -> { + UNREACHABLE("Unexpected binding location: ${binding.location}") + } + } + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt new file mode 100644 index 000000000..74cac8c60 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/RequestSerializerGenerator.kt @@ -0,0 +1,143 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators.protocol + +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule +import software.amazon.smithy.rust.codegen.client.smithy.generators.http.RequestBindingGenerator +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation +import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.findStreamingMember +import software.amazon.smithy.rust.codegen.core.util.inputShape + +class RequestSerializerGenerator( + private val codegenContext: ClientCodegenContext, + private val protocol: Protocol, + private val bodyGenerator: ProtocolPayloadGenerator, +) { + private val httpBindingResolver = protocol.httpBindingResolver + private val symbolProvider = codegenContext.symbolProvider + private val codegenScope by lazy { + CargoDependency.smithyRuntimeApi(codegenContext.runtimeConfig).toType().let { runtimeApi -> + val interceptorContext = runtimeApi.resolve("client::interceptors::context") + val orchestrator = runtimeApi.resolve("client::orchestrator") + arrayOf( + "BoxError" to orchestrator.resolve("BoxError"), + "HttpRequest" to orchestrator.resolve("HttpRequest"), + "HttpRequestBuilder" to RuntimeType.HttpRequestBuilder, + "Input" to interceptorContext.resolve("Input"), + "RequestSerializer" to orchestrator.resolve("RequestSerializer"), + "SdkBody" to RuntimeType.sdkBody(codegenContext.runtimeConfig), + "TypedBox" to runtimeApi.resolve("type_erasure::TypedBox"), + "config" to ClientRustModule.Config, + "header_util" to RuntimeType.smithyHttp(codegenContext.runtimeConfig).resolve("header"), + "http" to RuntimeType.Http, + "operation" to RuntimeType.operationModule(codegenContext.runtimeConfig), + ) + } + } + + fun render(writer: RustWriter, operationShape: OperationShape, customizations: List) { + val inputShape = operationShape.inputShape(codegenContext.model) + val operationName = symbolProvider.toSymbol(operationShape).name + val inputSymbol = symbolProvider.toSymbol(inputShape) + writer.rustTemplate( + """ + ##[derive(Debug)] + struct ${operationName}RequestSerializer; + impl #{RequestSerializer} for ${operationName}RequestSerializer { + ##[allow(unused_mut, clippy::let_and_return, clippy::needless_borrow, clippy::useless_conversion)] + fn serialize_input(&self, input: #{Input}) -> Result<#{HttpRequest}, #{BoxError}> { + let input = #{TypedBox}::<#{ConcreteInput}>::assume_from(input).expect("correct type").unwrap(); + let mut request_builder = { + #{create_http_request} + }; + let body = #{generate_body}; + #{add_content_length} + Ok(request_builder.body(body).expect("valid request")) + } + } + """, + *codegenScope, + "ConcreteInput" to inputSymbol, + "create_http_request" to createHttpRequest(operationShape), + "generate_body" to writable { + val body = writable { bodyGenerator.generatePayload(this, "input", operationShape) } + val streamingMember = inputShape.findStreamingMember(codegenContext.model) + val isBlobStreaming = + streamingMember != null && codegenContext.model.expectShape(streamingMember.target) is BlobShape + if (isBlobStreaming) { + // Consume the `ByteStream` into its inner `SdkBody`. + rust("#T.into_inner()", body) + } else { + rustTemplate("#{SdkBody}::from(#{body})", *codegenScope, "body" to body) + } + }, + "add_content_length" to if (needsContentLength(operationShape)) { + writable { + rustTemplate( + """ + if let Some(content_length) = body.content_length() { + request_builder = #{header_util}::set_request_header_if_absent(request_builder, #{http}::header::CONTENT_LENGTH, content_length); + } + """, + *codegenScope, + ) + } + } else { + writable { } + }, + ) + } + + private fun needsContentLength(operationShape: OperationShape): Boolean { + return protocol.httpBindingResolver.requestBindings(operationShape) + .any { it.location == HttpLocation.DOCUMENT || it.location == HttpLocation.PAYLOAD } + } + + private fun createHttpRequest(operationShape: OperationShape): Writable = writable { + val httpBindingGenerator = RequestBindingGenerator( + codegenContext, + protocol, + operationShape, + ) + httpBindingGenerator.renderUpdateHttpBuilder(this) + val contentType = httpBindingResolver.requestContentType(operationShape) + + rust("let mut builder = update_http_builder(&input, #T::new())?;", RuntimeType.HttpRequestBuilder) + if (contentType != null) { + rustTemplate( + "builder = #{header_util}::set_request_header_if_absent(builder, #{http}::header::CONTENT_TYPE, ${contentType.dq()});", + *codegenScope, + ) + } + for (header in protocol.additionalRequestHeaders(operationShape)) { + rustTemplate( + """ + builder = #{header_util}::set_request_header_if_absent( + builder, + #{http}::header::HeaderName::from_static(${header.first.dq()}), + ${header.second.dq()} + ); + """, + *codegenScope, + ) + } + rust("builder") + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt new file mode 100644 index 000000000..81c0494c1 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ResponseDeserializerGenerator.kt @@ -0,0 +1,169 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators.protocol + +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection +import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations +import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions +import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember +import software.amazon.smithy.rust.codegen.core.util.outputShape + +class ResponseDeserializerGenerator( + codegenContext: ClientCodegenContext, + protocol: Protocol, +) { + private val symbolProvider = codegenContext.symbolProvider + private val model = codegenContext.model + private val runtimeConfig = codegenContext.runtimeConfig + private val httpBindingResolver = protocol.httpBindingResolver + private val parserGenerator = ProtocolParserGenerator(codegenContext, protocol) + + private val codegenScope by lazy { + val interceptorContext = + CargoDependency.smithyRuntimeApi(runtimeConfig).toType().resolve("client::interceptors::context") + val orchestrator = + CargoDependency.smithyRuntimeApi(runtimeConfig).toType().resolve("client::orchestrator") + arrayOf( + "Error" to interceptorContext.resolve("Error"), + "HttpResponse" to orchestrator.resolve("HttpResponse"), + "Instrument" to CargoDependency.Tracing.toType().resolve("Instrument"), + "Output" to interceptorContext.resolve("Output"), + "OutputOrError" to interceptorContext.resolve("OutputOrError"), + "ResponseDeserializer" to orchestrator.resolve("ResponseDeserializer"), + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + "SdkError" to RuntimeType.sdkError(runtimeConfig), + "TypedBox" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("type_erasure::TypedBox"), + "debug_span" to RuntimeType.Tracing.resolve("debug_span"), + "type_erase_result" to typeEraseResult(), + ) + } + + fun render(writer: RustWriter, operationShape: OperationShape, customizations: List) { + val outputSymbol = symbolProvider.toSymbol(operationShape.outputShape(model)) + val operationName = symbolProvider.toSymbol(operationShape).name + val streaming = operationShape.outputShape(model).hasStreamingMember(model) + + writer.rustTemplate( + """ + ##[derive(Debug)] + struct ${operationName}ResponseDeserializer; + impl #{ResponseDeserializer} for ${operationName}ResponseDeserializer { + #{deserialize_streaming} + + fn deserialize_nonstreaming(&self, response: &#{HttpResponse}) -> #{OutputOrError} { + #{deserialize_nonstreaming} + } + } + """, + *codegenScope, + "O" to outputSymbol, + "E" to symbolProvider.symbolForOperationError(operationShape), + "deserialize_streaming" to writable { + if (streaming) { + deserializeStreaming(operationShape, customizations) + } + }, + "deserialize_nonstreaming" to writable { + when (streaming) { + true -> deserializeStreamingError(operationShape, customizations) + else -> deserializeNonStreaming(operationShape, customizations) + } + }, + ) + } + + private fun RustWriter.deserializeStreaming( + operationShape: OperationShape, + customizations: List, + ) { + val successCode = httpBindingResolver.httpTrait(operationShape).code + rustTemplate( + """ + fn deserialize_streaming(&self, response: &mut #{HttpResponse}) -> Option<#{OutputOrError}> { + #{BeforeParseResponse} + + // If this is an error, defer to the non-streaming parser + if !response.status().is_success() && response.status().as_u16() != $successCode { + return None; + } + Some(#{type_erase_result}(#{parse_streaming_response}(response))) + } + """, + *codegenScope, + "parse_streaming_response" to parserGenerator.parseStreamingResponseFn(operationShape, false, customizations), + "BeforeParseResponse" to writable { + writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response")) + }, + ) + } + + private fun RustWriter.deserializeStreamingError( + operationShape: OperationShape, + customizations: List, + ) { + rustTemplate( + """ + // For streaming operations, we only hit this case if its an error + let body = response.body().bytes().expect("body loaded"); + #{type_erase_result}(#{parse_error}(response.status().as_u16(), response.headers(), body)) + """, + *codegenScope, + "parse_error" to parserGenerator.parseErrorFn(operationShape, customizations), + ) + } + + private fun RustWriter.deserializeNonStreaming( + operationShape: OperationShape, + customizations: List, + ) { + val successCode = httpBindingResolver.httpTrait(operationShape).code + rustTemplate( + """ + let (success, status) = (response.status().is_success(), response.status().as_u16()); + let headers = response.headers(); + let body = response.body().bytes().expect("body loaded"); + #{BeforeParseResponse} + let parse_result = if !success && status != $successCode { + #{parse_error}(status, headers, body) + } else { + #{parse_response}(status, headers, body) + }; + #{type_erase_result}(parse_result) + """, + *codegenScope, + "parse_error" to parserGenerator.parseErrorFn(operationShape, customizations), + "parse_response" to parserGenerator.parseResponseFn(operationShape, customizations), + "BeforeParseResponse" to writable { + writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response")) + }, + ) + } + + private fun typeEraseResult(): RuntimeType = ProtocolFunctions.crossOperationFn("type_erase_result") { fnName -> + rustTemplate( + """ + pub(crate) fn $fnName(result: Result) -> Result<#{Output}, #{Error}> + where + O: Send + Sync + 'static, + E: Send + Sync + 'static, + { + result.map(|output| #{TypedBox}::new(output).erase()) + .map_err(|error| #{TypedBox}::new(error).erase()) + } + """, + *codegenScope, + ) + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index 76fe6eefa..45ea10824 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -7,69 +7,57 @@ package software.amazon.smithy.rust.codegen.client.smithy.protocols import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.client.smithy.generators.http.ResponseBindingGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.MakeOperationGenerator +import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolParserGenerator import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Writable -import software.amazon.smithy.rust.codegen.core.rustlang.assignment import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations -import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName -import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor +import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator -import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions -import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator -import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors -import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE -import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.core.util.errorMessageMember import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember -import software.amazon.smithy.rust.codegen.core.util.hasTrait -import software.amazon.smithy.rust.codegen.core.util.isStreaming import software.amazon.smithy.rust.codegen.core.util.outputShape +// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime` class HttpBoundProtocolGenerator( codegenContext: ClientCodegenContext, protocol: Protocol, + bodyGenerator: ProtocolPayloadGenerator = HttpBoundProtocolPayloadGenerator(codegenContext, protocol), ) : ClientProtocolGenerator( codegenContext, protocol, MakeOperationGenerator( codegenContext, protocol, - HttpBoundProtocolPayloadGenerator(codegenContext, protocol), + bodyGenerator, public = true, includeDefaultPayloadHeaders = true, ), + bodyGenerator, HttpBoundProtocolTraitImplGenerator(codegenContext, protocol), ) +// TODO(enableNewSmithyRuntime): Delete this class when cleaning up `enableNewSmithyRuntime` open class HttpBoundProtocolTraitImplGenerator( - private val codegenContext: ClientCodegenContext, - private val protocol: Protocol, + codegenContext: ClientCodegenContext, + protocol: Protocol, ) { private val symbolProvider = codegenContext.symbolProvider private val model = codegenContext.model private val runtimeConfig = codegenContext.runtimeConfig private val httpBindingResolver = protocol.httpBindingResolver private val protocolFunctions = ProtocolFunctions(codegenContext) + private val parserGenerator = ProtocolParserGenerator(codegenContext, protocol) private val codegenScope = arrayOf( "ParseStrict" to RuntimeType.parseStrictResponse(runtimeConfig), @@ -79,25 +67,6 @@ open class HttpBoundProtocolTraitImplGenerator( "Bytes" to RuntimeType.Bytes, "SdkBody" to RuntimeType.sdkBody(runtimeConfig), ) - private val orchestratorCodegenScope by lazy { - val interceptorContext = - CargoDependency.smithyRuntimeApi(runtimeConfig).toType().resolve("client::interceptors::context") - val orchestrator = - CargoDependency.smithyRuntimeApi(runtimeConfig).toType().resolve("client::orchestrator") - arrayOf( - "Error" to interceptorContext.resolve("Error"), - "HttpResponse" to orchestrator.resolve("HttpResponse"), - "Instrument" to CargoDependency.Tracing.toType().resolve("Instrument"), - "Output" to interceptorContext.resolve("Output"), - "OutputOrError" to interceptorContext.resolve("OutputOrError"), - "ResponseDeserializer" to orchestrator.resolve("ResponseDeserializer"), - "SdkBody" to RuntimeType.sdkBody(runtimeConfig), - "SdkError" to RuntimeType.sdkError(runtimeConfig), - "TypedBox" to CargoDependency.smithyRuntimeApi(runtimeConfig).toType().resolve("type_erasure::TypedBox"), - "debug_span" to RuntimeType.Tracing.resolve("debug_span"), - "type_erase_result" to typeEraseResult(), - ) - } open fun generateTraitImpls( operationWriter: RustWriter, @@ -116,130 +85,8 @@ open class HttpBoundProtocolTraitImplGenerator( } else { operationWriter.renderNonStreamingTraits(operationName, outputSymbol, operationShape, customizations) } - - if (codegenContext.settings.codegenConfig.enableNewSmithyRuntime) { - operationWriter.renderRuntimeTraits(operationName, outputSymbol, operationShape, customizations, streaming) - } - } - - private fun typeEraseResult(): RuntimeType = ProtocolFunctions.crossOperationFn("type_erase_result") { fnName -> - rustTemplate( - """ - pub(crate) fn $fnName(result: Result) -> Result<#{Output}, #{Error}> - where - O: Send + Sync + 'static, - E: Send + Sync + 'static, - { - result.map(|output| #{TypedBox}::new(output).erase()) - .map_err(|error| #{TypedBox}::new(error).erase()) - } - """, - *orchestratorCodegenScope, - ) - } - - private fun RustWriter.renderRuntimeTraits( - operationName: String?, - outputSymbol: Symbol, - operationShape: OperationShape, - customizations: List, - streaming: Boolean, - ) { - rustTemplate( - """ - impl #{ResponseDeserializer} for $operationName { - #{deserialize_streaming} - - fn deserialize_nonstreaming(&self, response: &#{HttpResponse}) -> #{OutputOrError} { - #{deserialize_nonstreaming} - } - } - """, - *orchestratorCodegenScope, - "O" to outputSymbol, - "E" to symbolProvider.symbolForOperationError(operationShape), - "deserialize_streaming" to writable { - if (streaming) { - deserializeStreaming(operationShape, customizations) - } - }, - "deserialize_nonstreaming" to writable { - when (streaming) { - true -> deserializeStreamingError(operationShape, customizations) - else -> deserializeNonStreaming(operationShape, customizations) - } - }, - ) - } - - private fun RustWriter.deserializeStreaming( - operationShape: OperationShape, - customizations: List, - ) { - val successCode = httpBindingResolver.httpTrait(operationShape).code - rustTemplate( - """ - fn deserialize_streaming(&self, response: &mut #{HttpResponse}) -> Option<#{OutputOrError}> { - #{BeforeParseResponse} - - // If this is an error, defer to the non-streaming parser - if !response.status().is_success() && response.status().as_u16() != $successCode { - return None; - } - Some(#{type_erase_result}(#{parse_streaming_response}(response))) - } - """, - *orchestratorCodegenScope, - "parse_streaming_response" to parseStreamingResponse(operationShape, customizations), - "BeforeParseResponse" to writable { - writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response")) - }, - ) } - private fun RustWriter.deserializeStreamingError( - operationShape: OperationShape, - customizations: List, - ) { - rustTemplate( - """ - // For streaming operations, we only hit this case if its an error - let body = response.body().bytes().expect("body loaded"); - #{type_erase_result}(#{parse_error}(response.status().as_u16(), response.headers(), body)) - """, - *orchestratorCodegenScope, - "parse_error" to parseError(operationShape, customizations), - ) - } - - private fun RustWriter.deserializeNonStreaming( - operationShape: OperationShape, - customizations: List, - ) { - val successCode = httpBindingResolver.httpTrait(operationShape).code - rustTemplate( - """ - let (success, status) = (response.status().is_success(), response.status().as_u16()); - let headers = response.headers(); - let body = response.body().bytes().expect("body loaded"); - #{BeforeParseResponse} - let parse_result = if !success && status != $successCode { - #{parse_error}(status, headers, body) - } else { - #{parse_response}(status, headers, body) - }; - #{type_erase_result}(parse_result) - """, - *orchestratorCodegenScope, - "parse_error" to parseError(operationShape, customizations), - "parse_response" to parseResponse(operationShape, customizations), - "BeforeParseResponse" to writable { - writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response")) - }, - ) - } - - // TODO(enableNewSmithyRuntime): Delete this when cleaning up `enableNewSmithyRuntime` private fun RustWriter.renderNonStreamingTraits( operationName: String?, outputSymbol: Symbol, @@ -250,8 +97,8 @@ open class HttpBoundProtocolTraitImplGenerator( val localScope = arrayOf( "O" to outputSymbol, "E" to symbolProvider.symbolForOperationError(operationShape), - "parse_error" to parseError(operationShape, customizations), - "parse_response" to parseResponse(operationShape, customizations), + "parse_error" to parserGenerator.parseErrorFn(operationShape, customizations), + "parse_response" to parserGenerator.parseResponseFn(operationShape, customizations), "BeforeParseResponse" to writable { writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response")) }, @@ -277,7 +124,6 @@ open class HttpBoundProtocolTraitImplGenerator( ) } - // TODO(enableNewSmithyRuntime): Delete this when cleaning up `enableNewSmithyRuntime` private fun RustWriter.renderStreamingTraits( operationName: String, outputSymbol: Symbol, @@ -305,8 +151,8 @@ open class HttpBoundProtocolTraitImplGenerator( """, "O" to outputSymbol, "E" to symbolProvider.symbolForOperationError(operationShape), - "parse_streaming_response" to parseStreamingResponseNoRt(operationShape, customizations), - "parse_error" to parseError(operationShape, customizations), + "parse_streaming_response" to parseStreamingResponse(operationShape, customizations), + "parse_error" to parserGenerator.parseErrorFn(operationShape, customizations), "BeforeParseResponse" to writable { writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response")) }, @@ -314,139 +160,11 @@ open class HttpBoundProtocolTraitImplGenerator( ) } - private fun parseError(operationShape: OperationShape, customizations: List): RuntimeType { - val outputShape = operationShape.outputShape(model) - val outputSymbol = symbolProvider.toSymbol(outputShape) - val errorSymbol = symbolProvider.symbolForOperationError(operationShape) - return protocolFunctions.deserializeFn(operationShape, fnNameSuffix = "http_error") { fnName -> - Attribute.AllowClippyUnnecessaryWraps.render(this) - rustBlockTemplate( - "pub fn $fnName(_response_status: u16, _response_headers: &#{http}::header::HeaderMap, _response_body: &[u8]) -> std::result::Result<#{O}, #{E}>", - *codegenScope, - "O" to outputSymbol, - "E" to errorSymbol, - ) { - Attribute.AllowUnusedMut.render(this) - rust( - "let mut generic_builder = #T(_response_status, _response_headers, _response_body).map_err(#T::unhandled)?;", - protocol.parseHttpErrorMetadata(operationShape), - errorSymbol, - ) - writeCustomizations( - customizations, - OperationSection.PopulateErrorMetadataExtras( - customizations, - "generic_builder", - "_response_status", - "_response_headers", - ), - ) - rust("let generic = generic_builder.build();") - if (operationShape.operationErrors(model).isNotEmpty()) { - rustTemplate( - """ - let error_code = match generic.code() { - Some(code) => code, - None => return Err(#{error_symbol}::unhandled(generic)) - }; - - let _error_message = generic.message().map(|msg|msg.to_owned()); - """, - "error_symbol" to errorSymbol, - ) - withBlock("Err(match error_code {", "})") { - val errors = operationShape.operationErrors(model) - errors.forEach { error -> - val errorShape = model.expectShape(error.id, StructureShape::class.java) - val variantName = symbolProvider.toSymbol(model.expectShape(error.id)).name - val errorCode = httpBindingResolver.errorCode(errorShape).dq() - withBlock( - "$errorCode => #1T::$variantName({", - "}),", - errorSymbol, - ) { - Attribute.AllowUnusedMut.render(this) - assignment("mut tmp") { - rustBlock("") { - renderShapeParser( - operationShape, - errorShape, - httpBindingResolver.errorResponseBindings(errorShape), - errorSymbol, - listOf(object : OperationCustomization() { - override fun section(section: OperationSection): Writable = writable { - if (section is OperationSection.MutateOutput) { - rust("let output = output.meta(generic);") - } - } - }, - ), - ) - } - } - if (errorShape.errorMessageMember() != null) { - rust( - """ - if tmp.message.is_none() { - tmp.message = _error_message; - } - """, - ) - } - rust("tmp") - } - } - rust("_ => #T::generic(generic)", errorSymbol) - } - } else { - rust("Err(#T::generic(generic))", errorSymbol) - } - } - } - } - private fun parseStreamingResponse(operationShape: OperationShape, customizations: List): RuntimeType { val outputShape = operationShape.outputShape(model) val outputSymbol = symbolProvider.toSymbol(outputShape) val errorSymbol = symbolProvider.symbolForOperationError(operationShape) - return protocolFunctions.deserializeFn(operationShape, fnNameSuffix = "http_response") { fnName -> - Attribute.AllowClippyUnnecessaryWraps.render(this) - rustBlockTemplate( - "pub fn $fnName(response: &mut #{http}::Response<#{SdkBody}>) -> std::result::Result<#{O}, #{E}>", - *codegenScope, - "O" to outputSymbol, - "E" to errorSymbol, - ) { - rustTemplate( - """ - let mut _response_body = #{SdkBody}::taken(); - std::mem::swap(&mut _response_body, response.body_mut()); - let _response_body = &mut _response_body; - - let _response_status = response.status().as_u16(); - let _response_headers = response.headers(); - """, - *codegenScope, - ) - withBlock("Ok({", "})") { - renderShapeParser( - operationShape, - outputShape, - httpBindingResolver.responseBindings(operationShape), - errorSymbol, - customizations, - ) - } - } - } - } - - // TODO(enableNewSmithyRuntime): Delete this when cleaning up `enableNewSmithyRuntime` - private fun parseStreamingResponseNoRt(operationShape: OperationShape, customizations: List): RuntimeType { - val outputShape = operationShape.outputShape(model) - val outputSymbol = symbolProvider.toSymbol(outputShape) - val errorSymbol = symbolProvider.symbolForOperationError(operationShape) - return protocolFunctions.deserializeFn(operationShape, fnNameSuffix = "http_response_") { fnName -> + return protocolFunctions.deserializeFn(operationShape, fnNameSuffix = "op_response") { fnName -> Attribute.AllowClippyUnnecessaryWraps.render(this) rustBlockTemplate( "pub fn $fnName(op_response: &mut #{operation}::Response) -> std::result::Result<#{O}, #{E}>", @@ -459,167 +177,10 @@ open class HttpBoundProtocolTraitImplGenerator( rust("let (response, properties) = op_response.parts_mut();") rustTemplate( """ - let mut _response_body = #{SdkBody}::taken(); - std::mem::swap(&mut _response_body, response.body_mut()); - let _response_body = &mut _response_body; - - let _response_status = response.status().as_u16(); - let _response_headers = response.headers(); - """, - *codegenScope, - ) - withBlock("Ok({", "})") { - renderShapeParser( - operationShape, - outputShape, - httpBindingResolver.responseBindings(operationShape), - errorSymbol, - customizations, - ) - } - } - } - } - - private fun parseResponse(operationShape: OperationShape, customizations: List): RuntimeType { - val outputShape = operationShape.outputShape(model) - val outputSymbol = symbolProvider.toSymbol(outputShape) - val errorSymbol = symbolProvider.symbolForOperationError(operationShape) - return protocolFunctions.deserializeFn(operationShape, fnNameSuffix = "http_response") { fnName -> - Attribute.AllowClippyUnnecessaryWraps.render(this) - rustBlockTemplate( - "pub fn $fnName(_response_status: u16, _response_headers: &#{http}::header::HeaderMap, _response_body: &[u8]) -> std::result::Result<#{O}, #{E}>", - *codegenScope, - "O" to outputSymbol, - "E" to errorSymbol, - ) { - withBlock("Ok({", "})") { - renderShapeParser( - operationShape, - outputShape, - httpBindingResolver.responseBindings(operationShape), - errorSymbol, - customizations, - ) - } - } - } - } - - private fun RustWriter.renderShapeParser( - operationShape: OperationShape, - outputShape: StructureShape, - bindings: List, - errorSymbol: Symbol, - customizations: List, - ) { - val httpBindingGenerator = ResponseBindingGenerator(protocol, codegenContext, operationShape) - val structuredDataParser = protocol.structuredDataParser(operationShape) - Attribute.AllowUnusedMut.render(this) - rust("let mut output = #T::default();", symbolProvider.symbolForBuilder(outputShape)) - if (outputShape.id == operationShape.output.get()) { - structuredDataParser.operationParser(operationShape)?.also { parser -> - rust( - "output = #T(_response_body, output).map_err(#T::unhandled)?;", - parser, - errorSymbol, - ) - } - } else { - check(outputShape.hasTrait()) { "should only be called on outputs or errors $outputShape" } - structuredDataParser.errorParser(outputShape)?.also { parser -> - rust( - "output = #T(_response_body, output).map_err(#T::unhandled)?;", - parser, errorSymbol, - ) - } - } - for (binding in bindings) { - val member = binding.member - val parsedValue = renderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) - if (parsedValue != null) { - withBlock("output = output.${member.setterName()}(", ");") { - parsedValue(this) - } - } - } - - val err = if (BuilderGenerator.hasFallibleBuilder(outputShape, symbolProvider)) { - ".map_err(${format(errorSymbol)}::unhandled)?" - } else { - "" - } - - writeCustomizations( - customizations, - OperationSection.MutateOutput(customizations, operationShape, "_response_headers"), - ) - - rust("output.build()$err") - } - - /** - * Generate a parser & a parsed value converter for each output member of `operationShape` - * - * Returns a map with key = memberName, value = parsedValue - */ - private fun renderBindingParser( - binding: HttpBindingDescriptor, - operationShape: OperationShape, - httpBindingGenerator: ResponseBindingGenerator, - structuredDataParser: StructuredDataParserGenerator, - ): Writable? { - val errorSymbol = symbolProvider.symbolForOperationError(operationShape) - val member = binding.member - return when (binding.location) { - HttpLocation.HEADER -> writable { - val fnName = httpBindingGenerator.generateDeserializeHeaderFn(binding) - rust( - """ - #T(_response_headers) - .map_err(|_|#T::unhandled("Failed to parse ${member.memberName} from header `${binding.locationName}"))? + #{parse_streaming_response}(response, &properties) """, - fnName, errorSymbol, - ) - } - HttpLocation.DOCUMENT -> { - // document is handled separately - null - } - HttpLocation.PAYLOAD -> { - val payloadParser: RustWriter.(String) -> Unit = { body -> - rust("#T($body).map_err(#T::unhandled)", structuredDataParser.payloadParser(member), errorSymbol) - } - val deserializer = httpBindingGenerator.generateDeserializePayloadFn( - binding, - errorSymbol, - payloadParser = payloadParser, + "parse_streaming_response" to parserGenerator.parseStreamingResponseFn(operationShape, true, customizations), ) - return if (binding.member.isStreaming(model)) { - writable { rust("Some(#T(_response_body)?)", deserializer) } - } else { - writable { rust("#T(_response_body)?", deserializer) } - } - } - HttpLocation.RESPONSE_CODE -> writable { - rust("Some(_response_status as _)") - } - HttpLocation.PREFIX_HEADERS -> { - val sym = httpBindingGenerator.generateDeserializePrefixHeaderFn(binding) - writable { - rustTemplate( - """ - #{deser}(_response_headers) - .map_err(|_| - #{err}::unhandled("Failed to parse ${member.memberName} from prefix header `${binding.locationName}") - )? - """, - "deser" to sym, "err" to errorSymbol, - ) - } - } - else -> { - UNREACHABLE("Unexpected binding location: ${binding.location}") } } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt index 060e636c3..d0d271231 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt @@ -101,6 +101,7 @@ private class TestProtocolGenerator( codegenContext, protocol, TestProtocolMakeOperationGenerator(codegenContext, protocol, body, httpRequestBuilder), + TestProtocolPayloadGenerator(body), TestProtocolTraitImplGenerator(codegenContext, correctResponse), ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index 428c76656..490aab7c2 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -256,6 +256,8 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) fun smithyHttpTower(runtimeConfig: RuntimeConfig) = CargoDependency.smithyHttpTower(runtimeConfig).toType() fun smithyJson(runtimeConfig: RuntimeConfig) = CargoDependency.smithyJson(runtimeConfig).toType() fun smithyQuery(runtimeConfig: RuntimeConfig) = CargoDependency.smithyQuery(runtimeConfig).toType() + fun smithyRuntime(runtimeConfig: RuntimeConfig) = CargoDependency.smithyRuntime(runtimeConfig).toType() + fun smithyRuntimeApi(runtimeConfig: RuntimeConfig) = CargoDependency.smithyRuntimeApi(runtimeConfig).toType() fun smithyTypes(runtimeConfig: RuntimeConfig) = CargoDependency.smithyTypes(runtimeConfig).toType() fun smithyXml(runtimeConfig: RuntimeConfig) = CargoDependency.smithyXml(runtimeConfig).toType() private fun smithyProtocolTest(runtimeConfig: RuntimeConfig) = diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context.rs b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context.rs index fc1065e05..4a3cb7bfe 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/context.rs @@ -13,7 +13,7 @@ pub type OutputOrError = Result; /// A container for the data currently available to an interceptor. pub struct InterceptorContext { - input: Input, + input: Option, output_or_error: Option, request: Option, response: Option, @@ -24,21 +24,30 @@ pub struct InterceptorContext { impl InterceptorContext { pub fn new(input: Input) -> Self { Self { - input, + input: Some(input), output_or_error: None, request: None, response: None, } } - /// Retrieve the modeled request for the operation being invoked. - pub fn input(&self) -> &Input { - &self.input + /// Retrieve the input for the operation being invoked. + pub fn input(&self) -> Result<&Input, InterceptorError> { + self.input + .as_ref() + .ok_or_else(InterceptorError::invalid_input_access) + } + + /// Retrieve the input for the operation being invoked. + pub fn input_mut(&mut self) -> Result<&mut Input, InterceptorError> { + self.input + .as_mut() + .ok_or_else(InterceptorError::invalid_input_access) } - /// Retrieve the modeled request for the operation being invoked. - pub fn input_mut(&mut self) -> &mut Input { - &mut self.input + /// Takes ownership of the input. + pub fn take_input(&mut self) -> Option { + self.input.take() } /// Retrieve the transmittable request for the operation being invoked. @@ -80,7 +89,7 @@ impl InterceptorContext { pub fn output_or_error(&self) -> Result, InterceptorError> { self.output_or_error .as_ref() - .ok_or_else(InterceptorError::invalid_modeled_response_access) + .ok_or_else(InterceptorError::invalid_output_access) .map(|res| res.as_ref()) } @@ -90,10 +99,10 @@ impl InterceptorContext { pub fn output_or_error_mut(&mut self) -> Result<&mut Result, InterceptorError> { self.output_or_error .as_mut() - .ok_or_else(InterceptorError::invalid_modeled_response_access) + .ok_or_else(InterceptorError::invalid_output_access) } - // There is no set_modeled_request method because that can only be set once, during context construction + // There is no set_input method because that can only be set once, during context construction pub fn set_request(&mut self, request: Request) { if self.request.is_some() { @@ -125,7 +134,7 @@ impl InterceptorContext { pub fn into_parts( self, ) -> ( - Input, + Option, Option, Option, Option, diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/error.rs b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/error.rs index 9fba28092..fc3002771 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/error.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors/error.rs @@ -189,24 +189,31 @@ impl InterceptorError { source: Some(source.into()), } } - /// Create a new error indicating that an interceptor tried to access the tx_request out of turn + /// Create a new error indicating that an interceptor tried to access the request out of turn pub fn invalid_request_access() -> Self { Self { kind: ErrorKind::InvalidRequestAccess, source: None, } } - /// Create a new error indicating that an interceptor tried to access the tx_response out of turn + /// Create a new error indicating that an interceptor tried to access the response out of turn pub fn invalid_response_access() -> Self { Self { kind: ErrorKind::InvalidResponseAccess, source: None, } } - /// Create a new error indicating that an interceptor tried to access the modeled_response out of turn - pub fn invalid_modeled_response_access() -> Self { + /// Create a new error indicating that an interceptor tried to access the input out of turn + pub fn invalid_input_access() -> Self { Self { - kind: ErrorKind::InvalidModeledResponseAccess, + kind: ErrorKind::InvalidInputAccess, + source: None, + } + } + /// Create a new error indicating that an interceptor tried to access the output out of turn + pub fn invalid_output_access() -> Self { + Self { + kind: ErrorKind::InvalidOutputAccess, source: None, } } @@ -257,8 +264,10 @@ enum ErrorKind { InvalidRequestAccess, /// An interceptor tried to access the response out of turn InvalidResponseAccess, - /// An interceptor tried to access the modeled_response out of turn - InvalidModeledResponseAccess, + /// An interceptor tried to access the input out of turn + InvalidInputAccess, + /// An interceptor tried to access the output out of turn + InvalidOutputAccess, } impl fmt::Display for InterceptorError { @@ -327,9 +336,13 @@ impl fmt::Display for InterceptorError { InvalidResponseAccess => { write!(f, "tried to access response before transmitting a request") } - InvalidModeledResponseAccess => write!( + InvalidInputAccess => write!( + f, + "tried to access the input before response deserialization" + ), + InvalidOutputAccess => write!( f, - "tried to access modeled_response before response deserialization" + "tried to access the output before response deserialization" ), } } diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs index c70bcb8c1..e38b28485 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs @@ -22,11 +22,11 @@ pub type BoxError = Box; pub type BoxFallibleFut = Pin>>>; pub trait TraceProbe: Send + Sync + Debug { - fn dispatch_events(&self, cfg: &ConfigBag) -> BoxFallibleFut<()>; + fn dispatch_events(&self) -> BoxFallibleFut<()>; } pub trait RequestSerializer: Send + Sync + Debug { - fn serialize_input(&self, input: &Input, cfg: &ConfigBag) -> Result; + fn serialize_input(&self, input: Input) -> Result; } pub trait ResponseDeserializer: Send + Sync + Debug { @@ -161,11 +161,7 @@ pub trait HttpRequestSigner: Send + Sync + Debug { } pub trait EndpointResolver: Send + Sync + Debug { - fn resolve_and_apply_endpoint( - &self, - request: &mut HttpRequest, - cfg: &ConfigBag, - ) -> Result<(), BoxError>; + fn resolve_and_apply_endpoint(&self, request: &mut HttpRequest) -> Result<(), BoxError>; } pub trait ConfigBagAccessors { diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/runtime_plugin.rs b/rust-runtime/aws-smithy-runtime-api/src/client/runtime_plugin.rs index f29de78e0..8b262a453 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/runtime_plugin.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/runtime_plugin.rs @@ -5,7 +5,7 @@ use crate::config_bag::ConfigBag; -type BoxError = Box; +pub type BoxError = Box; pub trait RuntimePlugin { fn configure(&self, cfg: &mut ConfigBag) -> Result<(), BoxError>; @@ -32,17 +32,17 @@ impl RuntimePlugins { } pub fn with_client_plugin( - &mut self, + mut self, plugin: impl Into>, - ) -> &mut Self { + ) -> Self { self.client_plugins.push(plugin.into()); self } pub fn with_operation_plugin( - &mut self, + mut self, plugin: impl Into>, - ) -> &mut Self { + ) -> Self { self.operation_plugins.push(plugin.into()); self } @@ -79,7 +79,6 @@ mod tests { #[test] fn can_add_runtime_plugin_implementors_to_runtime_plugins() { - let mut rps = RuntimePlugins::new(); - rps.with_client_plugin(SomeStruct); + RuntimePlugins::new().with_client_plugin(SomeStruct); } } diff --git a/rust-runtime/aws-smithy-runtime-api/src/type_erasure.rs b/rust-runtime/aws-smithy-runtime-api/src/type_erasure.rs index 72de2ccbb..50c852f0c 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/type_erasure.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/type_erasure.rs @@ -76,6 +76,33 @@ impl DerefMut for TypedBox { } } +#[derive(Debug)] +pub struct TypedRef<'a, T> { + inner: &'a TypeErasedBox, + _phantom: PhantomData, +} + +impl<'a, T: 'static> TypedRef<'a, T> { + pub fn assume_from(type_erased: &'a TypeErasedBox) -> Option> { + if type_erased.downcast_ref::().is_some() { + Some(TypedRef { + inner: type_erased, + _phantom: Default::default(), + }) + } else { + None + } + } +} + +impl<'a, T: 'static> Deref for TypedRef<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.inner.downcast_ref().expect("type checked") + } +} + /// A new-type around `Box` #[derive(Debug)] pub struct TypeErasedBox { diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs index cdecbcc15..d97ca1cdf 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs @@ -22,10 +22,13 @@ pub(self) mod phase; pub async fn invoke( input: Input, - interceptors: &mut Interceptors, runtime_plugins: &RuntimePlugins, - cfg: &mut ConfigBag, ) -> Result> { + let mut cfg = ConfigBag::base(); + let cfg = &mut cfg; + let mut interceptors = Interceptors::new(); + let interceptors = &mut interceptors; + let context = Phase::construction(InterceptorContext::new(input)) // Client configuration .include(|_| runtime_plugins.apply_client_configuration(cfg))? @@ -39,7 +42,8 @@ pub async fn invoke( // Serialization .include_mut(|ctx| { let request_serializer = cfg.request_serializer(); - let request = request_serializer.serialize_input(ctx.input(), cfg)?; + let request = request_serializer + .serialize_input(ctx.take_input().expect("input set at this point"))?; ctx.set_request(request); Result::<(), BoxError>::Ok(()) })? @@ -83,7 +87,7 @@ pub async fn invoke( let handling_phase = Phase::response_handling(context) .include_mut(|ctx| interceptors.modify_before_completion(ctx, cfg))?; let trace_probe = cfg.trace_probe(); - trace_probe.dispatch_events(cfg); + trace_probe.dispatch_events(); break handling_phase.include(|ctx| interceptors.read_after_execution(ctx, cfg))?; }; @@ -105,7 +109,7 @@ async fn make_an_attempt( let request = ctx.request_mut().expect("request has been set"); let endpoint_resolver = cfg.endpoint_resolver(); - endpoint_resolver.resolve_and_apply_endpoint(request, cfg) + endpoint_resolver.resolve_and_apply_endpoint(request) })? .include_mut(|ctx| interceptors.modify_before_signing(ctx, cfg))? .include(|ctx| interceptors.read_before_signing(ctx, cfg))?; -- GitLab