diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/FluentClientGenerator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/FluentClientGenerator.kt index 484124e16ba1be6cf446e6200cfe877c683334b6..d9c52e1b91041911d7ce10e87942622a43b2c2f9 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/FluentClientGenerator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/FluentClientGenerator.kt @@ -79,6 +79,10 @@ class FluentClientGenerator(protocolConfig: ProtocolConfig) { Self { handle: std::sync::Arc::new(Handle { conf, client })} } + pub fn conf(&self) -> &crate::Config { + &self.handle.conf + } + """, "aws_hyper" to hyperDep.asType() ) @@ -115,7 +119,7 @@ class FluentClientGenerator(protocolConfig: ProtocolConfig) { } pub async fn send(self) -> Result<#{ok}, #{sdk_err}<#{operation_err}>> { - let op = self.inner.build(&self.handle.conf); + let op = self.inner.build(&self.handle.conf).map_err(|err|#{sdk_err}::ConstructionFailure(err.into()))?; self.handle.client.call(op).await } """, diff --git a/aws/sdk/build.gradle.kts b/aws/sdk/build.gradle.kts index 222be6354725f6e194792a52b86832620d22a867..245f5b8041a9edcd034310de7a842c698c0d7c9b 100644 --- a/aws/sdk/build.gradle.kts +++ b/aws/sdk/build.gradle.kts @@ -164,8 +164,7 @@ fun generateCargoWorkspace(services: List): String { val examples = projectDir.resolve("examples").listFiles { file -> !file.name.startsWith(".") }?.toList() ?.map { "examples/${it.name}" }.orEmpty() - val modules = services.map(AwsService::module) + runtimeModules + awsModules + examples - ?.toList() + val modules = services.map(AwsService::module) + runtimeModules + awsModules + examples.toList() return """ [workspace] members = [ diff --git a/aws/sdk/examples/dynamo-movies/src/main.rs b/aws/sdk/examples/dynamo-movies/src/main.rs index 777753ecd6b6edf4772e1068c18903ba8553b716..d1983b2d5594e733363ba6b07cde855714ba5ade 100644 --- a/aws/sdk/examples/dynamo-movies/src/main.rs +++ b/aws/sdk/examples/dynamo-movies/src/main.rs @@ -6,15 +6,14 @@ use aws_http::AwsErrorRetryPolicy; use aws_hyper::{SdkError, SdkSuccess}; use dynamodb::error::DescribeTableError; -use dynamodb::input::{ - create_table_input, put_item_input, query_input, DescribeTableInput, ListTablesInput, - PutItemInput, QueryInput, -}; +use dynamodb::fluent::fluent_builders::Query; +use dynamodb::fluent::Client; +use dynamodb::input::DescribeTableInput; use dynamodb::model::{ AttributeDefinition, AttributeValue, KeySchemaElement, KeyType, ProvisionedThroughput, ScalarAttributeType, TableStatus, }; -use dynamodb::operation::{CreateTable, DescribeTable}; +use dynamodb::operation::DescribeTable; use dynamodb::output::DescribeTableOutput; use dynamodb::{Config, Region}; use serde_json::Value; @@ -34,13 +33,16 @@ use std::time::Duration; #[tokio::main] async fn main() { let table_name = "dynamo-movies-example"; - let client = aws_hyper::Client::https(); let conf = dynamodb::Config::builder() .region(Region::new("us-east-1")) .build(); + let conn = aws_hyper::conn::Standard::https(); + let client = dynamodb::fluent::Client::from_conf_conn(conf, conn); + let raw_client = aws_hyper::Client::https(); let table_exists = client - .call(ListTablesInput::builder().build(&conf)) + .list_tables() + .send() .await .expect("should succeed") .table_names @@ -49,14 +51,14 @@ async fn main() { .contains(&table_name.to_string()); if !table_exists { - client - .call(create_table(table_name).build(&conf)) + create_table(&client, table_name) + .send() .await .expect("failed to create table"); } - client - .call(wait_for_ready_table(table_name, &conf)) + raw_client + .call(wait_for_ready_table(table_name, client.conf())) .await .expect("table should become ready"); @@ -66,21 +68,24 @@ async fn main() { Value::Array(inner) => inner, data => panic!("data must be an array, got: {:?}", data), }; - for item in data { + for value in data { client - .call(add_item(table_name, item).build(&conf)) + .put_item() + .table_name(table_name) + .item(parse_item(value)) + .send() .await .expect("failed to insert item"); } - let films_2222 = client - .call(movies_in_year(table_name, 2222).build(&conf)) + let films_2222 = movies_in_year(&client, table_name, 2222) + .send() .await .expect("query should succeed"); // this isn't back to the future, there are no movies from 2022 assert_eq!(films_2222.count, 0); - let films_2013 = client - .call(movies_in_year(table_name, 2013).build(&conf)) + let films_2013 = movies_in_year(&client, table_name, 2013) + .send() .await .expect("query should succeed"); assert_eq!(films_2013.count, 2); @@ -99,8 +104,12 @@ async fn main() { ); } -fn create_table(table_name: &str) -> create_table_input::Builder { - CreateTable::builder() +fn create_table( + client: &Client, + table_name: &str, +) -> dynamodb::fluent::fluent_builders::CreateTable { + client + .create_table() .table_name(table_name) .key_schema(vec![ KeySchemaElement::builder() @@ -130,6 +139,13 @@ fn create_table(table_name: &str) -> create_table_input::Builder { ) } +fn parse_item(value: Value) -> HashMap { + match value_to_item(value) { + AttributeValue::M(map) => map, + other => panic!("can only insert top level values, got {:?}", other), + } +} + fn value_to_item(value: Value) -> AttributeValue { match value { Value::Null => AttributeValue::Null(true), @@ -143,23 +159,13 @@ fn value_to_item(value: Value) -> AttributeValue { } } -fn add_item(table_name: impl Into, item: Value) -> put_item_input::Builder { - let attribute_value = match value_to_item(item) { - AttributeValue::M(map) => map, - other => panic!("can only insert top level values, got {:?}", other), - }; - - PutItemInput::builder() - .table_name(table_name) - .item(attribute_value) -} - -fn movies_in_year(table_name: &str, year: u16) -> query_input::Builder { +fn movies_in_year(client: &Client, table_name: &str, year: u16) -> Query { let mut expr_attrib_names = HashMap::new(); expr_attrib_names.insert("#yr".to_string(), "year".to_string()); let mut expr_attrib_values = HashMap::new(); expr_attrib_values.insert(":yyyy".to_string(), AttributeValue::N(year.to_string())); - QueryInput::builder() + client + .query() .table_name(table_name) .key_condition_expression("#yr = :yyyy") .expression_attribute_names(expr_attrib_names) @@ -214,7 +220,8 @@ fn wait_for_ready_table( ) -> Operation> { let operation = DescribeTableInput::builder() .table_name(table_name) - .build(&conf); + .build(&conf) + .expect("valid input"); let waiting_policy = WaitForReadyTable { inner: operation.retry_policy().clone(), }; diff --git a/aws/sdk/examples/kms-helloworld/src/main.rs b/aws/sdk/examples/kms-helloworld/src/main.rs index 9d7de58ba3135c70c51dea5837fbcd996d62cbdd..810754b74b514f9474aa85db2849d7723283bf15 100644 --- a/aws/sdk/examples/kms-helloworld/src/main.rs +++ b/aws/sdk/examples/kms-helloworld/src/main.rs @@ -15,7 +15,7 @@ async fn main() { .build(); let client: StandardClient = aws_hyper::Client::https(); let data = client - .call(GenerateRandom::builder().number_of_bytes(64).build(&config)) + .call(GenerateRandom::builder().number_of_bytes(64).build(&config).expect("valid operation")) .await .expect("failed to generate random data"); println!("{:?}", data); diff --git a/aws/sdk/integration-tests/dynamodb/tests/movies.rs b/aws/sdk/integration-tests/dynamodb/tests/movies.rs index 92cbcfed083c6738e11164455b0254325ca30812..0061020b1561684dcb24ed0932a8c1ae5e42dbba 100644 --- a/aws/sdk/integration-tests/dynamodb/tests/movies.rs +++ b/aws/sdk/integration-tests/dynamodb/tests/movies.rs @@ -143,7 +143,7 @@ fn wait_for_ready_table( ) -> Operation> { let operation = DescribeTableInput::builder() .table_name(table_name) - .build(&conf); + .build(&conf).expect("valid operation"); let waiting_policy = WaitForReadyTable { inner: operation.retry_policy().clone(), }; @@ -181,7 +181,7 @@ async fn movies_it() { .credentials_provider(Credentials::from_keys("AKNOTREAL", "NOT_A_SECRET", None)) .build(); client - .call(create_table(table_name).build(&conf)) + .call(create_table(table_name).build(&conf).expect("valid request")) .await .expect("failed to create table"); @@ -200,19 +200,19 @@ async fn movies_it() { }; for item in data { client - .call(add_item(table_name, item.clone()).build(&conf)) + .call(add_item(table_name, item.clone()).build(&conf).expect("valid request")) .await .expect("failed to insert item"); } let films_2222 = client - .call(movies_in_year(table_name, 2222).build(&conf)) + .call(movies_in_year(table_name, 2222).build(&conf).expect("valid request")) .await .expect("query should succeed"); // this isn't back to the future, there are no movies from 2022 assert_eq!(films_2222.count, 0); let films_2013 = client - .call(movies_in_year(table_name, 2013).build(&conf)) + .call(movies_in_year(table_name, 2013).build(&conf).expect("valid request")) .await .expect("query should succeed"); assert_eq!(films_2013.count, 2); diff --git a/aws/sdk/integration-tests/kms/tests/integration.rs b/aws/sdk/integration-tests/kms/tests/integration.rs index b8226ba5bc62d64d60d3ef15b8a3da9ac76c314d..b86e01547aaf6581e5389ffc0ed47e6fef99c8ed 100644 --- a/aws/sdk/integration-tests/kms/tests/integration.rs +++ b/aws/sdk/integration-tests/kms/tests/integration.rs @@ -47,7 +47,7 @@ async fn generate_random() { .region(Region::new("us-east-1")) .credentials_provider(creds) .build(); - let mut op = GenerateRandom::builder().number_of_bytes(64).build(&conf); + let mut op = GenerateRandom::builder().number_of_bytes(64).build(&conf).expect("valid operation"); op.config_mut() .insert(UNIX_EPOCH + Duration::from_secs(1614952162)); op.config_mut().insert(AwsUserAgent::for_tests()); @@ -87,7 +87,7 @@ async fn generate_random_malformed_response() { .region(Region::new("us-east-1")) .credentials_provider(creds) .build(); - let op = GenerateRandom::builder().number_of_bytes(64).build(&conf); + let op = GenerateRandom::builder().number_of_bytes(64).build(&conf).expect("valid operation"); client.call(op).await.expect_err("response was malformed"); } @@ -130,7 +130,7 @@ async fn generate_random_keystore_not_found() { let mut op = GenerateRandom::builder() .number_of_bytes(64) .custom_key_store_id("does not exist") - .build(&conf); + .build(&conf).expect("valid operation"); op.config_mut() .insert(UNIX_EPOCH + Duration::from_secs(1614955644)); diff --git a/aws/sdk/integration-tests/kms/tests/sensitive-it.rs b/aws/sdk/integration-tests/kms/tests/sensitive-it.rs index 29f83f6336f3898e631e903204e0261b400db55e..dd27ad4c69795a3efdd7b8c9ef04cfc468fab1dd 100644 --- a/aws/sdk/integration-tests/kms/tests/sensitive-it.rs +++ b/aws/sdk/integration-tests/kms/tests/sensitive-it.rs @@ -26,7 +26,7 @@ fn validate_sensitive_trait() { #[test] fn errors_are_retryable() { let conf = kms::Config::builder().build(); - let (_, parts) = CreateAlias::builder().build(&conf).into_request_response(); + let (_, parts) = CreateAlias::builder().build(&conf).expect("valid request").into_request_response(); let http_response = http::Response::builder() .status(400) .body(r#"{ "code": "LimitExceededException" }"#) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/BuilderGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/BuilderGenerator.kt index ced2bca97760faf7ccac9fc76858c91e5453acc7..63b7855f100f1bcb63f29968a81e3a13eb0e9b9e 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/BuilderGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/BuilderGenerator.kt @@ -20,11 +20,11 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.stripOuter import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.smithy.Default +import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.defaultValue import software.amazon.smithy.rust.codegen.smithy.isOptional -import software.amazon.smithy.rust.codegen.smithy.letIf import software.amazon.smithy.rust.codegen.smithy.makeOptional import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.rust.codegen.util.dq @@ -57,8 +57,15 @@ class ModelBuilderGenerator( } } } + + override fun RustWriter.missingRequiredField(field: String) { + val errMessage = "$field is required when building ${symbolProvider.toSymbol(shape).name}" + rust(errMessage.dq()) + } } +fun RuntimeConfig.operationBuildError() = RuntimeType.operationModule(this).member("BuildError") + class OperationInputBuilderGenerator( model: Model, private val symbolProvider: RustSymbolProvider, @@ -66,29 +73,41 @@ class OperationInputBuilderGenerator( private val serviceName: String, private val features: List, ) : BuilderGenerator(model, symbolProvider, shape.inputShape(model)) { + private val runtimeConfig = symbolProvider.config().runtimeConfig + + override fun RustWriter.missingRequiredField(field: String) { + val detailedMessage = "$field was not specified but it is required when building ${ + symbolProvider.toSymbol( + shape + ).name + }" + rust( + """#T::MissingField { field: ${field.dq()}, details: ${detailedMessage.dq()}}""", runtimeConfig.operationBuildError() + ) + } + override fun buildFn(implBlockWriter: RustWriter) { - val fallibleBuilder = StructureGenerator.fallibleBuilder(shape.inputShape(model), symbolProvider) val outputSymbol = symbolProvider.toSymbol(shape) - val operationT = RuntimeType.operation(symbolProvider.config().runtimeConfig) - val operationModule = RuntimeType.operationModule(symbolProvider.config().runtimeConfig) - val sdkBody = RuntimeType.sdkBody(symbolProvider.config().runtimeConfig) + val operationT = RuntimeType.operation(runtimeConfig) + val operationModule = RuntimeType.operationModule(runtimeConfig) + val sdkBody = RuntimeType.sdkBody(runtimeConfig) val retryType = features.mapNotNull { it.retryType() }.firstOrNull()?.let { implBlockWriter.format(it) } ?: "()" - val returnType = with(implBlockWriter) { - "${format(operationT)}<${format(outputSymbol)}, $retryType>".letIf(fallibleBuilder) { "Result<$it, String>" } - } + + val baseReturnType = with(implBlockWriter) { "${format(operationT)}<${format(outputSymbol)}, $retryType>" } + val returnType = "Result<$baseReturnType, ${implBlockWriter.format(runtimeConfig.operationBuildError())}>" implBlockWriter.docs("Consumes the builder and constructs an Operation<#D>", outputSymbol) // For codegen simplicity, allow `let x = ...; x` implBlockWriter.rust("##[allow(clippy::let_and_return)]") implBlockWriter.rustBlock("pub fn build(self, _config: &#T::Config) -> $returnType", RuntimeType.Config) { - conditionalBlock("Ok({", "})", conditional = fallibleBuilder) { + withBlock("Ok({", "})") { withBlock("let op = #T::new(", ");", outputSymbol) { coreBuilder(this) } rust( """ ##[allow(unused_mut)] - let mut request = #T::Request::new(op.build_http_request().map(#T::from)); + let mut request = #T::Request::new(op.build_http_request()?.map(#T::from)); """, operationModule, sdkBody ) @@ -181,6 +200,8 @@ abstract class BuilderGenerator( } } + abstract fun RustWriter.missingRequiredField(field: String) + abstract fun buildFn(implBlockWriter: RustWriter) /** @@ -199,13 +220,12 @@ abstract class BuilderGenerator( members.forEach { member -> val memberName = symbolProvider.toMemberName(member) val memberSymbol = symbolProvider.toSymbol(member) - val errorWhenMissing = "$memberName is required when building ${structureSymbol.name}" val default = memberSymbol.defaultValue() withBlock("$memberName: self.$memberName", ",") { // Write the modifier when { - !memberSymbol.isOptional() && default == Default.RustDefault -> write(".unwrap_or_default()") - !memberSymbol.isOptional() -> write(".ok_or(${errorWhenMissing.dq()})?") + !memberSymbol.isOptional() && default == Default.RustDefault -> rust(".unwrap_or_default()") + !memberSymbol.isOptional() -> withBlock(".ok_or(", ")?") { missingRequiredField(memberName) } memberSymbol.isOptional() && default is Default.Custom -> { withBlock(".or_else(||Some(", "))") { default.render(this) } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt index 7048e032c04e788baf47b0c6c8bf3fc66871e470..5209be99a590bb73feb64a296dbab8dc49fbb9ac 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolGenerator.kt @@ -49,6 +49,7 @@ abstract class HttpProtocolGenerator( ) { private val symbolProvider = protocolConfig.symbolProvider private val model = protocolConfig.model + private val buildErrorT = protocolConfig.runtimeConfig.operationBuildError() fun renderOperation( operationWriter: RustWriter, inputWriter: RustWriter, @@ -99,9 +100,9 @@ abstract class HttpProtocolGenerator( builderGenerator.renderConvenienceMethod(this) rustBlock( - "pub fn build_http_request(&self) -> #T>", RuntimeType.Http("request::Request") + "pub fn build_http_request(&self) -> Result<#T>, #T>", RuntimeType.Http("request::Request"), buildErrorT ) { - write("#T::assemble(self.input.request_builder_base(), self.input.build_body())", inputSymbol) + write("Ok(#T::assemble(self.input.request_builder_base()?, self.input.build_body()))", inputSymbol) } fromResponseImpl(this, operationShape) @@ -126,8 +127,8 @@ abstract class HttpProtocolGenerator( protected fun httpBuilderFun(implBlockWriter: RustWriter, f: RustWriter.() -> Unit) { implBlockWriter.rustBlock( - "pub fn request_builder_base(&self) -> #T", - RuntimeType.HttpRequestBuilder + "pub fn request_builder_base(&self) -> Result<#T, #T>", + RuntimeType.HttpRequestBuilder, buildErrorT ) { f(this) } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt index 0fa6d7524db93a4885be3e4242a05abc78373015..0e3e82de14193cfe399f874769391d67dd08af85 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt @@ -355,15 +355,6 @@ class HttpProtocolTestGenerator( val AwsJson11 = "aws.protocoltests.json#JsonProtocol" val RestJson = "aws.protocoltests.restjson#RestJson" private val ExpectFail = setOf( - // Misc: - - // https://github.com/awslabs/smithy-rs/issues/35 - FailingTest( - RestJson, - "RestJsonHttpPrefixHeadersArePresent", - Action.Request - ), - // Document deserialization: FailingTest(AwsJson11, "PutAndGetInlineDocumentsInput", Action.Response), diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpTraitBindingGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpTraitBindingGenerator.kt index 5f39cf504fa6294f04e5cdbf9d0138845e9bc15b..bed0d5a923d1f90ee5ead0ef7ae2317818c8a213 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpTraitBindingGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpTraitBindingGenerator.kt @@ -10,6 +10,8 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.pattern.SmithyPattern +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape @@ -17,9 +19,11 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.MediaTypeTrait import software.amazon.smithy.model.traits.TimestampFormatTrait +import software.amazon.smithy.rust.codegen.rustlang.Attribute import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.util.dq @@ -62,6 +66,7 @@ class HttpTraitBindingGenerator( // TODO: make defaultTimestampFormat configurable private val defaultTimestampFormat = TimestampFormatTrait.Format.EPOCH_SECONDS private val index = HttpBindingIndex.of(model) + private val buildError = runtimeConfig.operationBuildError() /** * Generates `update_http_builder` and all necessary dependency functions into the impl block provided by @@ -71,9 +76,11 @@ class HttpTraitBindingGenerator( uriBase(implBlockWriter) val hasHeaders = addHeaders(implBlockWriter) val hasQuery = uriQuery(implBlockWriter) + Attribute.Custom("allow(clippy::unnecessary_wraps)").render(implBlockWriter) implBlockWriter.rustBlock( - "fn update_http_builder(&self, builder: #1T) -> #1T", - RuntimeType.HttpRequestBuilder + "fn update_http_builder(&self, builder: #1T) -> Result<#1T, #2T>", + RuntimeType.HttpRequestBuilder, + buildError ) { write("let mut uri = String::new();") write("self.uri_base(&mut uri);") @@ -81,9 +88,9 @@ class HttpTraitBindingGenerator( write("self.uri_query(&mut uri);") } if (hasHeaders) { - write("let builder = self.add_headers(builder);") + write("let builder = self.add_headers(builder)?;") } - write("builder.method(${httpTrait.method.dq()}).uri(uri)") + write("Ok(builder.method(${httpTrait.method.dq()}).uri(uri))") } } @@ -95,47 +102,110 @@ class HttpTraitBindingGenerator( */ private fun addHeaders(writer: RustWriter): Boolean { val headers = index.getRequestBindings(shape, HttpBinding.Location.HEADER) - if (headers.isEmpty()) { + val prefixHeaders = index.getRequestBindings( + shape, + HttpBinding.Location.PREFIX_HEADERS + ) + val buildErrorT = runtimeConfig.operationBuildError() + + if (headers.isEmpty() && prefixHeaders.isEmpty()) { return false } writer.rustBlock( - "fn add_headers(&self, mut builder: #1T) -> #1T", - RuntimeType.HttpRequestBuilder + "fn add_headers(&self, mut builder: #1T) -> Result<#1T, #2T>", + RuntimeType.HttpRequestBuilder, + buildErrorT ) { - headers.forEach { httpBinding -> - val memberShape = httpBinding.member - val memberType = model.expectShape(memberShape.target) - val memberSymbol = symbolProvider.toSymbol(memberShape) - val memberName = symbolProvider.toMemberName(memberShape) - ifSet(memberType, memberSymbol, "&self.$memberName") { field -> - ListForEach(memberType, field) { innerField, targetId -> - val innerMemberType = model.expectShape(targetId) - val formatted = headerFmtFun(innerMemberType, memberShape, innerField) - val safeName = safeName("formatted") - write("let $safeName = $formatted;") - rustBlock("if !$safeName.is_empty()") { - write("builder = builder.header(${httpBinding.locationName.dq()}, $formatted);") - } - } - } + headers.forEach { httpBinding -> renderHeaders(httpBinding) } + prefixHeaders.forEach { httpBinding -> + renderPrefixHeaders(httpBinding) } - write("builder") + rust("Ok(builder)") } return true } + private fun RustWriter.renderPrefixHeaders(httpBinding: HttpBinding) { + val memberShape = httpBinding.member + val memberType = model.expectShape(memberShape.target) + val memberSymbol = symbolProvider.toSymbol(memberShape) + val memberName = symbolProvider.toMemberName(memberShape) + val target = when (memberType) { + is CollectionShape -> model.expectShape(memberType.member.target) + is MapShape -> model.expectShape(memberType.value.target) + else -> TODO("unexpected member for prefix headers: $memberType") + } + ifSet(memberType, memberSymbol, "&self.$memberName") { field -> + rustTemplate( + """ + for (k, v) in $field { + use std::str::FromStr; + let header_name = http::header::HeaderName::from_str(&format!("{}{}", ${httpBinding.locationName.dq()}, &k)).map_err(|err| { + #{build_error}::InvalidField { field: ${memberName.dq()}, details: format!("`{}` cannot be used as a header name: {}", k, err)} + })?; + use std::convert::TryFrom; + let header_value = ${headerFmtFun(target, memberShape, "v")}; + let header_value = http::header::HeaderValue::try_from(header_value).map_err(|err| { + #{build_error}::InvalidField { + field: ${memberName.dq()}, + details: format!("`{}` cannot be used as a header value: {}", ${redactIfNecessary(memberShape, model,"v")}, err)} + })?; + builder = builder.header(header_name, header_value); + } + + """, + "build_error" to runtimeConfig.operationBuildError() + ) + } + } + + private fun RustWriter.renderHeaders(httpBinding: HttpBinding) { + val memberShape = httpBinding.member + val memberType = model.expectShape(memberShape.target) + val memberSymbol = symbolProvider.toSymbol(memberShape) + val memberName = symbolProvider.toMemberName(memberShape) + ifSet(memberType, memberSymbol, "&self.$memberName") { field -> + ListForEach(memberType, field) { innerField, targetId -> + val innerMemberType = model.expectShape(targetId) + val formatted = headerFmtFun(innerMemberType, memberShape, innerField) + val safeName = safeName("formatted") + write("let $safeName = $formatted;") + rustBlock("if !$safeName.is_empty()") { + rustTemplate( + """ + use std::convert::TryFrom; + let header_value = $safeName; + let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| { + #{build_error}::InvalidField { field: ${memberName.dq()}, details: format!("`{}` cannot be used as a header value: {}", &${ + redactIfNecessary( + memberShape, + model, + "header_value" + ) + }, err)} + })?; + builder = builder.header(${httpBinding.locationName.dq()}, header_value); + """, + "build_error" to runtimeConfig.operationBuildError() + ) + } + } + } + } + /** * Format [member] in the when used as an HTTP header */ private fun headerFmtFun(target: Shape, member: MemberShape, targetName: String): String { return when { target.isStringShape -> { - val func = if (target.hasTrait(MediaTypeTrait::class.java)) { - writer.format(RuntimeType.Base64Encode(runtimeConfig)) + /*val func = */ if (target.hasTrait(MediaTypeTrait::class.java)) { + val func = writer.format(RuntimeType.Base64Encode(runtimeConfig)) + "$func(&${writer.useAs(target, targetName)})" } else { - writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_string")) + writer.useAs(target, targetName) + // writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_string")) } - "$func(&${writer.useAs(target, targetName)})" } target.isTimestampShape -> { val timestampFormat = diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt index d2147df61854bd5ef0674abcd870df3d4d0c4b00..f95e30cc80b8ab5e6671b64acc4a87424f4f05a1 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt @@ -272,7 +272,8 @@ class Instantiator( } else { writer.write(".build()") } - if (StructureGenerator.fallibleBuilder(shape, symbolProvider)) { + // All operation builders are fallible + if (StructureGenerator.fallibleBuilder(shape, symbolProvider) || isSyntheticInput) { writer.write(".unwrap()") } } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt index ab0364376ec7508437fb03ed1995dafdb4589703..6e1464c4dd2bfff444477ad46bfe36b411f4d0d3 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/StructureGenerator.kt @@ -31,8 +31,13 @@ fun RustWriter.implBlock(structureShape: Shape, symbolProvider: SymbolProvider, } } -fun StructureShape.hasSensitiveMember(model: Model) = - this.members().any { it.getMemberTrait(model, SensitiveTrait::class.java).isPresent } +fun redactIfNecessary(member: MemberShape, model: Model, safeToPrint: String): String { + return if (member.getMemberTrait(model, SensitiveTrait::class.java).isPresent) { + "*** Sensitive Data Redacted ***".dq() + } else { + safeToPrint + } +} class StructureGenerator( val model: Model, @@ -89,11 +94,9 @@ class StructureGenerator( rust("""let mut formatter = f.debug_struct(${name.dq()});""") members.forEach { member -> val memberName = symbolProvider.toMemberName(member) - if (member.getMemberTrait(model, SensitiveTrait::class.java).isPresent) { - rust("""formatter.field(${memberName.dq()}, &"*** Sensitive Data Redacted ***");""") - } else { - rust("formatter.field(${memberName.dq()}, &self.$memberName);") - } + rust( + "formatter.field(${memberName.dq()}, &${redactIfNecessary(member, model, "self.$memberName")});", + ) } rust("formatter.finish()") } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt index 0960484cdfc090d7c93d4bb3540a09f7563bafda..9c0ed321d1ed089b2f49702b40f539cd2d96157e 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt @@ -169,10 +169,12 @@ class BasicAwsJsonGenerator( write("let builder = #T::new();", RuntimeType.HttpRequestBuilder) rust( """ - builder - .method("POST") - .header("Content-Type", "application/x-amz-json-${awsJsonVersion.value}") - .header("X-Amz-Target", "${protocolConfig.serviceShape.id.name}.${operationShape.id.name}") + Ok( + builder + .method("POST") + .header("Content-Type", "application/x-amz-json-${awsJsonVersion.value}") + .header("X-Amz-Target", "${protocolConfig.serviceShape.id.name}.${operationShape.id.name}") + ) """ ) } diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/HttpTraitBindingGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/HttpTraitBindingGeneratorTest.kt index 3e1ff46f5ee6ad12fef920178f70a4d77daa8482..350649c59efc909b912cd7c788a651ff850e1818 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/HttpTraitBindingGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/generators/HttpTraitBindingGeneratorTest.kt @@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.generators.HttpTraitBindingGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.smithy.generators.uriFormatString import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.testutil.TestRuntimeConfig @@ -45,6 +46,11 @@ class HttpTraitBindingGeneratorTest { @mediaType("video/quicktime") string Video + map StringMap { + key: String, + value: String + } + structure PutObjectRequest { // Sent in the URI label named "key". @required @@ -84,6 +90,13 @@ class HttpTraitBindingGeneratorTest { // Sent in the body additional: String, + + @httpPrefixHeaders("X-Prefix-") + prefix: StringMap, + + @sensitive + @httpHeader("stringHeader") + stringHeader: String } """.asSmithyModel() private val model = OperationNormalizer(baseModel).transformModel( @@ -105,7 +118,11 @@ class HttpTraitBindingGeneratorTest { symbolProvider, TestRuntimeConfig, writer, operationShape, inputShape, httpTrait ).renderUpdateHttpBuilder(this) - rustBlock("pub fn request_builder_base(&self) -> #T", RuntimeType.HttpRequestBuilder) { + rustBlock( + "pub fn request_builder_base(&self) -> Result<#T, #T>", + RuntimeType.HttpRequestBuilder, + TestRuntimeConfig.operationBuildError() + ) { write("let builder = #T::new();", RuntimeType.HttpRequestBuilder) write("self.update_http_builder(builder)") } @@ -169,7 +186,10 @@ class HttpTraitBindingGeneratorTest { renderOperation(writer) writer.compileAndTest( """ + use std::collections::HashMap; let ts = smithy_types::Instant::from_epoch_seconds(10123125); + let mut prefix_header = HashMap::new(); + prefix_header.insert("k".to_string(), "😹".to_string()); let inp = PutObjectInput::builder() .bucket_name("buk") .date_header_list(vec![ts.clone()]) @@ -178,8 +198,9 @@ class HttpTraitBindingGeneratorTest { .extras(vec![0,1]) .some_value("qp") .media_type("base64encodethis") + .prefix(prefix_header) .build().unwrap(); - let http_request = inp.request_builder_base().body(()).unwrap(); + let http_request = inp.request_builder_base().expect("valid input").body(()).unwrap(); assert_eq!(http_request.uri(), "/buk/1970-04-28T03%3A58%3A45Z?paramName=qp&hello=0&hello=1"); assert_eq!(http_request.method(), "PUT"); let mut date_header = http_request.headers().get_all("X-Dates").iter(); @@ -191,6 +212,70 @@ class HttpTraitBindingGeneratorTest { let base64_header = http_request.headers().get_all("X-MediaType").iter().map(|hv|hv.to_str().unwrap()).collect::>(); assert_eq!(base64_header, vec!["YmFzZTY0ZW5jb2RldGhpcw=="]); + + let prefix_header = http_request.headers().get_all("X-Prefix-k").iter().map(|hv|std::str::from_utf8(hv.as_ref()).unwrap()).collect::>(); + assert_eq!(prefix_header, vec!["😹"]) + """ + ) + } + + @Test + fun `invalid header name produces an error`() { + val writer = RustWriter.forModule("input") + renderOperation(writer) + writer.compileAndTest( + """ + use std::collections::HashMap; + let ts = smithy_types::Instant::from_epoch_seconds(10123125); + let mut prefix_header = HashMap::new(); + prefix_header.insert("😹".to_string(), "😹".to_string()); + let inp = PutObjectInput::builder() + .bucket_name("buk") + .key(ts.clone()) + .prefix(prefix_header) + .build().unwrap(); + let err = inp.request_builder_base().expect_err("can't make a header out of a cat emoji"); + assert_eq!(format!("{}", err), "Invalid field in input: prefix (Details: `😹` cannot be used as a header name: invalid HTTP header name)"); + """ + ) + } + + @Test + fun `invalid prefix header value produces an error`() { + val writer = RustWriter.forModule("input") + renderOperation(writer) + writer.compileAndTest( + """ + use std::collections::HashMap; + let ts = smithy_types::Instant::from_epoch_seconds(10123125); + let mut prefix_header = HashMap::new(); + prefix_header.insert("valid-key".to_string(), "\n can't put a newline in a header value".to_string()); + let inp = PutObjectInput::builder() + .bucket_name("buk") + .key(ts.clone()) + .prefix(prefix_header) + .build().unwrap(); + let err = inp.request_builder_base().expect_err("can't make a header with a newline"); + assert_eq!(format!("{}", err), "Invalid field in input: prefix (Details: `\n can\'t put a newline in a header value` cannot be used as a header value: failed to parse header value)"); + """ + ) + } + + @Test + fun `invalid header value produces an error`() { + val writer = RustWriter.forModule("input") + renderOperation(writer) + writer.compileAndTest( + """ + let ts = smithy_types::Instant::from_epoch_seconds(10123125); + let inp = PutObjectInput::builder() + .bucket_name("buk") + .key(ts.clone()) + .string_header("\n is not valid") + .build().unwrap(); + let err = inp.request_builder_base().expect_err("can't make a header with a newline"); + // make sure we obey the sensitive trait + assert_eq!(format!("{}", err), "Invalid field in input: string_header (Details: `*** Sensitive Data Redacted ***` cannot be used as a header value: failed to parse header value)"); """ ) } diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt index 04dbdad14eeefc7cea86d1f06ce8a2b93e652f61..e1de32134656e971ad89268bc32041008cb8806a 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGeneratorTest.kt @@ -14,6 +14,7 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.withBlock import software.amazon.smithy.rust.codegen.smithy.CodegenVisitor import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator @@ -143,8 +144,9 @@ class HttpProtocolTestGeneratorTest { inputShape: StructureShape ) { httpBuilderFun(implBlockWriter) { - write("#T::new()", RuntimeType.HttpRequestBuilder) - writeWithNoFormatting(httpRequestBuilder) + withBlock("Ok(#T::new()", ")", RuntimeType.HttpRequestBuilder) { + writeWithNoFormatting(httpRequestBuilder) + } } } } diff --git a/rust-runtime/smithy-http/Cargo.toml b/rust-runtime/smithy-http/Cargo.toml index 77e2cc76d2649a92ac755812efb8c01083669575..bc5f2659cf8941681f5c63399c3c2052a14b6472 100644 --- a/rust-runtime/smithy-http/Cargo.toml +++ b/rust-runtime/smithy-http/Cargo.toml @@ -9,6 +9,7 @@ smithy-types = { path = "../smithy-types" } bytes = "1" http-body = "0.4.0" http = "0.2.3" +thiserror = "1" [dev-dependencies] proptest = "0.10.1" diff --git a/rust-runtime/smithy-http/src/operation.rs b/rust-runtime/smithy-http/src/operation.rs index 0d841c1a25c0567dea6b6891a6bc1c908537f937..4690f49dbd8a8253030f422d84c89e72b2e01c1c 100644 --- a/rust-runtime/smithy-http/src/operation.rs +++ b/rust-runtime/smithy-http/src/operation.rs @@ -4,6 +4,7 @@ use std::borrow::Cow; use std::cell::{Ref, RefCell, RefMut}; use std::ops::DerefMut; use std::rc::Rc; +use thiserror::Error; #[derive(Clone)] pub struct Metadata { @@ -39,6 +40,26 @@ pub struct Parts { pub metadata: Option, } +/// An error occured attempting to build an `Operation` from an input +/// +/// These are almost always due to user error caused by limitations of specific fields due to +/// protocol serialization (eg. fields that can only be a subset ASCII because they are serialized +/// as the name of an HTTP header) +#[non_exhaustive] +#[derive(Debug, Error)] +pub enum BuildError { + #[error("Invalid field in input: {field} (Details: {details})")] + InvalidField { + field: &'static str, + details: String, + }, + #[error("{field} was missing. {details}")] + MissingField { + field: &'static str, + details: &'static str, + }, +} + pub struct Operation { request: Request, parts: Parts,