From a77c5ca2a911908dfc2a3dc793cbb193d038847e Mon Sep 17 00:00:00 2001 From: Russell Cohen Date: Wed, 5 Jan 2022 12:46:57 -0500 Subject: [PATCH] Add support for Paginators (#1006) * Add Support for Paginators This commit adds pagination support for both clients. It does this via the `FnStream` abstraction which combines a rendezvous-channel with a closure to produce a stream. When the stream is polled, the underlying closure is polled to advance the computation which then pushes data into the channel. Two paginators are generated: 1. `.paginate()` which produces a stream of `Result` 2. `.paginate().items()` which produces a stream of `Result` where items are flattened from the individual pages. This uses `TryFlatMap` to support conveniently generating a flat page iterator. In addition, docs are generated to point customers towards the paginators. * Add RFC, more tests * backout unrelated changes * Fix paginators for glacier * Fix clippy error * Add test for paginators that point to maps & fix apigateway paginator * Fix docs * remove extraeneous println * Fix DynamoDB examples * Disallow paginating operations with the idempotency token trait * Apply suggestions from code review misc typos / cleanups Co-authored-by: Zelda Hessler * CR feedback * Fix typo * Update changelog * Apply suggestions from code review Co-authored-by: John DiSanti * CR feedback round 1 * Rename paginate to into_paginator() * update pr bot comment Co-authored-by: Zelda Hessler Co-authored-by: John DiSanti --- .github/workflows/pull-request-bot.yml | 8 +- CHANGELOG.next.toml | 12 + .../rustsdk/AwsFluentClientDecorator.kt | 35 +-- .../rustsdk/IntegrationTestDependencies.kt | 1 + .../customize/auth/DisabledAuthDecorator.kt | 2 +- .../customize/glacier/AccountIdAutofill.kt | 2 +- aws/sdk/build.gradle.kts | 16 +- aws/sdk/examples/dynamodb/Cargo.toml | 1 + .../examples/dynamodb/src/bin/list-items.rs | 13 +- .../examples/dynamodb/src/bin/list-tables.rs | 12 +- aws/sdk/integration-tests/dynamodb/Cargo.toml | 2 + .../dynamodb/tests/paginators.rs | 171 +++++++++++ .../rust/codegen/rustlang/CargoDependency.kt | 2 + .../smithy/rust/codegen/rustlang/RustTypes.kt | 5 + .../rust/codegen/smithy/SymbolVisitor.kt | 8 + .../customize/OperationCustomization.kt | 3 - .../generators/FluentClientDecorator.kt | 65 ++-- .../generators/NestedAccessorGenerator.kt | 88 ++++++ .../smithy/generators/PaginatorGenerator.kt | 289 ++++++++++++++++++ .../smithy/generators/ServiceGenerator.kt | 3 +- .../smithy/protocols/InlineFunctionNamer.kt | 6 + .../generators/PaginatorGeneratorTest.kt | 75 +++++ design/src/rfcs/rfc0008_paginators.md | 197 ++++++++++++ rust-runtime/aws-smithy-async/Cargo.toml | 9 +- .../aws-smithy-async/src/future/fn_stream.rs | 265 ++++++++++++++++ .../aws-smithy-async/src/future/mod.rs | 2 + .../aws-smithy-async/src/future/rendezvous.rs | 132 ++++++++ .../aws-smithy-protocol-test/Cargo.toml | 1 + .../publisher/src/subcommand/fix_manifests.rs | 17 +- 29 files changed, 1358 insertions(+), 84 deletions(-) create mode 100644 aws/sdk/integration-tests/dynamodb/tests/paginators.rs create mode 100644 codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/NestedAccessorGenerator.kt create mode 100644 codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/PaginatorGenerator.kt create mode 100644 codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/PaginatorGeneratorTest.kt create mode 100644 design/src/rfcs/rfc0008_paginators.md create mode 100644 rust-runtime/aws-smithy-async/src/future/fn_stream.rs create mode 100644 rust-runtime/aws-smithy-async/src/future/rendezvous.rs diff --git a/.github/workflows/pull-request-bot.yml b/.github/workflows/pull-request-bot.yml index a2329b20d..b3785c3da 100644 --- a/.github/workflows/pull-request-bot.yml +++ b/.github/workflows/pull-request-bot.yml @@ -104,11 +104,11 @@ jobs: toolchain: ${{ env.rust_version }} default: true - name: Generate doc preview - # Only generate two of the smallest services since these get huge. One of these must be - # STS since aws-config depends on it. STS and Transcribe Streaming were chosen below to stay - # small while still representing most features. Combined, they are about 11 MB at time of writing. + # Only generate three of the smallest services since the doc build can be very large. One of these must be + # STS since aws-config depends on it. STS and Transcribe Streaming and DynamoDB (paginators/waiters) were chosen + # below to stay small while still representing most features. Combined, they are about ~20MB at time of writing. run: | - ./gradlew -Paws.services=+sts,+transcribestreaming :aws:sdk:assemble + ./gradlew -Paws.services=+sts,+transcribestreaming,+dynamodb :aws:sdk:assemble # Copy the Server runtime crate(s) in cp -r rust-runtime/aws-smithy-http-server aws/sdk/build/aws-sdk/sdk diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 12b762ced..7c30422a6 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -99,6 +99,18 @@ meta = { "breaking" = false, "tada" = false, "bug" = false } references = ["smithy-rs#1002", "aws-sdk-rust#352"] author = "rcoh" +[[smithy-rs]] +message = "Add support for paginators! Paginated APIs now include `.into_paginator()` and (when supported) `.into_paginator().items()` to enable paginating responses automatically. The paginator API should be considered in preview and is subject to change pending customer feedback." +meta = { "breaking" = false, "tada" = true, "bug" = false } +references = ["aws-sdk-rust#47", "smithy-rs#1006"] +author = "rcoh" + +[[aws-sdk-rust]] +message = "Add support for paginators! Paginated APIs now include `.into_paginator()` and (when supported) `.into_paginator().items()` to enable paginating responses automatically. The paginator API should be considered in preview and is subject to change pending customer feedback." +meta = { "breaking" = false, "tada" = true, "bug" = false } +references = ["aws-sdk-rust#47", "smithy-rs#1006"] +author = "rcoh" + [[aws-sdk-rust]] message = "Example for Config builder region function added" references = ["smithy-rs#670"] diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt index 1abbfc4f6..0d5629287 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt @@ -10,8 +10,6 @@ import software.amazon.smithy.model.traits.TitleTrait import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.rustlang.Feature -import software.amazon.smithy.rust.codegen.rustlang.RustMetadata -import software.amazon.smithy.rust.codegen.rustlang.RustModule import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.asType @@ -52,27 +50,20 @@ class AwsFluentClientDecorator : RustCodegenDecorator { override fun extras(codegenContext: CodegenContext, rustCrate: RustCrate) { val types = Types(codegenContext.runtimeConfig) - val module = RustMetadata(public = true) - rustCrate.withModule( - RustModule( - "client", - module, - documentation = "Client and fluent builders for calling the service." + FluentClientGenerator( + codegenContext, + generics = ClientGenerics( + connectorDefault = types.dynConnector, + middlewareDefault = types.defaultMiddleware, + retryDefault = types.smithyClientRetry.member("Standard"), + client = types.awsSmithyClient + ), + customizations = listOf( + AwsPresignedFluentBuilderMethod(codegenContext.runtimeConfig), + AwsFluentClientDocs(codegenContext) ) - ) { writer -> - FluentClientGenerator( - codegenContext, - generics = ClientGenerics( - connectorDefault = types.dynConnector, - middlewareDefault = types.defaultMiddleware, - retryDefault = types.smithyClientRetry.member("Standard"), - client = types.awsSmithyClient - ), - customizations = listOf( - AwsPresignedFluentBuilderMethod(codegenContext.runtimeConfig), - AwsFluentClientDocs(codegenContext) - ) - ).render(writer) + ).render(rustCrate) + rustCrate.withModule(FluentClientGenerator.clientModule) { writer -> AwsFluentClientExtensions(types).render(writer) } val awsSmithyClient = "aws-smithy-client" diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt index fd21cee43..ce5151fb6 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt @@ -60,6 +60,7 @@ class IntegrationTestDependencies( val smithyClient = CargoDependency.SmithyClient(runtimeConfig) .copy(features = setOf("test-util"), scope = DependencyScope.Dev) addDependency(smithyClient) + addDependency(CargoDependency.SmithyProtocolTestHelpers(runtimeConfig)) addDependency(SerdeJson) addDependency(Tokio) addDependency(FuturesUtil) diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/auth/DisabledAuthDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/auth/DisabledAuthDecorator.kt index 547954489..6467ed874 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/auth/DisabledAuthDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/auth/DisabledAuthDecorator.kt @@ -38,7 +38,7 @@ class DisabledAuthDecorator : RustCodegenDecorator { val optionalOperations = optionalAuth[service.id]!! return ModelTransformer.create().mapShapes(model) { if (optionalOperations.contains(it.id) && it is OperationShape) { - it.toBuilder().addTrait(AuthTrait(listOf())).build() + it.toBuilder().addTrait(AuthTrait(setOf())).build() } else { it } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/AccountIdAutofill.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/AccountIdAutofill.kt index 83f7cac3c..89307c718 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/AccountIdAutofill.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/AccountIdAutofill.kt @@ -16,7 +16,7 @@ import software.amazon.smithy.rust.codegen.util.inputShape class AccountIdAutofill() : OperationCustomization() { override fun mutSelf(): Boolean = true - override fun consumesSelf(): Boolean = true + override fun consumesSelf(): Boolean = false override fun section(section: OperationSection): Writable { return when (section) { is OperationSection.MutateInput -> writable { diff --git a/aws/sdk/build.gradle.kts b/aws/sdk/build.gradle.kts index c33e21885..ac9f0b4ab 100644 --- a/aws/sdk/build.gradle.kts +++ b/aws/sdk/build.gradle.kts @@ -160,14 +160,16 @@ task("relocateServices") { task("relocateExamples") { description = "relocate the examples folder & rewrite path dependencies" doLast { - copy { - from(projectDir) - awsServices.examples.forEach { example -> - include("$example/**") + if (awsServices.examples.isNotEmpty()) { + copy { + from(projectDir) + awsServices.examples.forEach { example -> + include("$example/**") + } + into(outputDir) + exclude("**/target") + filter { line -> line.replace("build/aws-sdk/sdk/", "sdk/") } } - into(outputDir) - exclude("**/target") - filter { line -> line.replace("build/aws-sdk/sdk/", "sdk/") } } } inputs.dir(projectDir.resolve("examples")) diff --git a/aws/sdk/examples/dynamodb/Cargo.toml b/aws/sdk/examples/dynamodb/Cargo.toml index 1a973da34..7fa34ed3f 100644 --- a/aws/sdk/examples/dynamodb/Cargo.toml +++ b/aws/sdk/examples/dynamodb/Cargo.toml @@ -18,4 +18,5 @@ rand = "0.8.3" serde_json = "1" structopt = { version = "0.3", default-features = false } tokio = { version = "1", features = ["full"] } +tokio-stream = "0.1.8" tracing-subscriber = { version = "0.2.16", features = ["fmt"] } diff --git a/aws/sdk/examples/dynamodb/src/bin/list-items.rs b/aws/sdk/examples/dynamodb/src/bin/list-items.rs index 277e2025c..56c587faa 100644 --- a/aws/sdk/examples/dynamodb/src/bin/list-items.rs +++ b/aws/sdk/examples/dynamodb/src/bin/list-items.rs @@ -5,6 +5,7 @@ use aws_config::meta::region::RegionProviderChain; use aws_sdk_dynamodb::{Client, Error, Region, PKG_VERSION}; +use tokio_stream::StreamExt; use structopt::StructOpt; @@ -58,11 +59,17 @@ async fn main() -> Result<(), Error> { println!(); } - let resp = client.scan().table_name(table).send().await?; + let items: Result, _> = client + .scan() + .table_name(table) + .into_paginator() + .items() + .send() + .collect() + .await; println!("Items in table:"); - - if let Some(item) = resp.items { + for item in items? { println!(" {:?}", item); } diff --git a/aws/sdk/examples/dynamodb/src/bin/list-tables.rs b/aws/sdk/examples/dynamodb/src/bin/list-tables.rs index 8123a979b..a9a3fc677 100644 --- a/aws/sdk/examples/dynamodb/src/bin/list-tables.rs +++ b/aws/sdk/examples/dynamodb/src/bin/list-tables.rs @@ -5,6 +5,7 @@ use aws_config::meta::region::RegionProviderChain; use aws_sdk_dynamodb::{Client, Error, Region, PKG_VERSION}; +use tokio_stream::StreamExt; use structopt::StructOpt; @@ -50,18 +51,15 @@ async fn main() -> Result<(), Error> { let client = Client::new(&shared_config); - let resp = client.list_tables().send().await?; + let paginator = client.list_tables().into_paginator().items().send(); + let table_names = paginator.collect::, _>>().await?; println!("Tables:"); - let names = resp.table_names.unwrap_or_default(); - let len = names.len(); - - for name in names { + for name in &table_names { println!(" {}", name); } - println!("Found {} tables", len); - + println!("Found {} tables", table_names.len()); Ok(()) } diff --git a/aws/sdk/integration-tests/dynamodb/Cargo.toml b/aws/sdk/integration-tests/dynamodb/Cargo.toml index 8084e8cc3..3edd8b984 100644 --- a/aws/sdk/integration-tests/dynamodb/Cargo.toml +++ b/aws/sdk/integration-tests/dynamodb/Cargo.toml @@ -13,6 +13,7 @@ aws-sdk-dynamodb = { path = "../../build/aws-sdk/sdk/dynamodb" } aws-smithy-client = { path = "../../build/aws-sdk/sdk/aws-smithy-client", features = ["test-util", "rustls"] } aws-smithy-http = { path = "../../build/aws-sdk/sdk/aws-smithy-http" } aws-smithy-types = { path = "../../build/aws-sdk/sdk/aws-smithy-types" } +aws-smithy-protocol-test = { path = "../../build/aws-sdk/sdk/aws-smithy-protocol-test" } aws-types = { path = "../../build/aws-sdk/sdk/aws-types" } bytes = "1" criterion = { version = "0.3.4" } @@ -21,6 +22,7 @@ http = "0.2.4" serde_json = "1" tokio = { version = "1", features = ["full", "test-util"]} tracing-subscriber = "0.2.16" +tokio-stream = "0.1.8" [[bench]] name = "deserialization_bench" diff --git a/aws/sdk/integration-tests/dynamodb/tests/paginators.rs b/aws/sdk/integration-tests/dynamodb/tests/paginators.rs new file mode 100644 index 000000000..0c4462862 --- /dev/null +++ b/aws/sdk/integration-tests/dynamodb/tests/paginators.rs @@ -0,0 +1,171 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +use aws_sdk_dynamodb::model::AttributeValue; +use aws_sdk_dynamodb::{Client, Config}; +use aws_smithy_client::test_connection::{capture_request, TestConnection}; +use aws_smithy_http::body::SdkBody; +use aws_smithy_protocol_test::{assert_ok, validate_body, MediaType}; +use aws_types::region::Region; +use aws_types::Credentials; +use std::collections::HashMap; +use std::iter::FromIterator; +use tokio_stream::StreamExt; + +fn stub_config() -> Config { + Config::builder() + .region(Region::new("us-east-1")) + .credentials_provider(Credentials::new("akid", "secret", None, None, "test")) + .build() +} + +/// Validate that arguments are passed on to the paginator +#[tokio::test] +async fn paginators_pass_args() { + let (conn, request) = capture_request(None); + let client = Client::from_conf_conn(stub_config(), conn); + let mut paginator = client + .scan() + .table_name("test-table") + .into_paginator() + .page_size(32) + .send(); + let _ = paginator.next().await; + let request = request.expect_request(); + let body = request.body().bytes().expect("data is loaded"); + assert_ok(validate_body( + body, + r#"{"TableName":"test-table","Limit":32}"#, + MediaType::Json, + )); +} + +fn mk_request(body: &'static str) -> http::Request { + http::Request::builder() + .uri("https://dynamodb.us-east-1.amazonaws.com/") + .body(SdkBody::from(body)) + .unwrap() +} + +fn mk_response(body: &'static str) -> http::Response { + http::Response::builder().body(SdkBody::from(body)).unwrap() +} + +#[tokio::test] +async fn paginators_loop_until_completion() { + let conn = TestConnection::new(vec![ + ( + mk_request(r#"{"TableName":"test-table","Limit":32}"#), + mk_response( + r#"{ + "Count": 1, + "Items": [{ + "PostedBy": { + "S": "joe@example.com" + } + }], + "LastEvaluatedKey": { + "PostedBy": { "S": "joe@example.com" } + } + }"#, + ), + ), + ( + mk_request( + r#"{"TableName":"test-table","Limit":32,"ExclusiveStartKey":{"PostedBy":{"S":"joe@example.com"}}}"#, + ), + mk_response( + r#"{ + "Count": 1, + "Items": [{ + "PostedBy": { + "S": "jack@example.com" + } + }] + }"#, + ), + ), + ]); + let client = Client::from_conf_conn(stub_config(), conn.clone()); + let mut paginator = client + .scan() + .table_name("test-table") + .into_paginator() + .page_size(32) + .send(); + assert_eq!(conn.requests().len(), 0); + let first_page = paginator + .try_next() + .await + .expect("success") + .expect("page exists"); + assert_eq!( + first_page.items.unwrap_or_default(), + vec![HashMap::from_iter([( + "PostedBy".to_string(), + AttributeValue::S("joe@example.com".to_string()) + )])] + ); + assert_eq!(conn.requests().len(), 1); + let second_page = paginator + .try_next() + .await + .expect("success") + .expect("page exists"); + assert_eq!( + second_page.items.unwrap_or_default(), + vec![HashMap::from_iter([( + "PostedBy".to_string(), + AttributeValue::S("jack@example.com".to_string()) + )])] + ); + assert_eq!(conn.requests().len(), 2); + assert!( + paginator.next().await.is_none(), + "no more pages should exist" + ); + // we shouldn't make another request, we know we're at the end + assert_eq!(conn.requests().len(), 2); + conn.assert_requests_match(&[]); +} + +#[tokio::test] +async fn paginators_handle_errors() { + // LastEvaluatedKey is set but there is only one response in the test connection + let conn = TestConnection::new(vec![( + mk_request(r#"{"TableName":"test-table","Limit":32}"#), + mk_response( + r#"{ + "Count": 1, + "Items": [{ + "PostedBy": { + "S": "joe@example.com" + } + }], + "LastEvaluatedKey": { + "PostedBy": { "S": "joe@example.com" } + } + }"#, + ), + )]); + let client = Client::from_conf_conn(stub_config(), conn.clone()); + let mut rows = client + .scan() + .table_name("test-table") + .into_paginator() + .page_size(32) + .items() + .send(); + assert_eq!( + rows.try_next() + .await + .expect("no error") + .expect("not EOS") + .get("PostedBy"), + Some(&AttributeValue::S("joe@example.com".to_string())) + ); + rows.try_next().await.expect_err("failure"); + assert_eq!(rows.try_next().await.expect("ok"), None); +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt index ba0f42782..6c622f6cc 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt @@ -205,11 +205,13 @@ data class CargoDependency( val Ring: CargoDependency = CargoDependency("ring", CratesIo("0.16")) val TempFile: CargoDependency = CargoDependency("temp-file", CratesIo("0.1.6"), scope = DependencyScope.Dev) val TokioStream: CargoDependency = CargoDependency("tokio-stream", CratesIo("0.1.7")) + val AsyncStream: CargoDependency = CargoDependency("async-stream", CratesIo("0.3.2")) val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) val Tracing: CargoDependency = CargoDependency("tracing", CratesIo("0.1")) fun SmithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("types") fun SmithyClient(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("client") + fun SmithyAsync(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("async") fun SmithyEventStream(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("eventstream") fun SmithyHttp(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http") fun SmithyHttpTower(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http-tower") diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt index 67b5abdbb..cff762b4c 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt @@ -97,6 +97,11 @@ sealed class RustType { data class Option(override val member: RustType) : RustType(), Container { override val name: kotlin.String = "Option" override val namespace = "std::option" + + /** Convert `Option` to `Option<&T>` **/ + fun referenced(lifetime: kotlin.String?): Option { + return Option(Reference(lifetime, this.member)) + } } data class Box(override val member: RustType) : RustType(), Container { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt index 495832fd7..3fdb863db 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/SymbolVisitor.kt @@ -98,6 +98,14 @@ fun Symbol.makeOptional(): Symbol { } } +fun Symbol.mapRustType(f: (RustType) -> RustType): Symbol { + val newType = f(this.rustType()) + return Symbol.builder().rustType(newType) + .addReference(this) + .name(newType.name) + .build() +} + fun Symbol.makeRustBoxed(): Symbol { val symbol = this val rustType = RustType.Box(symbol.rustType()) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/OperationCustomization.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/OperationCustomization.kt index 8225fad72..8170405d4 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/OperationCustomization.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/customize/OperationCustomization.kt @@ -53,9 +53,6 @@ sealed class OperationSection(name: String) : Section(name) { abstract class OperationCustomization : NamedSectionGenerator() { open fun retryType(): RuntimeType? = null - // NOTE: mutSelf and consumes self must be set together due to a the fact that the fluent builder does not have any information - // about consumes / mut - /** * Does `make_operation` consume the self parameter? * diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/FluentClientDecorator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/FluentClientDecorator.kt index aaa77b245..891975f57 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/FluentClientDecorator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/FluentClientDecorator.kt @@ -56,19 +56,10 @@ class FluentClientDecorator : RustCodegenDecorator { return } - val module = RustMetadata(public = true) - rustCrate.withModule( - RustModule( - "client", - module, - documentation = "Client and fluent builders for calling the service." - ) - ) { writer -> - FluentClientGenerator( - codegenContext, - customizations = listOf(GenericFluentClient(codegenContext)) - ).render(writer) - } + FluentClientGenerator( + codegenContext, + customizations = listOf(GenericFluentClient(codegenContext)) + ).render(rustCrate) rustCrate.mergeFeature(Feature("rustls", default = true, listOf("aws-smithy-client/rustls"))) rustCrate.mergeFeature(Feature("native-tls", default = false, listOf("aws-smithy-client/native-tls"))) } @@ -273,7 +264,7 @@ class GenericFluentClient(codegenContext: CodegenContext) : FluentClientCustomiz } class FluentClientGenerator( - codegenContext: CodegenContext, + private val codegenContext: CodegenContext, private val generics: ClientGenerics = ClientGenerics( connectorDefault = null, middlewareDefault = null, @@ -285,6 +276,12 @@ class FluentClientGenerator( companion object { fun clientOperationFnName(operationShape: OperationShape, symbolProvider: RustSymbolProvider): String = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(operationShape).name.toSnakeCase()) + + val clientModule = RustModule( + "client", + RustMetadata(public = true), + documentation = "Client and fluent builders for calling the service." + ) } private val serviceShape = codegenContext.serviceShape @@ -296,13 +293,19 @@ class FluentClientGenerator( private val runtimeConfig = codegenContext.runtimeConfig private val core = FluentClientCore(model) - fun render(writer: RustWriter) { + fun render(crate: RustCrate) { + crate.withModule(clientModule) { writer -> + renderFluentClient(writer) + } + } + + private fun renderFluentClient(writer: RustWriter) { writer.rustTemplate( """ ##[derive(Debug)] pub(crate) struct Handle#{generics_decl:W} { - client: #{client}::Client${generics.inst}, - conf: crate::Config, + pub(crate) client: #{client}::Client${generics.inst}, + pub(crate) conf: crate::Config, } #{client_docs:W} @@ -345,7 +348,8 @@ class FluentClientGenerator( """, "generics_decl" to generics.decl, "client" to clientDep.asType(), - "client_docs" to writable { + "client_docs" to writable + { customizations.forEach { it.section( FluentClientSection.FluentClientDocs( @@ -362,12 +366,16 @@ class FluentClientGenerator( ) { operations.forEach { operation -> val name = symbolProvider.toSymbol(operation).name + val fullPath = "crate::client::fluent_builders::$name" + val maybePaginated = if (operation.isPaginated(model)) { + "\n/// This operation supports pagination. See [`into_paginator()`]($fullPath::into_paginator)." + } else "" rust( """ /// Constructs a fluent builder for the `$name` operation. /// - /// See [`$name`](crate::client::fluent_builders::$name) for more information about the - /// operation and its arguments. + /// See [`$name`]($fullPath) for more information about the + /// operation and its arguments.$maybePaginated pub fn ${ clientOperationFnName( operation, @@ -445,8 +453,8 @@ class FluentClientGenerator( #{operation_err}, #{input}OperationRetryAlias>, { - let input = self.inner.build().map_err(|err|#{sdk_err}::ConstructionFailure(err.into()))?; - let op = input.make_operation(&self.handle.conf) + let op = self.inner.build().map_err(|err|#{sdk_err}::ConstructionFailure(err.into()))? + .make_operation(&self.handle.conf) .await .map_err(|err|#{sdk_err}::ConstructionFailure(err.into()))?; self.handle.client.call(op).await @@ -459,6 +467,19 @@ class FluentClientGenerator( .copy(name = "result::SdkError"), "client" to clientDep.asType(), ) + PaginatorGenerator.paginatorType(codegenContext, generics, operation)?.also { paginatorType -> + rustTemplate( + """ + /// Create a paginator for this request + /// + /// Paginators are used by calling [`send().await`](#{Paginator}::send) which returns a [`Stream`](tokio_stream::Stream). + pub fn into_paginator(self) -> #{Paginator}${generics.inst} { + #{Paginator}::new(self.handle, self.inner) + } + """, + "Paginator" to paginatorType + ) + } writeCustomizations( customizations, FluentClientSection.FluentBuilderImpl( diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/NestedAccessorGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/NestedAccessorGenerator.kt new file mode 100644 index 000000000..d6ad3d56c --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/NestedAccessorGenerator.kt @@ -0,0 +1,88 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.codegen.smithy.generators + +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.rustlang.RustModule +import software.amazon.smithy.rust.codegen.rustlang.RustType +import software.amazon.smithy.rust.codegen.rustlang.Writable +import software.amazon.smithy.rust.codegen.rustlang.rust +import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.smithy.isOptional +import software.amazon.smithy.rust.codegen.smithy.makeOptional +import software.amazon.smithy.rust.codegen.smithy.mapRustType +import software.amazon.smithy.rust.codegen.smithy.protocols.lensName + +/** Generator for accessing nested fields through optional values **/ +class NestedAccessorGenerator(private val symbolProvider: RustSymbolProvider) { + private val module = RustModule("lens", RustMetadata(public = false), "Generated accessors for nested fields") + /** + * Generate an accessor on [root] that consumes [root] and returns an `Option` for the nested item + */ + fun generateOwnedAccessor(root: StructureShape, path: List): RuntimeType { + check(path.isNotEmpty()) { "must not be called on an empty path" } + val baseType = symbolProvider.toSymbol(path.last()) + val fnName = symbolProvider.lensName("", root, path) + return RuntimeType.forInlineFun(fnName, module) { + it.rustTemplate( + """ + pub(crate) fn $fnName(input: #{Input}) -> #{Output} { + #{body:W} + } + """, + "Input" to symbolProvider.toSymbol(root), "Output" to baseType.makeOptional(), "body" to generateBody(path, false) + ) + } + } + + /** + * Generate an accessor on [root] that takes a reference and returns an `Option<&T>` for the nested item + */ + fun generateBorrowingAccessor(root: StructureShape, path: List): RuntimeType { + check(path.isNotEmpty()) { "must not be called on an empty path" } + val baseType = symbolProvider.toSymbol(path.last()).makeOptional() + val fnName = symbolProvider.lensName("ref", root, path) + val referencedType = baseType.mapRustType { (it as RustType.Option).referenced(lifetime = null) } + return RuntimeType.forInlineFun(fnName, module) { + it.rustTemplate( + """ + pub(crate) fn $fnName(input: &#{Input}) -> #{Output} { + #{body:W} + } + """, + "Input" to symbolProvider.toSymbol(root), "Output" to referencedType, "body" to generateBody(path, true) + ) + } + } + + private fun generateBody(path: List, reference: Boolean): Writable = writable { + val ref = if (reference) { "&" } else { "" } + if (path.isEmpty()) { + rust("Some(input)") + } else { + val head = path.first() + if (symbolProvider.toSymbol(head).isOptional()) { + rust( + """ + let input = match ${ref}input.${symbolProvider.toMemberName(head)} { + None => return None, + Some(t) => t + }; + """ + ) + } else { + rust("let input = input.${symbolProvider.toMemberName(head)};") + } + // Note: although _this_ function is recursive, it generates a series of `if let` statements with early returns. + generateBody(path.drop(1), reference)(this) + } + } +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/PaginatorGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/PaginatorGenerator.kt new file mode 100644 index 000000000..6531e3112 --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/PaginatorGenerator.kt @@ -0,0 +1,289 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.codegen.smithy.generators + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.PaginatedIndex +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.traits.IdempotencyTokenTrait +import software.amazon.smithy.model.traits.PaginatedTrait +import software.amazon.smithy.rust.codegen.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.rustlang.RustModule +import software.amazon.smithy.rust.codegen.rustlang.RustType +import software.amazon.smithy.rust.codegen.rustlang.Writable +import software.amazon.smithy.rust.codegen.rustlang.asType +import software.amazon.smithy.rust.codegen.rustlang.render +import software.amazon.smithy.rust.codegen.rustlang.rust +import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.rustlang.stripOuter +import software.amazon.smithy.rust.codegen.rustlang.writable +import software.amazon.smithy.rust.codegen.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.smithy.rustType +import software.amazon.smithy.rust.codegen.util.PANIC +import software.amazon.smithy.rust.codegen.util.findMemberWithTrait +import software.amazon.smithy.rust.codegen.util.hasTrait +import software.amazon.smithy.rust.codegen.util.inputShape +import software.amazon.smithy.rust.codegen.util.orNull +import software.amazon.smithy.rust.codegen.util.outputShape +import software.amazon.smithy.rust.codegen.util.toPascalCase + +// TODO(https://github.com/awslabs/smithy-rs/issues/1013) Support pagination when the idempotency trait is present +fun OperationShape.isPaginated(model: Model) = + hasTrait() && inputShape(model) + .findMemberWithTrait(model) == null + +class PaginatorGenerator private constructor( + private val model: Model, + private val symbolProvider: RustSymbolProvider, + service: ServiceShape, + operation: OperationShape, + private val generics: ClientGenerics +) { + + companion object { + fun paginatorType( + codegenContext: CodegenContext, + generics: ClientGenerics, + operationShape: OperationShape + ): RuntimeType? { + return if (operationShape.isPaginated(codegenContext.model)) { + PaginatorGenerator( + codegenContext.model, + codegenContext.symbolProvider, + codegenContext.serviceShape, + operationShape, + generics + ).paginatorType() + } else { + null + } + } + } + + private val paginatorName = "${operation.id.name.toPascalCase()}Paginator" + private val runtimeConfig = symbolProvider.config().runtimeConfig + private val idx = PaginatedIndex.of(model) + private val paginationInfo = + idx.getPaginationInfo(service, operation).orNull() ?: PANIC("failed to load pagination info") + private val module = RustModule( + "paginator", + RustMetadata(public = true), + documentation = "Paginators for the service" + ) + + private val outputType = operation.outputShape(model) + + private fun paginatorType(): RuntimeType = RuntimeType.forInlineFun( + paginatorName, + module, + generate() + ) + + private val codegenScope = arrayOf( + "generics" to generics.decl, + "bounds" to generics.bounds, + "page_size_setter" to pageSizeSetter(), + + // Operation Types + "operation" to symbolProvider.toSymbol(operation), + "Input" to symbolProvider.toSymbol(operation.inputShape(model)), + "Output" to symbolProvider.toSymbol(operation.outputShape(model)), + "Error" to operation.errorSymbol(symbolProvider), + "Builder" to operation.inputShape(model).builderSymbol(symbolProvider), + + // SDK Types + "SdkError" to CargoDependency.SmithyHttp(runtimeConfig).asType() + .copy(name = "result::SdkError"), + "client" to CargoDependency.SmithyClient(runtimeConfig).asType(), + "fn_stream" to CargoDependency.SmithyAsync(runtimeConfig).asType().member("future::fn_stream"), + + // External Types + "Stream" to CargoDependency.TokioStream.asType().member("Stream") + + ) + + /** Generate the paginator struct & impl **/ + private fun generate() = writable { + val outputTokenLens = NestedAccessorGenerator(symbolProvider).generateBorrowingAccessor( + outputType, + paginationInfo.outputTokenMemberPath + ) + val inputTokenMember = symbolProvider.toMemberName(paginationInfo.inputTokenMember) + rustTemplate( + """ + /// Paginator for #{operation:D} + pub struct $paginatorName#{generics:W} { + handle: std::sync::Arc, + builder: #{Builder} + } + + impl${generics.inst} ${paginatorName}${generics.inst} where #{bounds:W} { + /// Create a new paginator-wrapper + pub(crate) fn new(handle: std::sync::Arc, builder: #{Builder}) -> Self { + Self { + handle, + builder, + } + } + + #{page_size_setter:W} + + #{items_fn:W} + + + /// Create the pagination stream + /// + /// _Note:_ No requests will be dispatched until the stream is used (eg. with [`.next().await`](tokio_stream::StreamExt::next)). + pub fn send(self) -> impl #{Stream}>> + Unpin + where + R::Policy: #{client}::bounds::SmithyRetryPolicy< + #{Input}OperationOutputAlias, + #{Output}, + #{Error}, + #{Input}OperationRetryAlias + >, { + // Move individual fields out of self for the borrow checker + let builder = self.builder; + let handle = self.handle; + #{fn_stream}::FnStream::new(move |tx| Box::pin(async move { + // Build the input for the first time. If required fields are missing, this is where we'll produce an early error. + let mut input = match builder.build().map_err(|err| #{SdkError}::ConstructionFailure(err.into())) { + Ok(input) => input, + Err(e) => { let _ = tx.send(Err(e)).await; return; } + }; + loop { + let op = match input.make_operation(&handle.conf) + .await + .map_err(|err| #{SdkError}::ConstructionFailure(err.into())) { + Ok(op) => op, + Err(e) => { + let _ = tx.send(Err(e)).await; + return; + } + }; + let resp = handle.client.call(op).await; + // If the input member is None or it was an error + let done = match resp { + Ok(ref resp) => { + input.$inputTokenMember = #{output_token}(resp).cloned(); + input.$inputTokenMember.is_none() + }, + Err(_) => true, + }; + if tx.send(resp).await.is_err() { + // receiving end was dropped + return + } + if done { + return + } + } + + })) + } + } + """, + *codegenScope, + "items_fn" to itemsFn(), + "output_token" to outputTokenLens + ) + } + + /** Type of the inner item of the paginator */ + private fun itemType(): String { + val members = paginationInfo.itemsMemberPath + val type = symbolProvider.toSymbol(model.expectShape(members.last().target)).rustType() + check(type is RustType.Vec || type is RustType.HashMap) + return when (type) { + is RustType.Vec -> type.member.render(true) + is RustType.HashMap -> "(${type.key.render(true)}, ${type.member.render(true)})" + else -> PANIC("only HashMaps or Vecs may be used for item pagination.") + } + } + + /** Generate an `.items()` function to expose flattened pagination when modeled */ + private fun itemsFn(): Writable = writable { + itemsPaginator()?.also { itemPaginatorType -> + val documentedPath = + paginationInfo.itemsMemberPath.joinToString(".") { symbolProvider.toMemberName(it) } + rustTemplate( + """ + /// Create a flattened paginator + /// + /// This paginator automatically flattens results using `$documentedPath`. Queries to the underlying service + /// are dispatched lazily. + pub fn items(self) -> #{ItemPaginator}${generics.inst} { + #{ItemPaginator}(self) + } + """, + "ItemPaginator" to itemPaginatorType + ) + } + } + + /** Generate a struct with a `items()` method that flattens the paginator **/ + private fun itemsPaginator(): RuntimeType? = if (paginationInfo.itemsMemberPath.isEmpty()) { + null + } else { + RuntimeType.forInlineFun("${paginatorName}Items", module) { + it.rustTemplate( + """ + /// Flattened paginator for `$paginatorName` + /// + /// This is created with [`.items()`]($paginatorName::items) + pub struct ${paginatorName}Items#{generics:W}($paginatorName${generics.inst}); + + impl ${generics.inst} ${paginatorName}Items${generics.inst} where #{bounds:W} { + /// Create the pagination stream + /// + /// _Note: No requests will be dispatched until the stream is used (eg. with [`.next().await`](tokio_stream::StreamExt::next))._ + /// + /// To read the entirety of the paginator, use [`.collect::, _>()`](tokio_stream::StreamExt::collect). + pub fn send(self) -> impl #{Stream}>> + Unpin + where + R::Policy: #{client}::bounds::SmithyRetryPolicy< + #{Input}OperationOutputAlias, + #{Output}, + #{Error}, + #{Input}OperationRetryAlias + >, { + #{fn_stream}::TryFlatMap::new(self.0.send()).flat_map(|page| #{extract_items}(page).unwrap_or_default().into_iter()) + } + } + + """, + "extract_items" to NestedAccessorGenerator(symbolProvider).generateOwnedAccessor( + outputType, + paginationInfo.itemsMemberPath + ), + *codegenScope + ) + } + } + + private fun pageSizeSetter() = writable { + paginationInfo.pageSizeMember.orNull()?.also { + val memberName = symbolProvider.toMemberName(it) + val pageSizeT = symbolProvider.toSymbol(it).rustType().stripOuter().render(true) + rust( + """ + /// Set the page size + /// + /// _Note: this method will override any previously set value for `$memberName`_ + pub fn page_size(mut self, limit: $pageSizeT) -> Self { + self.builder.$memberName = Some(limit); + self + } + """ + ) + } + } +} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt index 57e53fedc..84f3ba74d 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/ServiceGenerator.kt @@ -34,9 +34,8 @@ class ServiceGenerator( private val index = TopDownIndex.of(config.model) /** - * Render Service Specific code. Code will end up in different files via `useShapeWriter`. See `SymbolVisitor.kt` + * Render Service-specific code. Code will end up in different files via `useShapeWriter`. See `SymbolVisitor.kt` * which assigns a symbol location to each shape. - * */ fun render() { val operations = index.getContainedOperations(config.serviceShape).sortedBy { it.id } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/InlineFunctionNamer.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/InlineFunctionNamer.kt index d07681a41..0f43eb11c 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/InlineFunctionNamer.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/InlineFunctionNamer.kt @@ -18,6 +18,12 @@ import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.util.PANIC import software.amazon.smithy.rust.codegen.util.toSnakeCase +fun RustSymbolProvider.lensName(prefix: String, root: Shape, path: List): String { + val base = shapeFunctionName("${prefix}lens", root) + val rest = path.joinToString("_") { toMemberName(it) } + return "${base}_$rest" +} + /** * Creates a unique name for a serialization function. * diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/PaginatorGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/PaginatorGeneratorTest.kt new file mode 100644 index 000000000..baf0a9c1a --- /dev/null +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/PaginatorGeneratorTest.kt @@ -0,0 +1,75 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.codegen.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.smithy.RustCodegenPlugin +import software.amazon.smithy.rust.codegen.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.testutil.generatePluginContext +import software.amazon.smithy.rust.codegen.util.runCommand + +internal class PaginatorGeneratorTest { + private val model = """ + namespace test + use aws.protocols#awsJson1_1 + + @awsJson1_1 + service TestService { + operations: [PaginatedList, PaginatedMap] + } + + @readonly + @paginated(inputToken: "nextToken", outputToken: "inner.token", + pageSize: "maxResults", items: "inner.items") + operation PaginatedList { + input: GetFoosInput, + output: GetFoosOutput + } + + @readonly + @paginated(inputToken: "nextToken", outputToken: "inner.token", + pageSize: "maxResults", items: "inner.mapItems") + operation PaginatedMap { + input: GetFoosInput, + output: GetFoosOutput + } + + structure GetFoosInput { + maxResults: Integer, + nextToken: String + } + + structure Inner { + token: String, + + @required + items: StringList, + + @required + mapItems: StringMap + } + + structure GetFoosOutput { + inner: Inner + } + + list StringList { + member: String + } + + map StringMap { + key: String, + value: Integer + } + """.asSmithyModel() + + @Test + fun `generate correct paginators`() { + val (ctx, testDir) = generatePluginContext(model) + RustCodegenPlugin().execute(ctx) + "cargo test".runCommand(testDir) + } +} diff --git a/design/src/rfcs/rfc0008_paginators.md b/design/src/rfcs/rfc0008_paginators.md new file mode 100644 index 000000000..26a086661 --- /dev/null +++ b/design/src/rfcs/rfc0008_paginators.md @@ -0,0 +1,197 @@ +## Summary + +Smithy [models paginated responses](https://awslabs.github.io/smithy/1.0/spec/core/behavior-traits.html#paginated-trait) +. Customers of Smithy generated code & the Rust SDK will have an improved user experience if code is generated to +support this. Fundamentally, paginators are a way to automatically make a series of requests with the SDK, where subsequent +requests automatically forward output from the previous responses. There is nothing a paginator does that a user could not do manually, +they merely simplify the common task of interacting with paginated APIs. **Specifically, a paginator will resend the orginal request +but with `inputToken` updated to the value of the previous `outputToken`. + +In this RFC, we propose modeling paginated data as +a [`Stream`](https://docs.rs/tokio-stream/0.1.5/tokio_stream/#traits) of output shapes. + +- When an output is paginated, a `paginate()` method will be added to the high level builder +- An `Paginator` struct will be generated into the `paginator` module. +- If `items` is modeled, `paginate().items()` will be added to produce the paginated + items. `PaginatorItems` will be generated into the `paginator` module. + +The [`Stream`](https://docs.rs/tokio-stream/latest/tokio_stream/index.html) trait enables customers to use a number of +abstractions including simple looping, and `collect()`ing all data in a single call. A paginator will resend the +original input, but with the field marked `inputToken` to the value of `outputToken` in the previous output. + +Usage example: + +```rust +let paginator = client + .list_tables() + .paginate() + .items() + .page_size(10) + .send() + .await; +let tables: Result, _ > = paginator.collect().await; +``` + +Paginators are lazy and only retrieve pages when polled by a client. + +### Details + +Paginators will be generated into the `paginator` module of service crates. Currently, paginators are _not_ feature gated, but this +could be considered in the future. A `paginator` struct captures 2 pieces of data: + +```rust +// dynamodb/src/paginator.rs +struct ListTablesPaginator { + // holds the low-level client and configuration + handle: Arc>, + + // input builder to construct the actual input on demand + input: ListTablesInputBuilder +} +``` + +In addition to the basic usage example above, when `pageSize` is modeled, customers can specify the page size during +pagination: + +```rust +let mut tables = vec![]; +let mut pages = client + .list_tables() + .paginate() + .page_size(20) + .send(); +while let Some(next_page) = pages.try_next().await? { + // pages of 20 items requested from DynamoDb + tables.extend(next_page.table_names.unwrap_or_default().into_iter()); +} +``` + +Paginators define a public method `send()`. This method +returns `impl Stream`. This uses `FnStream` defined in the `aws-smithy-async` crate which +enables demand driven execution of a closure. A rendezvous channel is used which will block on `send` until demand exists. + +When modeled by Smithy, `page_size` which automatically sets the appropriate page_size parameter and `items()` which returns an +automatically flattened paginator are also generated. **Note**: `page_size` directly sets the modeled parameter on the internal builder. +This means that a value set for page size will override any previously set value for that field. +```rust +// Generated paginator for ListTables +impl ListTablesPaginator +{ + /// Set the page size + pub fn page_size(mut self, limit: i32) -> Self { + self.builder.limit = Some(limit); + self + } + + /// Create a flattened paginator + /// + /// This paginator automatically flattens results using `table_names`. Queries to the underlying service + /// are dispatched lazily. + pub fn items(self) -> crate::paginator::ListTablesPaginatorItems { + crate::paginator::ListTablesPaginatorItems(self) + } + + /// Create the pagination stream + /// + /// _Note:_ No requests will be dispatched until the stream is used (eg. with [`.next().await`](tokio_stream::StreamExt::next)). + pub async fn send( + self, + ) -> impl tokio_stream::Stream< + Item = std::result::Result< + crate::output::ListTablesOutput, + aws_smithy_http::result::SdkError, + >, + > + Unpin + { + // Move individual fields out of self for the borrow checker + let builder = self.builder; + let handle = self.handle; + fn_stream::FnStream::new(move |tx| { + Box::pin(async move { + // Build the input for the first time. If required fields are missing, this is where we'll produce an early error. + let mut input = match builder.build().map_err(|err| { + SdkError::ConstructionFailure(err.into()) + }) { + Ok(input) => input, + Err(e) => { + let _ = tx.send(Err(e)).await; + return; + } + }; + loop { + let op = match input.make_operation(&handle.conf).await.map_err(|err| { + SdkError::ConstructionFailure(err.into()) + }) { + Ok(op) => op, + Err(e) => { + let _ = tx.send(Err(e)).await; + return; + } + }; + let resp = handle.client.call(op).await; + // If the input member is None or it was an error + let done = match resp { + Ok(ref resp) => { + input.exclusive_start_table_name = crate::lens::reflens_structure_crate_output_list_tables_output_last_evaluated_table_name(resp).cloned(); + input.exclusive_start_table_name.is_none() + } + Err(_) => true, + }; + if let Err(_) = tx.send(resp).await { + // receiving end was dropped + return; + } + if done { + return; + } + } + }) + }) + } +} +``` + +**On Box::pin**: The stream returned by `AsyncStream` does not implement `Unpin`. Unfortunately, this makes iteration +require an invocation of `pin_mut!` and generates several hundred lines of compiler errors. Box::pin seems a worthwhile +trade off to improve the user experience. + +**On the `+ Unpin` bound**: Because auto-traits leak across `impl Trait` boundaries, `+ Unpin` prevents accidental +regressions in the generated code which would break users. + +**On the crate::reflens::...**: We use `LensGenerator.kt` to generate potentially complex accessors to deeply nested fields. + +### Updates to ergonomic clients + +The `builders` generated by ergonomic clients will gain the following method, if they represent an operation that implements the `Paginated` trait: + +```rust +/// Create a paginator for this request +/// +/// Paginators are used by calling [`send().await`](crate::paginator::ListTablesPaginator::send) which returns a [`Stream`](tokio_stream::Stream). +pub fn paginate(self) -> crate::paginator::ListTablesPaginator { + crate::paginator::ListTablesPaginator::new(self.handle, self.inner) +} +``` + +## Discussion Areas +### On `send().await` +Calling `send().await` is not necessary from an API perspective—we could have the paginators impl-stream directly. However, +it enables using `impl Trait` syntax and also makes the API consistent with other SDK APIs. + +### On `tokio_stream::Stream` +Currently, the core trait we use is `tokio_stream::Stream`. This is a re-export from futures-util. There are a few other choices: +1. Re-export `Stream` from tokio_stream. +2. Use `futures_util` directly + +### On Generics +Currently, the paginators forward the generics from the client (`C, M, R`) along with their fairly annoying bounds. +However, if we wanted to we _could_ simplify this and erase all the generics when the paginator was created. Since everything +is code generated, there isn't actually much duplicated code in the generator, just in the generated code. + +## Changes Checklist +- [x] Create and test `FnStream` abstraction +- [x] Generate page-level paginators +- [x] Generate `.items()` paginators +- [x] Generate doc hints pointing people to paginators +- [x] Integration test using mocked HTTP traffic against a generated paginator for a real service +- [ ] Integration test using real traffic diff --git a/rust-runtime/aws-smithy-async/Cargo.toml b/rust-runtime/aws-smithy-async/Cargo.toml index 18d412449..4fa06eb71 100644 --- a/rust-runtime/aws-smithy-async/Cargo.toml +++ b/rust-runtime/aws-smithy-async/Cargo.toml @@ -8,15 +8,16 @@ license = "Apache-2.0" repository = "https://github.com/awslabs/smithy-rs" [features] -rt-tokio = ["tokio"] +rt-tokio = ["tokio/time"] [dependencies] pin-project-lite = "0.2" -tokio = { version = "1.6", features = ["time"], optional = true } +tokio = { version = "1.6", features = ["sync"] } +tokio-stream = "0.1.8" +futures-util = "0.3.16" [dev-dependencies] -tokio = { version = "1.6", features = ["rt", "macros"] } -futures-util = "0.3.16" +tokio = { version = "1.6", features = ["rt", "macros", "test-util"] } [package.metadata.docs.rs] all-features = true diff --git a/rust-runtime/aws-smithy-async/src/future/fn_stream.rs b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs new file mode 100644 index 000000000..352badbd6 --- /dev/null +++ b/rust-runtime/aws-smithy-async/src/future/fn_stream.rs @@ -0,0 +1,265 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +use crate::future::rendezvous; +use futures_util::StreamExt; +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio_stream::{Iter, Once, Stream}; + +pin_project! { + /// Utility to drive a stream with an async function and a channel. + /// + /// The closure is passed a reference to a `Sender` which acts as a rendezvous channel. Messages + /// sent to the sender will be emitted to the stream. Because the stream is 1-bounded, the function + /// will not proceed until the stream is read. + /// + /// This utility is used by generated paginators to generate a stream of paginated results. + /// + /// If `tx.send` returns an error, the function MUST return immediately. + /// + /// # Examples + /// ```no_run + /// use tokio_stream::StreamExt; + /// # async fn docs() { + /// use aws_smithy_async::future::fn_stream::FnStream; + /// let stream = FnStream::new(|tx| Box::pin(async move { + /// if let Err(_) = tx.send("Hello!").await { + /// return; + /// } + /// if let Err(_) = tx.send("Goodbye!").await { + /// return; + /// } + /// })); + /// assert_eq!(stream.collect::>().await, vec!["Hello!", "Goodbye!"]); + /// # } + pub struct FnStream { + #[pin] + rx: rendezvous::Receiver, + #[pin] + generator: F, + } +} + +impl FnStream { + pub fn new(generator: T) -> Self + where + T: FnOnce(rendezvous::Sender) -> F, + { + let (tx, rx) = rendezvous::channel::(); + Self { + rx, + generator: generator(tx), + } + } +} + +impl Stream for FnStream +where + F: Future, +{ + type Item = Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut me = self.project(); + match me.rx.poll_recv(cx) { + Poll::Ready(item) => Poll::Ready(item), + Poll::Pending => { + let _ = me.generator.poll(cx); + Poll::Pending + } + } + } +} + +/// Utility wrapper to flatten paginated results +/// +/// When flattening paginated results, it's most convenient to produce an iterator where the `Result` +/// is present in each item. This provides `items()` which can wrap an stream of `Result` +/// and produce a stream of `Result`. +pub struct TryFlatMap(I); + +impl TryFlatMap { + /// Create a `TryFlatMap` that wraps the input + pub fn new(i: I) -> Self { + Self(i) + } + + pub fn flat_map(self, map: M) -> impl Stream> + where + I: Stream>, + M: Fn(Page) -> Iter, + Iter: IntoIterator, + { + self.0.flat_map(move |page| match page { + Ok(page) => OnceOrMany::Many { + many: tokio_stream::iter(map(page).into_iter().map(Ok)), + }, + Err(e) => OnceOrMany::Once { + once: tokio_stream::once(Err(e)), + }, + }) + } +} + +pin_project! { + /// Helper enum to to support returning `Once` and `Iter` from `Items::items` + #[project = OnceOrManyProj] + enum OnceOrMany { + Many { #[pin] many: Iter }, + Once { #[pin] once: Once }, + } +} + +impl Stream for OnceOrMany +where + Iter: Iterator, +{ + type Item = Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + match me { + OnceOrManyProj::Many { many } => many.poll_next(cx), + OnceOrManyProj::Once { once } => once.poll_next(cx), + } + } +} + +#[cfg(test)] +mod test { + use crate::future::fn_stream::{FnStream, TryFlatMap}; + use std::sync::{Arc, Mutex}; + use std::time::Duration; + use tokio_stream::StreamExt; + + /// basic test of FnStream functionality + #[tokio::test] + async fn fn_stream_returns_results() { + tokio::time::pause(); + let mut stream = FnStream::new(|tx| { + Box::pin(async move { + tx.send("1").await.expect("failed to send"); + tokio::time::sleep(Duration::from_secs(1)).await; + tokio::time::sleep(Duration::from_secs(1)).await; + tx.send("2").await.expect("failed to send"); + tokio::time::sleep(Duration::from_secs(1)).await; + tx.send("3").await.expect("failed to send"); + }) + }); + let mut out = vec![]; + while let Some(value) = stream.next().await { + out.push(value); + } + assert_eq!(out, vec!["1", "2", "3"]); + } + + /// Tests that the generator will not advance until demand exists + #[tokio::test] + async fn waits_for_reader() { + let progress = Arc::new(Mutex::new(0)); + let mut stream = FnStream::new(|tx| { + let progress = progress.clone(); + Box::pin(async move { + *progress.lock().unwrap() = 1; + tx.send("1").await.expect("failed to send"); + *progress.lock().unwrap() = 2; + tx.send("2").await.expect("failed to send"); + *progress.lock().unwrap() = 3; + tx.send("3").await.expect("failed to send"); + *progress.lock().unwrap() = 4; + }) + }); + assert_eq!(*progress.lock().unwrap(), 0); + stream.next().await.expect("ready"); + assert_eq!(*progress.lock().unwrap(), 1); + + assert_eq!(stream.next().await.expect("ready"), "2"); + assert_eq!(*progress.lock().unwrap(), 2); + + let _ = stream.next().await.expect("ready"); + assert_eq!(*progress.lock().unwrap(), 3); + assert_eq!(stream.next().await, None); + assert_eq!(*progress.lock().unwrap(), 4); + } + + #[tokio::test] + async fn generator_with_errors() { + let mut stream = FnStream::new(|tx| { + Box::pin(async move { + for i in 0..5 { + if i != 2 { + if let Err(_) = tx.send(Ok(i)).await { + return; + } + } else { + tx.send(Err(i)).await.unwrap(); + return; + } + } + }) + }); + let mut out = vec![]; + while let Some(Ok(value)) = stream.next().await { + out.push(value); + } + assert_eq!(out, vec![0, 1]); + } + + #[tokio::test] + async fn flatten_items_ok() { + #[derive(Debug)] + struct Output { + items: Vec, + } + let stream = FnStream::new(|tx| { + Box::pin(async move { + tx.send(Ok(Output { + items: vec![1, 2, 3], + })) + .await + .unwrap(); + tx.send(Ok(Output { + items: vec![4, 5, 6], + })) + .await + .unwrap(); + }) + }); + assert_eq!( + TryFlatMap(stream) + .flat_map(|output| output.items.into_iter()) + .collect::, &str>>() + .await, + Ok(vec![1, 2, 3, 4, 5, 6]) + ) + } + + #[tokio::test] + async fn flatten_items_error() { + #[derive(Debug)] + struct Output { + items: Vec, + } + let stream = FnStream::new(|tx| { + Box::pin(async move { + tx.send(Ok(Output { + items: vec![1, 2, 3], + })) + .await + .unwrap(); + tx.send(Err("bummer")).await.unwrap(); + }) + }); + assert_eq!( + TryFlatMap(stream) + .flat_map(|output| output.items.into_iter()) + .collect::, &str>>() + .await, + Err("bummer") + ) + } +} diff --git a/rust-runtime/aws-smithy-async/src/future/mod.rs b/rust-runtime/aws-smithy-async/src/future/mod.rs index f88aef706..a39925e14 100644 --- a/rust-runtime/aws-smithy-async/src/future/mod.rs +++ b/rust-runtime/aws-smithy-async/src/future/mod.rs @@ -5,6 +5,8 @@ //! Useful runtime-agnostic future implementations. +pub mod fn_stream; pub mod never; pub mod now_or_later; +pub mod rendezvous; pub mod timeout; diff --git a/rust-runtime/aws-smithy-async/src/future/rendezvous.rs b/rust-runtime/aws-smithy-async/src/future/rendezvous.rs new file mode 100644 index 000000000..11981e520 --- /dev/null +++ b/rust-runtime/aws-smithy-async/src/future/rendezvous.rs @@ -0,0 +1,132 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +//! Rendezvous channel implementation +//! +//! Rendezvous channels are equivalent to a channel with a 0-sized buffer: A sender cannot send +//! until there is an active receiver waiting. This implementation uses a Semaphore to record demand +//! and coordinate with the receiver. +//! +//! Rendezvous channels should be used with care—it's inherently easy to deadlock unless they're being +//! used from separate tasks or an a coroutine setup (e.g. [`crate::future::fn_stream::FnStream`]) + +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::Semaphore; + +/// Create a new rendezvous channel +/// +/// Rendezvous channels are equivalent to a channel with a 0-sized buffer: A sender cannot send +/// until this is an active receiver waiting. This implementation uses a semaphore to record demand +/// and coordinate with the receiver. +pub fn channel() -> (Sender, Receiver) { + let (tx, rx) = tokio::sync::mpsc::channel(1); + let semaphore = Arc::new(Semaphore::new(0)); + ( + Sender { + semaphore: semaphore.clone(), + chan: tx, + }, + Receiver { + semaphore, + chan: rx, + needs_permit: false, + }, + ) +} + +/// Sender-half of a channel +pub struct Sender { + semaphore: Arc, + chan: tokio::sync::mpsc::Sender, +} + +impl Sender { + /// Send `item` into the channel waiting until there is matching demand + /// + /// Unlike something like `tokio::sync::mpsc::Channel` where sending a value will be buffered until + /// demand exists, a rendezvous sender will wait until matching demand exists before this function will return. + pub async fn send(&self, item: T) -> Result<(), SendError> { + let result = self.chan.send(item).await; + // If this is an error, the rx half has been dropped. We will never get demand. + if result.is_ok() { + // The key here is that we block _after_ the send until more demand exists + self.semaphore + .acquire() + .await + .expect("semaphore is never closed") + .forget(); + } + result + } +} + +/// Receiver half of the rendezvous channel +pub struct Receiver { + semaphore: Arc, + chan: tokio::sync::mpsc::Receiver, + needs_permit: bool, +} + +impl Receiver { + /// Polls to receive an item from the channel + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + // This uses `needs_permit` to track whether this is the first poll since we last returned an item. + // If it is, we will grant a permit to the semaphore. Otherwise, we'll just forward the response through. + let resp = self.chan.poll_recv(cx); + // If there is no data on the channel, but we are reading, then give a permit so we can load data + if self.needs_permit && matches!(resp, Poll::Pending) { + self.needs_permit = false; + self.semaphore.add_permits(1); + } + + if matches!(resp, Poll::Ready(_)) { + // we returned an item, no need to provide another permit until we fail to read from the channel again + self.needs_permit = true; + } + resp + } +} + +#[cfg(test)] +mod test { + use crate::future::rendezvous::{channel, Receiver}; + use std::sync::{Arc, Mutex}; + use tokio::macros::support::poll_fn; + + async fn recv(rx: &mut Receiver) -> Option { + poll_fn(|cx| rx.poll_recv(cx)).await + } + + #[tokio::test] + async fn send_blocks_caller() { + let (tx, mut rx) = channel::(); + let done = Arc::new(Mutex::new(0)); + let idone = done.clone(); + let send = tokio::spawn(async move { + *idone.lock().unwrap() = 1; + tx.send(0).await.unwrap(); + *idone.lock().unwrap() = 2; + tx.send(1).await.unwrap(); + *idone.lock().unwrap() = 3; + }); + assert_eq!(*done.lock().unwrap(), 0); + assert_eq!(recv(&mut rx).await, Some(0)); + assert_eq!(*done.lock().unwrap(), 1); + assert_eq!(recv(&mut rx).await, Some(1)); + assert_eq!(*done.lock().unwrap(), 2); + assert_eq!(recv(&mut rx).await, None); + assert_eq!(*done.lock().unwrap(), 3); + let _ = send.await; + } + + #[tokio::test] + async fn send_errors_when_rx_dropped() { + let (tx, rx) = channel::(); + drop(rx); + tx.send(0).await.expect_err("rx half dropped"); + } +} diff --git a/rust-runtime/aws-smithy-protocol-test/Cargo.toml b/rust-runtime/aws-smithy-protocol-test/Cargo.toml index 192beba34..ffa12d020 100644 --- a/rust-runtime/aws-smithy-protocol-test/Cargo.toml +++ b/rust-runtime/aws-smithy-protocol-test/Cargo.toml @@ -18,6 +18,7 @@ assert-json-diff = "1" pretty_assertions = "1.0" roxmltree = "0.14.1" + [package.metadata.docs.rs] all-features = true targets = ["x86_64-unknown-linux-gnu"] diff --git a/tools/publisher/src/subcommand/fix_manifests.rs b/tools/publisher/src/subcommand/fix_manifests.rs index 75e8a4d8f..c45f98917 100644 --- a/tools/publisher/src/subcommand/fix_manifests.rs +++ b/tools/publisher/src/subcommand/fix_manifests.rs @@ -53,18 +53,19 @@ async fn read_manifests(fs: Fs, manifest_paths: Vec) -> Result Result> { let mut versions = BTreeMap::new(); for manifest in manifests { - let name = manifest - .metadata - .get("package") - .and_then(|package| package.get("name")) + // ignore workspace manifests + let package = match manifest.metadata.get("package") { + Some(package) => package, + None => continue, + }; + let name = package + .get("name") .and_then(|name| name.as_str()) .ok_or_else(|| { anyhow::Error::msg(format!("{:?} is missing a package name", manifest.path)) })?; - let version = manifest - .metadata - .get("package") - .and_then(|package| package.get("version")) + let version = package + .get("version") .and_then(|name| name.as_str()) .ok_or_else(|| { anyhow::Error::msg(format!("{:?} is missing a package version", manifest.path)) -- GitLab