diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt index 16b3823fca9c85386d5e86658ab4f8d32be931e9..8bdfe0cdc13f25ee47ab2c9b932aef00181c3de7 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rustsdk import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.Approx import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.AsyncStd import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.AsyncStream import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.BytesUtils @@ -26,6 +27,8 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Compani import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.TracingAppender import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.TracingSubscriber import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.TracingTest +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.smithyRuntime +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency.Companion.smithyRuntimeApi import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.writable @@ -73,6 +76,7 @@ class IntegrationTestDependencies( private val hasTests: Boolean, private val hasBenches: Boolean, ) : LibRsCustomization() { + private val runtimeConfig = codegenContext.runtimeConfig override fun section(section: LibRsSection) = when (section) { is LibRsSection.Body -> testDependenciesOnly { if (hasTests) { @@ -80,14 +84,22 @@ class IntegrationTestDependencies( .copy(features = setOf("test-util"), scope = DependencyScope.Dev) val smithyAsync = CargoDependency.smithyAsync(codegenContext.runtimeConfig) .copy(features = setOf("test-util"), scope = DependencyScope.Dev) + val smithyTypes = CargoDependency.smithyTypes(codegenContext.runtimeConfig) + .copy(features = setOf("test-util"), scope = DependencyScope.Dev) addDependency(smithyClient) addDependency(smithyAsync) + addDependency(smithyTypes) addDependency(CargoDependency.smithyProtocolTestHelpers(codegenContext.runtimeConfig)) addDependency(SerdeJson) addDependency(Tokio) addDependency(FuturesUtil) addDependency(Tracing.toDevDependency()) addDependency(TracingSubscriber) + + if (codegenContext.smithyRuntimeMode.generateOrchestrator) { + addDependency(smithyRuntime(runtimeConfig).copy(features = setOf("test-util"), scope = DependencyScope.Dev)) + addDependency(smithyRuntimeApi(runtimeConfig).copy(features = setOf("test-util"), scope = DependencyScope.Dev)) + } } if (hasBenches) { addDependency(Criterion) @@ -103,6 +115,7 @@ class IntegrationTestDependencies( private fun serviceSpecificCustomizations(): List = when (moduleName) { "transcribestreaming" -> listOf(TranscribeTestDependencies()) "s3" -> listOf(S3TestDependencies(codegenContext)) + "dynamodb" -> listOf(DynamoDbTestDependencies()) else -> emptyList() } } @@ -116,6 +129,13 @@ class TranscribeTestDependencies : LibRsCustomization() { } } +class DynamoDbTestDependencies : LibRsCustomization() { + override fun section(section: LibRsSection): Writable = + writable { + addDependency(Approx) + } +} + class S3TestDependencies(private val codegenContext: ClientCodegenContext) : LibRsCustomization() { override fun section(section: LibRsSection): Writable = writable { diff --git a/aws/sdk/integration-tests/dynamodb/Cargo.toml b/aws/sdk/integration-tests/dynamodb/Cargo.toml index 78011372866ddffa275853bbd7492a563a42e687..3cd3a98b49217803e9398e8ebf067cf712babc34 100644 --- a/aws/sdk/integration-tests/dynamodb/Cargo.toml +++ b/aws/sdk/integration-tests/dynamodb/Cargo.toml @@ -11,14 +11,18 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +approx = "0.5.1" +aws-config = { path = "../../build/aws-sdk/sdk/aws-config" } aws-credential-types = { path = "../../build/aws-sdk/sdk/aws-credential-types", features = ["test-util"] } aws-http = { path = "../../build/aws-sdk/sdk/aws-http" } aws-sdk-dynamodb = { path = "../../build/aws-sdk/sdk/dynamodb" } +aws-smithy-async = { path = "../../build/aws-sdk/sdk/aws-smithy-async" } aws-smithy-client = { path = "../../build/aws-sdk/sdk/aws-smithy-client", features = ["test-util", "rustls"] } aws-smithy-http = { path = "../../build/aws-sdk/sdk/aws-smithy-http" } -aws-smithy-types = { path = "../../build/aws-sdk/sdk/aws-smithy-types" } aws-smithy-protocol-test = { path = "../../build/aws-sdk/sdk/aws-smithy-protocol-test" } -aws-smithy-async = { path = "../../build/aws-sdk/sdk/aws-smithy-async" } +aws-smithy-runtime = { path = "../../build/aws-sdk/sdk/aws-smithy-runtime", features = ["test-util"]} +aws-smithy-runtime-api = { path = "../../build/aws-sdk/sdk/aws-smithy-runtime-api", features = ["test-util"]} +aws-smithy-types = { path = "../../build/aws-sdk/sdk/aws-smithy-types", features = ["test-util"]} aws-types = { path = "../../build/aws-sdk/sdk/aws-types" } bytes = "1.0.0" criterion = { version = "0.4.0" } @@ -26,8 +30,8 @@ futures-util = { version = "0.3.16", default-features = false } http = "0.2.0" serde_json = "1.0.0" tokio = { version = "1.23.1", features = ["full", "test-util"] } -tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } tokio-stream = "0.1.5" +tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } [[bench]] name = "deserialization_bench" diff --git a/aws/sdk/integration-tests/dynamodb/tests/retries-with-client-rate-limiting.rs b/aws/sdk/integration-tests/dynamodb/tests/retries-with-client-rate-limiting.rs new file mode 100644 index 0000000000000000000000000000000000000000..d3f05db8656a9254cd8b619d16ac2da2bf4864c7 --- /dev/null +++ b/aws/sdk/integration-tests/dynamodb/tests/retries-with-client-rate-limiting.rs @@ -0,0 +1,173 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#[cfg(aws_sdk_orchestrator_mode)] +mod test { + use aws_sdk_dynamodb::config::{Credentials, Region, SharedAsyncSleep}; + use aws_sdk_dynamodb::{config::retry::RetryConfig, error::ProvideErrorMetadata}; + use aws_smithy_async::rt::sleep::TokioSleep; + use aws_smithy_async::test_util::instant_time_and_sleep; + use aws_smithy_async::time::SharedTimeSource; + use aws_smithy_async::time::SystemTimeSource; + use aws_smithy_client::test_connection::TestConnection; + use aws_smithy_http::body::SdkBody; + use aws_smithy_runtime::client::retries::RetryPartition; + use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse}; + use aws_smithy_types::timeout::TimeoutConfigBuilder; + use std::time::{Duration, Instant, SystemTime}; + + fn req() -> HttpRequest { + http::Request::builder() + .body(SdkBody::from("request body")) + .unwrap() + } + + fn ok() -> HttpResponse { + http::Response::builder() + .status(200) + .header("server", "Server") + .header("content-type", "application/x-amz-json-1.0") + .header("content-length", "23") + .header("connection", "keep-alive") + .header("x-amz-crc32", "2335643545") + .body(SdkBody::from("{ \"TableNames\": [ \"Test\" ] }")) + .unwrap() + } + + fn err() -> HttpResponse { + http::Response::builder() + .status(500) + .body(SdkBody::from("{ \"message\": \"The request has failed because of an unknown error, exception or failure.\", \"code\": \"InternalServerError\" }")) + .unwrap() + } + + fn throttling_err() -> HttpResponse { + http::Response::builder() + .status(400) + .body(SdkBody::from("{ \"message\": \"The request was denied due to request throttling.\", \"code\": \"ThrottlingException\" }")) + .unwrap() + } + + #[tokio::test] + async fn test_adaptive_retries_with_no_throttling_errors() { + let (time_source, sleep_impl) = instant_time_and_sleep(SystemTime::UNIX_EPOCH); + + let events = vec![ + // First operation + (req(), err()), + (req(), err()), + (req(), ok()), + // Second operation + (req(), err()), + (req(), ok()), + // Third operation will fail, only errors + (req(), err()), + (req(), err()), + (req(), err()), + (req(), err()), + ]; + + let conn = TestConnection::new(events); + let config = aws_sdk_dynamodb::Config::builder() + .credentials_provider(Credentials::for_tests()) + .region(Region::new("us-east-1")) + .retry_config( + RetryConfig::adaptive() + .with_max_attempts(4) + .with_use_static_exponential_base(true), + ) + .time_source(SharedTimeSource::new(time_source)) + .sleep_impl(SharedAsyncSleep::new(sleep_impl.clone())) + .retry_partition(RetryPartition::new( + "test_adaptive_retries_with_no_throttling_errors", + )) + .http_connector(conn.clone()) + .build(); + let expected_table_names = vec!["Test".to_owned()]; + + // We create a new client each time to ensure that the cross-client retry state is working. + let client = aws_sdk_dynamodb::Client::from_conf(config.clone()); + let res = client.list_tables().send().await.unwrap(); + assert_eq!(sleep_impl.total_duration(), Duration::from_secs(3)); + assert_eq!(res.table_names(), Some(expected_table_names.as_slice())); + // Three requests should have been made, two failing & one success + assert_eq!(conn.requests().len(), 3); + + let client = aws_sdk_dynamodb::Client::from_conf(config.clone()); + let res = client.list_tables().send().await.unwrap(); + assert_eq!(sleep_impl.total_duration(), Duration::from_secs(3 + 1)); + assert_eq!(res.table_names(), Some(expected_table_names.as_slice())); + // Two requests should have been made, one failing & one success (plus previous requests) + assert_eq!(conn.requests().len(), 5); + + let client = aws_sdk_dynamodb::Client::from_conf(config); + let err = client.list_tables().send().await.unwrap_err(); + assert_eq!(sleep_impl.total_duration(), Duration::from_secs(3 + 1 + 7),); + assert_eq!(err.code(), Some("InternalServerError")); + // four requests should have been made, all failing (plus previous requests) + assert_eq!(conn.requests().len(), 9); + } + + #[tokio::test] + async fn test_adaptive_retries_with_throttling_errors_times_out() { + tracing_subscriber::fmt::init(); + let events = vec![ + // First operation + (req(), err()), + (req(), ok()), + // Second operation + (req(), err()), + (req(), throttling_err()), + (req(), ok()), + ]; + + let conn = TestConnection::new(events); + let config = aws_sdk_dynamodb::Config::builder() + .credentials_provider(Credentials::for_tests()) + .region(Region::new("us-east-1")) + .retry_config( + RetryConfig::adaptive() + .with_max_attempts(4) + .with_initial_backoff(Duration::from_millis(50)) + .with_use_static_exponential_base(true), + ) + .timeout_config( + TimeoutConfigBuilder::new() + .operation_attempt_timeout(Duration::from_millis(100)) + .build(), + ) + .time_source(SharedTimeSource::new(SystemTimeSource::new())) + .sleep_impl(SharedAsyncSleep::new(TokioSleep::new())) + .http_connector(conn.clone()) + .retry_partition(RetryPartition::new( + "test_adaptive_retries_with_throttling_errors_times_out", + )) + .build(); + + let expected_table_names = vec!["Test".to_owned()]; + let start = Instant::now(); + + // We create a new client each time to ensure that the cross-client retry state is working. + let client = aws_sdk_dynamodb::Client::from_conf(config.clone()); + let res = client.list_tables().send().await.unwrap(); + assert_eq!(res.table_names(), Some(expected_table_names.as_slice())); + // Three requests should have been made, two failing & one success + assert_eq!(conn.requests().len(), 2); + + let client = aws_sdk_dynamodb::Client::from_conf(config); + let err = client.list_tables().send().await.unwrap_err(); + assert_eq!(err.to_string(), "request has timed out".to_owned()); + // two requests should have been made, both failing (plus previous requests) + assert_eq!(conn.requests().len(), 2 + 2); + + let since = start.elapsed(); + // At least 300 milliseconds must pass: + // - 50ms for the first retry on attempt 1 + // - 50ms for the second retry on attempt 3 + // - 100ms for the throttling delay triggered by attempt 4, which required a delay longer than the attempt timeout. + // - 100ms for the 5th attempt, which would have succeeded, but required a delay longer than the attempt timeout. + assert!(since.as_secs_f64() > 0.3); + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomization.kt index 04be1fca913a4d7c00c5810d769eefaa04ac4d03..07ece3fb7840c1428e8e350235f48ce8400e1251 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomization.kt @@ -7,8 +7,12 @@ package software.amazon.smithy.rust.codegen.client.smithy.customizations import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule +import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization +import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginSection import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig @@ -16,12 +20,14 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope import software.amazon.smithy.rust.codegen.core.smithy.RustCrate -class ResiliencyConfigCustomization(codegenContext: ClientCodegenContext) : ConfigCustomization() { +class ResiliencyConfigCustomization(private val codegenContext: ClientCodegenContext) : ConfigCustomization() { private val runtimeConfig = codegenContext.runtimeConfig private val runtimeMode = codegenContext.smithyRuntimeMode private val retryConfig = RuntimeType.smithyTypes(runtimeConfig).resolve("retry") private val sleepModule = RuntimeType.smithyAsync(runtimeConfig).resolve("rt::sleep") private val timeoutModule = RuntimeType.smithyTypes(runtimeConfig).resolve("timeout") + private val smithyRuntimeCrate = RuntimeType.smithyRuntime(runtimeConfig) + private val retries = smithyRuntimeCrate.resolve("client::retries") private val moduleUseName = codegenContext.moduleUseName() private val codegenScope = arrayOf( *preludeScope, @@ -29,8 +35,17 @@ class ResiliencyConfigCustomization(codegenContext: ClientCodegenContext) : Conf "RetryConfig" to retryConfig.resolve("RetryConfig"), "SharedAsyncSleep" to sleepModule.resolve("SharedAsyncSleep"), "Sleep" to sleepModule.resolve("Sleep"), - "StandardRetryStrategy" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::retries::strategy::StandardRetryStrategy"), + "StandardRetryStrategy" to retries.resolve("strategy::StandardRetryStrategy"), + "SystemTime" to RuntimeType.std.resolve("time::SystemTime"), "TimeoutConfig" to timeoutModule.resolve("TimeoutConfig"), + "RetryMode" to RuntimeType.smithyTypes(runtimeConfig).resolve("retry::RetryMode"), + "TokenBucket" to retries.resolve("TokenBucket"), + "ClientRateLimiter" to retries.resolve("ClientRateLimiter"), + "SharedTimeSource" to RuntimeType.smithyAsync(runtimeConfig).resolve("time::SharedTimeSource"), + "ClientRateLimiterPartition" to retries.resolve("ClientRateLimiterPartition"), + "TokenBucketPartition" to retries.resolve("TokenBucketPartition"), + "RetryPartition" to retries.resolve("RetryPartition"), + "debug" to RuntimeType.Tracing.resolve("debug"), ) override fun section(section: ServiceConfig) = @@ -67,6 +82,15 @@ class ResiliencyConfigCustomization(codegenContext: ClientCodegenContext) : Conf pub fn timeout_config(&self) -> #{Option}<&#{TimeoutConfig}> { self.inner.load::<#{TimeoutConfig}>() } + + ##[doc(hidden)] + /// Returns a reference to the retry partition contained in this config, if any. + /// + /// WARNING: This method is unstable and may be removed at any time. Do not rely on this + /// method for anything! + pub fn retry_partition(&self) -> #{Option}<&#{RetryPartition}> { + self.inner.load::<#{RetryPartition}>() + } """, *codegenScope, ) @@ -311,13 +335,67 @@ class ResiliencyConfigCustomization(codegenContext: ClientCodegenContext) : Conf *codegenScope, ) } + + if (runtimeMode.defaultToOrchestrator) { + Attribute.DocHidden.render(this) + rustTemplate( + """ + /// Set the partition for retry-related state. When clients share a retry partition, they will + /// also share things like token buckets and client rate limiters. By default, all clients + /// for the same service will share a partition. + pub fn retry_partition(mut self, retry_partition: #{RetryPartition}) -> Self { + self.set_retry_partition(Some(retry_partition)); + self + } + """, + *codegenScope, + ) + + Attribute.DocHidden.render(this) + rustTemplate( + """ + /// Set the partition for retry-related state. When clients share a retry partition, they will + /// also share things like token buckets and client rate limiters. By default, all clients + /// for the same service will share a partition. + pub fn set_retry_partition(&mut self, retry_partition: #{Option}<#{RetryPartition}>) -> &mut Self { + retry_partition.map(|r| self.inner.store_put(r)); + self + } + """, + *codegenScope, + ) + } } ServiceConfig.BuilderBuild -> { if (runtimeMode.defaultToOrchestrator) { rustTemplate( """ + let retry_partition = layer.load::<#{RetryPartition}>().cloned().unwrap_or_else(|| #{RetryPartition}::new("${codegenContext.serviceShape.id.name}")); let retry_config = layer.load::<#{RetryConfig}>().cloned().unwrap_or_else(#{RetryConfig}::disabled); + if retry_config.has_retry() { + #{debug}!("creating retry strategy with partition '{}'", retry_partition); + } + + if retry_config.mode() == #{RetryMode}::Adaptive { + if let Some(time_source) = layer.load::<#{SharedTimeSource}>().cloned() { + let seconds_since_unix_epoch = time_source + .now() + .duration_since(#{SystemTime}::UNIX_EPOCH) + .expect("the present takes place after the UNIX_EPOCH") + .as_secs_f64(); + let client_rate_limiter_partition = #{ClientRateLimiterPartition}::new(retry_partition.clone()); + let client_rate_limiter = CLIENT_RATE_LIMITER.get_or_init(client_rate_limiter_partition, || { + #{ClientRateLimiter}::new(seconds_since_unix_epoch) + }); + layer.store_put(client_rate_limiter); + } + } + + // The token bucket is used for both standard AND adaptive retries. + let token_bucket_partition = #{TokenBucketPartition}::new(retry_partition); + let token_bucket = TOKEN_BUCKET.get_or_init(token_bucket_partition, #{TokenBucket}::default); + layer.store_put(token_bucket); layer.set_retry_strategy(#{DynRetryStrategy}::new(#{StandardRetryStrategy}::new(&retry_config))); """, *codegenScope, @@ -355,6 +433,10 @@ class ResiliencyReExportCustomization(private val runtimeConfig: RuntimeConfig) "pub use #{types_retry}::{RetryConfig, RetryConfigBuilder, RetryMode};", "types_retry" to RuntimeType.smithyTypes(runtimeConfig).resolve("retry"), ) + rustTemplate( + "pub use #{RetryPartition};", + "RetryPartition" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::retries::RetryPartition"), + ) } rustCrate.withModule(ClientRustModule.Config.timeout) { rustTemplate( @@ -364,3 +446,33 @@ class ResiliencyReExportCustomization(private val runtimeConfig: RuntimeConfig) } } } + +class ResiliencyServiceRuntimePluginCustomization(codegenContext: ClientCodegenContext) : ServiceRuntimePluginCustomization() { + private val runtimeConfig = codegenContext.runtimeConfig + private val smithyRuntimeCrate = RuntimeType.smithyRuntime(runtimeConfig) + private val retries = smithyRuntimeCrate.resolve("client::retries") + private val codegenScope = arrayOf( + "TokenBucket" to retries.resolve("TokenBucket"), + "TokenBucketPartition" to retries.resolve("TokenBucketPartition"), + "ClientRateLimiter" to retries.resolve("ClientRateLimiter"), + "ClientRateLimiterPartition" to retries.resolve("ClientRateLimiterPartition"), + "StaticPartitionMap" to smithyRuntimeCrate.resolve("static_partition_map::StaticPartitionMap"), + ) + + override fun section(section: ServiceRuntimePluginSection): Writable = writable { + when (section) { + is ServiceRuntimePluginSection.DeclareSingletons -> { + // TODO(enableNewSmithyRuntimeCleanup) We can use the standard library's `OnceCell` once we upgrade the + // MSRV to 1.70 + rustTemplate( + """ + static TOKEN_BUCKET: #{StaticPartitionMap}<#{TokenBucketPartition}, #{TokenBucket}> = #{StaticPartitionMap}::new(); + static CLIENT_RATE_LIMITER: #{StaticPartitionMap}<#{ClientRateLimiterPartition}, #{ClientRateLimiter}> = #{StaticPartitionMap}::new(); + """, + *codegenScope, + ) + } + else -> emptySection + } + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt index fd1aacc68980d88399b0f6d7196c1ad80617f769..4b85a6bb533e6c252016bd741687abdf09b80dab 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt @@ -16,9 +16,11 @@ import software.amazon.smithy.rust.codegen.client.smithy.customizations.Identity import software.amazon.smithy.rust.codegen.client.smithy.customizations.InterceptorConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.customizations.ResiliencyConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.customizations.ResiliencyReExportCustomization +import software.amazon.smithy.rust.codegen.client.smithy.customizations.ResiliencyServiceRuntimePluginCustomization import software.amazon.smithy.rust.codegen.client.smithy.customizations.TimeSourceCustomization import software.amazon.smithy.rust.codegen.client.smithy.customizations.TimeSourceOperationCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCustomization +import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.core.rustlang.Feature import software.amazon.smithy.rust.codegen.core.smithy.RustCrate @@ -95,4 +97,13 @@ class RequiredCustomizations : ClientCodegenDecorator { } } } + + override fun serviceRuntimePluginCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List = if (codegenContext.smithyRuntimeMode.generateOrchestrator) { + baseCustomizations + ResiliencyServiceRuntimePluginCustomization(codegenContext) + } else { + baseCustomizations + } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt index c8f0457b6510f472387c248c8d2d3228101b90fe..bf47e9f41c5045ec5810f9606139e282ef7a7cd2 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceRuntimePluginGenerator.kt @@ -21,6 +21,13 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizat import software.amazon.smithy.rust.codegen.core.util.dq sealed class ServiceRuntimePluginSection(name: String) : Section(name) { + /** + * Hook for declaring singletons that store cross-operation state. + * + * Examples include token buckets, ID generators, etc. + */ + class DeclareSingletons : ServiceRuntimePluginSection("DeclareSingletons") + /** * Hook for adding additional things to config inside service runtime plugins. */ @@ -128,6 +135,9 @@ class ServiceRuntimePluginGenerator( #{additional_interceptors} } } + + /// Cross-operation shared-state singletons + #{declare_singletons} """, *codegenScope, "config" to writable { @@ -154,6 +164,9 @@ class ServiceRuntimePluginGenerator( "additional_interceptors" to writable { writeCustomizations(customizations, ServiceRuntimePluginSection.RegisterInterceptor("_interceptors")) }, + "declare_singletons" to writable { + writeCustomizations(customizations, ServiceRuntimePluginSection.DeclareSingletons()) + }, ) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt index c4af5f97fb4121bea21ca34a5ccb752b8b3002f4..fdb92dd3f70af3cca7ccca83dbde5436ee0a918d 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt @@ -67,6 +67,7 @@ open class MakeOperationGenerator( "OpBuildError" to runtimeConfig.operationBuildError(), "SdkBody" to RuntimeType.sdkBody(runtimeConfig), "SharedPropertyBag" to RuntimeType.smithyHttp(runtimeConfig).resolve("property_bag::SharedPropertyBag"), + "RetryMode" to RuntimeType.smithyTypes(runtimeConfig).resolve("retry::RetryMode"), ) fun generateMakeOperation( @@ -98,6 +99,12 @@ open class MakeOperationGenerator( "$fnType $functionName($self, _config: &#{config}::Config) -> $returnType", *codegenScope, ) { + rustTemplate( + """ + assert_ne!(_config.retry_config().map(|rc| rc.mode()), #{Option}::Some(#{RetryMode}::Adaptive), "Adaptive retry mode is unsupported, please use Standard mode or disable retries."); + """, + *codegenScope, + ) writeCustomizations(customizations, OperationSection.MutateInput(customizations, "self", "_config")) withBlock("let mut request = {", "};") { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index 9187f35e6eb2a10fd07317cd65095ce68b77f9b3..d850b2cc26443cc0f9e4ab0a5d2b1695f6727d76 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -233,6 +233,7 @@ data class CargoDependency( val Tracing: CargoDependency = CargoDependency("tracing", CratesIo("0.1")) // Test-only dependencies + val Approx: CargoDependency = CargoDependency("approx", CratesIo("0.5.1"), DependencyScope.Dev) val AsyncStd: CargoDependency = CargoDependency("async-std", CratesIo("1.12.0"), DependencyScope.Dev) val AsyncStream: CargoDependency = CargoDependency("async-stream", CratesIo("0.3.0"), DependencyScope.Dev) val Criterion: CargoDependency = CargoDependency("criterion", CratesIo("0.4.0"), DependencyScope.Dev) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/testutil/PythonServerTestHelpers.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/testutil/PythonServerTestHelpers.kt index 669279c8d5332c7ce67771e76b2ac230d743e556..c7e6023f3cc95a0ed87d0e5e90f13fc10b2587bb 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/testutil/PythonServerTestHelpers.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/testutil/PythonServerTestHelpers.kt @@ -41,7 +41,7 @@ fun executePythonServerCodegenVisitor(pluginCtx: PluginContext) { fun cargoTest(workdir: Path) = // `--no-default-features` is required to disable `pyo3/extension-module` which causes linking errors // see `PyO3ExtensionModuleDecorator`'s comments fore more detail. - "cargo test --no-default-features".runCommand( + "cargo test --no-default-features --no-fail-fast".runCommand( workdir, mapOf( // Those are required to run tests on macOS, see: https://pyo3.rs/main/building_and_distribution#macos diff --git a/rust-runtime/aws-smithy-async/src/test_util.rs b/rust-runtime/aws-smithy-async/src/test_util.rs index 8c867869e83332bf1a3f084e2c51982220939513..fa1dfe300b5a9369807c1c701f6616d9d9da7b30 100644 --- a/rust-runtime/aws-smithy-async/src/test_util.rs +++ b/rust-runtime/aws-smithy-async/src/test_util.rs @@ -22,6 +22,21 @@ pub struct ManualTimeSource { log: Arc>>, } +#[cfg(feature = "test-util")] +impl ManualTimeSource { + /// Get the number of seconds since the UNIX Epoch as an f64. + /// + /// ## Panics + /// + /// This will panic if `self.now()` returns a time that's before the UNIX Epoch. + pub fn seconds_since_unix_epoch(&self) -> f64 { + self.now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs_f64() + } +} + impl TimeSource for ManualTimeSource { fn now(&self) -> SystemTime { self.start_time + self.log.lock().unwrap().iter().sum::() diff --git a/rust-runtime/aws-smithy-eventstream/src/frame.rs b/rust-runtime/aws-smithy-eventstream/src/frame.rs index 5983d09c8069d0cc879c0b4806446e653dbddf17..202e410827a1813a4443a67b7a4ab26f4ede8438 100644 --- a/rust-runtime/aws-smithy-eventstream/src/frame.rs +++ b/rust-runtime/aws-smithy-eventstream/src/frame.rs @@ -962,24 +962,20 @@ mod deferred_signer_tests { impl SignMessage for TestSigner { fn sign( &mut self, - message: crate::frame::Message, - ) -> Result { + message: Message, + ) -> Result { self.call_num += 1; Ok(message.add_header(Header::new("call_num", HeaderValue::Int32(self.call_num)))) } - fn sign_empty( - &mut self, - ) -> Option> { + fn sign_empty(&mut self) -> Option> { None } } let (mut signer, sender) = check_send_sync(DeferredSigner::new()); - sender - .send(Box::new(TestSigner::default())) - .expect("success"); + sender.send(Box::::default()).expect("success"); let message = signer.sign(Message::new(Bytes::new())).expect("success"); assert_eq!(1, message.headers()[0].value().as_int32().unwrap()); diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs index 25443d86d238c138a4c5f4346fa6b7ea0e57cfcb..dd918faba02ff46c40e9d3e364f7e6a80f6c7648 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/interceptors.rs @@ -15,7 +15,6 @@ use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreAppend, StoreReplac use aws_smithy_types::error::display::DisplayErrorContext; use context::{Error, Input, Output}; use std::fmt; -use std::fmt::{Debug, Formatter}; use std::marker::PhantomData; use std::ops::Deref; use std::sync::Arc; @@ -54,7 +53,7 @@ macro_rules! interceptor_trait_fn { /// of the SDK ’s request execution pipeline. Hooks are either "read" hooks, which make it possible /// to read in-flight request or response messages, or "read/write" hooks, which make it possible /// to modify in-flight request or output messages. -pub trait Interceptor: std::fmt::Debug { +pub trait Interceptor: fmt::Debug { interceptor_trait_fn!( read_before_execution, BeforeSerializationInterceptorContextRef, @@ -584,8 +583,8 @@ pub struct SharedInterceptor { check_enabled: Arc bool + Send + Sync>, } -impl Debug for SharedInterceptor { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { +impl fmt::Debug for SharedInterceptor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SharedInterceptor") .field("interceptor", &self.interceptor) .finish() @@ -966,7 +965,7 @@ mod tests { 2 ); interceptors - .read_before_transmit(&mut InterceptorContext::new(Input::new(5)), &mut cfg) + .read_before_transmit(&InterceptorContext::new(Input::new(5)), &mut cfg) .expect_err("interceptor returns error"); cfg.interceptor_state() .store_put(disable_interceptor::("test")); @@ -979,7 +978,7 @@ mod tests { ); // shouldn't error because interceptors won't run interceptors - .read_before_transmit(&mut InterceptorContext::new(Input::new(5)), &mut cfg) + .read_before_transmit(&InterceptorContext::new(Input::new(5)), &mut cfg) .expect("interceptor is now disabled"); } } diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/request_attempts.rs b/rust-runtime/aws-smithy-runtime-api/src/client/request_attempts.rs index 1eb115d480a01c9d8365a1b2a8e112a2f7693f49..440019c6603797c19977fb696f425f40398c1dd4 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/request_attempts.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/request_attempts.rs @@ -7,26 +7,26 @@ use aws_smithy_types::config_bag::{Storable, StoreReplace}; #[derive(Debug, Clone, Copy)] pub struct RequestAttempts { - attempts: usize, + attempts: u32, } impl RequestAttempts { #[cfg(any(feature = "test-util", test))] - pub fn new(attempts: usize) -> Self { + pub fn new(attempts: u32) -> Self { Self { attempts } } - pub fn attempts(&self) -> usize { + pub fn attempts(&self) -> u32 { self.attempts } } -impl Storable for RequestAttempts { - type Storer = StoreReplace; -} - -impl From for RequestAttempts { - fn from(attempts: usize) -> Self { +impl From for RequestAttempts { + fn from(attempts: u32) -> Self { Self { attempts } } } + +impl Storable for RequestAttempts { + type Storer = StoreReplace; +} diff --git a/rust-runtime/aws-smithy-runtime/Cargo.toml b/rust-runtime/aws-smithy-runtime/Cargo.toml index 281e4245f9affbcbb947e5a84594b20438ad3683..d63a748b5cfcaa67f90567bcfc49756aa0e637be 100644 --- a/rust-runtime/aws-smithy-runtime/Cargo.toml +++ b/rust-runtime/aws-smithy-runtime/Cargo.toml @@ -28,6 +28,7 @@ pin-utils = "0.1.0" tokio = { version = "1.25", features = [] } tracing = "0.1.37" fastrand = "1.4" +once_cell = "1.18.0" [dev-dependencies] approx = "0.5.1" diff --git a/rust-runtime/aws-smithy-runtime/src/client/connections/test_connection.rs b/rust-runtime/aws-smithy-runtime/src/client/connections/test_connection.rs index 8c00beae6c4e238aaaaa9a384fe44c7d56e7b2ae..7db7ba541ce49c4c9bba779fdb443a1ca45375e0 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/connections/test_connection.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/connections/test_connection.rs @@ -5,7 +5,7 @@ //! Module with client connectors useful for testing. -use aws_smithy_async::rt::sleep::AsyncSleep; +use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep}; use aws_smithy_http::body::SdkBody; use aws_smithy_http::result::ConnectorError; use aws_smithy_protocol_test::{assert_ok, validate_body, MediaType}; @@ -191,16 +191,16 @@ impl ValidateRequest { pub struct TestConnection { data: Arc>, requests: Arc>>, - sleep_impl: Arc, + sleep_impl: SharedAsyncSleep, } impl TestConnection { - pub fn new(mut data: ConnectionEvents, sleep_impl: Arc) -> Self { + pub fn new(mut data: ConnectionEvents, sleep_impl: impl Into) -> Self { data.reverse(); TestConnection { data: Arc::new(Mutex::new(data)), requests: Default::default(), - sleep_impl, + sleep_impl: sleep_impl.into(), } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs index c38b3820455566d2ee20cd04ab0d12a5768d1f52..bd9ae34d5f4e04bb7e39e9e37a19ed0cac44b7c3 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs @@ -28,7 +28,7 @@ use aws_smithy_runtime_api::client::retries::ShouldAttempt; use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugins; use aws_smithy_types::config_bag::ConfigBag; use std::mem; -use tracing::{debug, debug_span, instrument, Instrument}; +use tracing::{debug, debug_span, instrument, trace, Instrument}; mod auth; /// Defines types that implement a trait for endpoint resolution @@ -174,7 +174,7 @@ async fn try_op( .unwrap_or(Ok(ShouldAttempt::Yes)); match should_attempt { // Yes, let's make a request - Ok(ShouldAttempt::Yes) => debug!("retry strategy has OK'd initial request"), + Ok(ShouldAttempt::Yes) => debug!("retry strategy has OKed initial request"), // No, this request shouldn't be sent Ok(ShouldAttempt::No) => { let err: BoxError = "the retry strategy indicates that an initial request shouldn't be made, but it didn't specify why".into(); @@ -182,15 +182,20 @@ async fn try_op( } // No, we shouldn't make a request because... Err(err) => halt!([ctx] => OrchestratorError::other(err)), - Ok(ShouldAttempt::YesAfterDelay(_)) => { - unreachable!("Delaying the initial request is currently unsupported. If this feature is important to you, please file an issue in GitHub.") + Ok(ShouldAttempt::YesAfterDelay(delay)) => { + let sleep_impl = halt_on_err!([ctx] => cfg.sleep_impl().ok_or(OrchestratorError::other( + "the retry strategy requested a delay before sending the initial request, but no 'async sleep' implementation was set" + ))); + debug!("retry strategy has OKed initial request after a {delay:?} delay"); + sleep_impl.sleep(delay).await; } } // Save a request checkpoint before we make the request. This will allow us to "rewind" // the request in the case of retry attempts. ctx.save_checkpoint(); - for i in 1usize.. { + let mut retry_delay = None; + for i in 1u32.. { debug!("beginning attempt #{i}"); // Break from the loop if we can't rewind the request's state. This will always succeed the // first time, but will fail on subsequent iterations if the request body wasn't retryable. @@ -203,6 +208,11 @@ async fn try_op( .store_put::(i.into()); let attempt_timeout_config = cfg.maybe_timeout_config(TimeoutKind::OperationAttempt); let maybe_timeout = async { + // We must await this here or else timeouts won't work as expected + if let Some(delay) = retry_delay.take() { + delay.await; + } + try_attempt(ctx, cfg, interceptors, stop_point).await; finally_attempt(ctx, cfg, interceptors).await; Result::<_, SdkError>::Ok(()) @@ -229,14 +239,14 @@ async fn try_op( ShouldAttempt::Yes => continue, // No, this request shouldn't be retried ShouldAttempt::No => { - debug!("this error is not retryable, exiting attempt loop"); + debug!("a retry is either unnecessary or not possible, exiting attempt loop"); break; } ShouldAttempt::YesAfterDelay(delay) => { let sleep_impl = halt_on_err!([ctx] => cfg.sleep_impl().ok_or(OrchestratorError::other( - "the retry strategy requested a delay before sending the next request, but no 'async sleep' implementation was set" + "the retry strategy requested a delay before sending the retry request, but no 'async sleep' implementation was set" ))); - sleep_impl.sleep(delay).await; + retry_delay = Some(sleep_impl.sleep(delay)); continue; } } @@ -278,6 +288,7 @@ async fn try_attempt( } }) }); + trace!(response = ?call_result, "received response from service"); ctx.set_response(call_result); ctx.enter_before_deserialization_phase(); diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries.rs b/rust-runtime/aws-smithy-runtime/src/client/retries.rs index 9c41e5366a233d53b912924cecc6cd20a4a7d066..893c5f0163e7c2c4f84eaf0aa54e3ec5b1b6189d 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries.rs @@ -9,5 +9,35 @@ pub mod strategy; mod client_rate_limiter; mod token_bucket; -pub use client_rate_limiter::ClientRateLimiterRuntimePlugin; -pub use token_bucket::TokenBucketRuntimePlugin; +use aws_smithy_types::config_bag::{Storable, StoreReplace}; +pub use client_rate_limiter::{ClientRateLimiter, ClientRateLimiterRuntimePlugin}; +use std::fmt; +pub use token_bucket::{TokenBucket, TokenBucketRuntimePlugin}; + +#[doc(hidden)] +pub use client_rate_limiter::ClientRateLimiterPartition; +#[doc(hidden)] +pub use token_bucket::TokenBucketPartition; + +#[doc(hidden)] +#[non_exhaustive] +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct RetryPartition { + inner: &'static str, +} + +impl RetryPartition { + pub fn new(name: &'static str) -> Self { + Self { inner: name } + } +} + +impl fmt::Display for RetryPartition { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.inner) + } +} + +impl Storable for RetryPartition { + type Storer = StoreReplace; +} diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/client_rate_limiter.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/client_rate_limiter.rs index 6b41f0fddd9b9f1d9c7a1c150c460d2138bfd9d2..69f36272ab07411cc2045c1111e218d0076e96c4 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/client_rate_limiter.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/client_rate_limiter.rs @@ -6,43 +6,57 @@ //! A rate limiter for controlling the rate at which AWS requests are made. The rate changes based //! on the number of throttling errors encountered. -// TODO(enableNewSmithyRuntimeLaunch): Zelda will integrate this rate limiter into the retry policy in a separate PR. #![allow(dead_code)] -use aws_smithy_runtime_api::box_error::BoxError; -use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors; +use crate::client::retries::RetryPartition; use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin; use aws_smithy_runtime_api::{builder, builder_methods, builder_struct}; -use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer, Layer, Storable, StoreReplace}; +use aws_smithy_types::config_bag::{FrozenLayer, Layer, Storable, StoreReplace}; use std::sync::{Arc, Mutex}; -use std::time::SystemTime; +use std::time::Duration; +use tracing::debug; /// A [RuntimePlugin] to provide a client rate limiter, usable by a retry strategy. #[non_exhaustive] #[derive(Debug)] pub struct ClientRateLimiterRuntimePlugin { - _rate_limiter: Arc>, + rate_limiter: ClientRateLimiter, } impl ClientRateLimiterRuntimePlugin { - pub fn new(cfg: &ConfigBag) -> Self { + pub fn new(seconds_since_unix_epoch: f64) -> Self { Self { - _rate_limiter: Arc::new(Mutex::new(ClientRateLimiter::new(cfg))), + rate_limiter: ClientRateLimiter::new(seconds_since_unix_epoch), } } } impl RuntimePlugin for ClientRateLimiterRuntimePlugin { fn config(&self) -> Option { - let cfg = Layer::new("client rate limiter"); - // TODO(enableNewSmithyRuntimeLaunch) Move the Arc/Mutex inside the rate limiter so that it - // be both storable and cloneable. - // cfg.store_put(self.rate_limiter.clone()); + let mut cfg = Layer::new("client rate limiter"); + cfg.store_put(self.rate_limiter.clone()); Some(cfg.freeze()) } } +#[doc(hidden)] +#[non_exhaustive] +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct ClientRateLimiterPartition { + retry_partition: RetryPartition, +} + +impl ClientRateLimiterPartition { + pub fn new(retry_partition: RetryPartition) -> Self { + Self { retry_partition } + } +} + +const RETRY_COST: f64 = 5.0; +const RETRY_TIMEOUT_COST: f64 = RETRY_COST * 2.0; +const INITIAL_REQUEST_COST: f64 = 1.0; + const MIN_FILL_RATE: f64 = 0.5; const MIN_CAPACITY: f64 = 1.0; const SMOOTH: f64 = 0.8; @@ -52,32 +66,40 @@ const BETA: f64 = 0.7; const SCALE_CONSTANT: f64 = 0.4; #[derive(Clone, Debug)] -pub(crate) struct ClientRateLimiter { +pub struct ClientRateLimiter { + inner: Arc>, +} + +#[derive(Debug)] +pub(crate) struct Inner { /// The rate at which token are replenished. - token_refill_rate: f64, + fill_rate: f64, /// The maximum capacity allowed in the token bucket. - maximum_bucket_capacity: f64, + max_capacity: f64, /// The current capacity of the token bucket. - /// The minimum this can be is 1.0 - current_bucket_capacity: f64, + current_capacity: f64, /// The last time the token bucket was refilled. - time_of_last_refill: Option, + last_timestamp: Option, + /// Boolean indicating if the token bucket is enabled. + /// The token bucket is initially disabled. + /// When a throttling error is encountered it is enabled. + enabled: bool, /// The smoothed rate which tokens are being retrieved. - tokens_retrieved_per_second: f64, + measured_tx_rate: f64, /// The last half second time bucket used. - previous_time_bucket: f64, + last_tx_rate_bucket: f64, /// The number of requests seen within the current time bucket. request_count: u64, - /// Boolean indicating if the token bucket is enabled. - /// The token bucket is initially disabled. - /// When a throttling error is encountered it is enabled. - enable_throttling: bool, /// The maximum rate when the client was last throttled. - tokens_retrieved_per_second_at_time_of_last_throttle: f64, + last_max_rate: f64, /// The last time when the client was throttled. time_of_last_throttle: f64, - time_window: f64, - calculated_rate: f64, +} + +pub(crate) enum RequestReason { + Retry, + RetryTimeout, + InitialRequest, } impl Storable for ClientRateLimiter { @@ -85,10 +107,11 @@ impl Storable for ClientRateLimiter { } impl ClientRateLimiter { - pub(crate) fn new(cfg: &ConfigBag) -> Self { + pub fn new(seconds_since_unix_epoch: f64) -> Self { Self::builder() - .time_of_last_throttle(get_unix_timestamp(cfg)) - .previous_time_bucket(get_unix_timestamp(cfg).floor()) + .tokens_retrieved_per_second(MIN_FILL_RATE) + .time_of_last_throttle(seconds_since_unix_epoch) + .previous_time_bucket(seconds_since_unix_epoch.floor()) .build() } @@ -96,108 +119,139 @@ impl ClientRateLimiter { Builder::new() } - /// If this function returns `Ok(())`, you're OK to send a request. If it returns an error, - /// then you should not send a request; You've sent quite enough already. pub(crate) fn acquire_permission_to_send_a_request( - &mut self, + &self, seconds_since_unix_epoch: f64, - amount: f64, - ) -> Result<(), BoxError> { - if !self.enable_throttling { + kind: RequestReason, + ) -> Result<(), Duration> { + let mut it = self.inner.lock().unwrap(); + + if !it.enabled { // return early if we haven't encountered a throttling error yet return Ok(()); } + let amount = match kind { + RequestReason::Retry => RETRY_COST, + RequestReason::RetryTimeout => RETRY_TIMEOUT_COST, + RequestReason::InitialRequest => INITIAL_REQUEST_COST, + }; + + it.refill(seconds_since_unix_epoch); + + let res = if amount > it.current_capacity { + let sleep_time = (amount - it.current_capacity) / it.fill_rate; + debug!( + amount, + it.current_capacity, + it.fill_rate, + sleep_time, + "client rate limiter delayed a request" + ); - self.refill(seconds_since_unix_epoch); - - if self.current_bucket_capacity < amount { - Err(BoxError::from("the client rate limiter is out of tokens")) + Err(Duration::from_secs_f64(sleep_time)) } else { - self.current_bucket_capacity -= amount; Ok(()) - } + }; + + it.current_capacity -= amount; + res } pub(crate) fn update_rate_limiter( - &mut self, + &self, seconds_since_unix_epoch: f64, is_throttling_error: bool, ) { - self.update_tokens_retrieved_per_second(seconds_since_unix_epoch); + let mut it = self.inner.lock().unwrap(); + it.update_tokens_retrieved_per_second(seconds_since_unix_epoch); + let calculated_rate; if is_throttling_error { - let rate_to_use = if self.enable_throttling { - f64::min(self.tokens_retrieved_per_second, self.token_refill_rate) + let rate_to_use = if it.enabled { + f64::min(it.measured_tx_rate, it.fill_rate) } else { - self.tokens_retrieved_per_second + it.measured_tx_rate }; // The fill_rate is from the token bucket - self.tokens_retrieved_per_second_at_time_of_last_throttle = rate_to_use; - self.calculate_time_window(); - self.time_of_last_throttle = seconds_since_unix_epoch; - self.calculated_rate = cubic_throttle(rate_to_use); - self.enable_token_bucket(); + it.last_max_rate = rate_to_use; + it.calculate_time_window(); + it.time_of_last_throttle = seconds_since_unix_epoch; + calculated_rate = cubic_throttle(rate_to_use); + it.enable_token_bucket(); } else { - self.calculate_time_window(); - self.calculated_rate = self.cubic_success(seconds_since_unix_epoch); + it.calculate_time_window(); + calculated_rate = it.cubic_success(seconds_since_unix_epoch); } - let new_rate = f64::min(self.calculated_rate, 2.0 * self.tokens_retrieved_per_second); - self.update_bucket_refill_rate(seconds_since_unix_epoch, new_rate); + let new_rate = f64::min(calculated_rate, 2.0 * it.measured_tx_rate); + it.update_bucket_refill_rate(seconds_since_unix_epoch, new_rate); } +} +impl Inner { fn refill(&mut self, seconds_since_unix_epoch: f64) { - if let Some(last_timestamp) = self.time_of_last_refill { - let fill_amount = (seconds_since_unix_epoch - last_timestamp) * self.token_refill_rate; - self.current_bucket_capacity = f64::min( - self.maximum_bucket_capacity, - self.current_bucket_capacity + fill_amount, + if let Some(last_timestamp) = self.last_timestamp { + let fill_amount = (seconds_since_unix_epoch - last_timestamp) * self.fill_rate; + self.current_capacity = + f64::min(self.max_capacity, self.current_capacity + fill_amount); + debug!( + fill_amount, + self.current_capacity, self.max_capacity, "refilling client rate limiter tokens" ); } - self.time_of_last_refill = Some(seconds_since_unix_epoch); + self.last_timestamp = Some(seconds_since_unix_epoch); } fn update_bucket_refill_rate(&mut self, seconds_since_unix_epoch: f64, new_fill_rate: f64) { // Refill based on our current rate before we update to the new fill rate. self.refill(seconds_since_unix_epoch); - self.token_refill_rate = f64::max(new_fill_rate, MIN_FILL_RATE); - self.maximum_bucket_capacity = f64::max(new_fill_rate, MIN_CAPACITY); + self.fill_rate = f64::max(new_fill_rate, MIN_FILL_RATE); + self.max_capacity = f64::max(new_fill_rate, MIN_CAPACITY); + + debug!( + fill_rate = self.fill_rate, + max_capacity = self.max_capacity, + current_capacity = self.current_capacity, + measured_tx_rate = self.measured_tx_rate, + "client rate limiter state has been updated" + ); + // When we scale down we can't have a current capacity that exceeds our max_capacity. - self.current_bucket_capacity = - f64::min(self.current_bucket_capacity, self.maximum_bucket_capacity); + self.current_capacity = f64::min(self.current_capacity, self.max_capacity); } fn enable_token_bucket(&mut self) { - self.enable_throttling = true; + // If throttling wasn't already enabled, note that we're now enabling it. + if !self.enabled { + debug!("client rate limiting has been enabled"); + } + self.enabled = true; } fn update_tokens_retrieved_per_second(&mut self, seconds_since_unix_epoch: f64) { let next_time_bucket = (seconds_since_unix_epoch * 2.0).floor() / 2.0; self.request_count += 1; - if next_time_bucket > self.previous_time_bucket { + if next_time_bucket > self.last_tx_rate_bucket { let current_rate = - self.request_count as f64 / (next_time_bucket - self.previous_time_bucket); - self.tokens_retrieved_per_second = - current_rate * SMOOTH + self.tokens_retrieved_per_second * (1.0 - SMOOTH); + self.request_count as f64 / (next_time_bucket - self.last_tx_rate_bucket); + self.measured_tx_rate = current_rate * SMOOTH + self.measured_tx_rate * (1.0 - SMOOTH); self.request_count = 0; - self.previous_time_bucket = next_time_bucket; + self.last_tx_rate_bucket = next_time_bucket; } } - fn calculate_time_window(&mut self) { - // This is broken out into a separate calculation because it only - // gets updated when @tokens_retrieved_per_second_at_time_of_last_throttle() changes so it can be cached. - let base = (self.tokens_retrieved_per_second_at_time_of_last_throttle * (1.0 - BETA)) - / SCALE_CONSTANT; - self.time_window = base.powf(1.0 / 3.0); + fn calculate_time_window(&self) -> f64 { + let base = (self.last_max_rate * (1.0 - BETA)) / SCALE_CONSTANT; + base.powf(1.0 / 3.0) } fn cubic_success(&self, seconds_since_unix_epoch: f64) -> f64 { - let dt = seconds_since_unix_epoch - self.time_of_last_throttle - self.time_window; - (SCALE_CONSTANT * dt.powi(3)) + self.tokens_retrieved_per_second_at_time_of_last_throttle + let dt = + seconds_since_unix_epoch - self.time_of_last_throttle - self.calculate_time_window(); + (SCALE_CONSTANT * dt.powi(3)) + self.last_max_rate } } @@ -205,15 +259,6 @@ fn cubic_throttle(rate_to_use: f64) -> f64 { rate_to_use * BETA } -fn get_unix_timestamp(cfg: &ConfigBag) -> f64 { - let request_time = cfg.request_time().unwrap(); - request_time - .now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap() - .as_secs_f64() -} - builder!( set_token_refill_rate, token_refill_rate, f64, "The rate at which token are replenished.", set_maximum_bucket_capacity, maximum_bucket_capacity, f64, "The maximum capacity allowed in the token bucket.", @@ -224,35 +269,34 @@ builder!( set_request_count, request_count, u64, "The number of requests seen within the current time bucket.", set_enable_throttling, enable_throttling, bool, "Boolean indicating if the token bucket is enabled. The token bucket is initially disabled. When a throttling error is encountered it is enabled.", set_tokens_retrieved_per_second_at_time_of_last_throttle, tokens_retrieved_per_second_at_time_of_last_throttle, f64, "The maximum rate when the client was last throttled.", - set_time_of_last_throttle, time_of_last_throttle, f64, "The last time when the client was throttled.", - set_time_window, time_window, f64, "The time window used to calculate the cubic success rate.", - set_calculated_rate, calculated_rate, f64, "The calculated rate used to update the sending rate." + set_time_of_last_throttle, time_of_last_throttle, f64, "The last time when the client was throttled." ); impl Builder { fn build(self) -> ClientRateLimiter { ClientRateLimiter { - token_refill_rate: self.token_refill_rate.unwrap_or_default(), - maximum_bucket_capacity: self.maximum_bucket_capacity.unwrap_or(f64::MAX), - current_bucket_capacity: self.current_bucket_capacity.unwrap_or_default(), - time_of_last_refill: self.time_of_last_refill, - enable_throttling: self.enable_throttling.unwrap_or_default(), - tokens_retrieved_per_second: self.tokens_retrieved_per_second.unwrap_or_default(), - previous_time_bucket: self.previous_time_bucket.unwrap_or_default(), - request_count: self.request_count.unwrap_or_default(), - tokens_retrieved_per_second_at_time_of_last_throttle: self - .tokens_retrieved_per_second_at_time_of_last_throttle - .unwrap_or_default(), - time_of_last_throttle: self.time_of_last_throttle.unwrap_or_default(), - time_window: self.time_window.unwrap_or_default(), - calculated_rate: self.calculated_rate.unwrap_or_default(), + inner: Arc::new(Mutex::new(Inner { + fill_rate: self.token_refill_rate.unwrap_or_default(), + max_capacity: self.maximum_bucket_capacity.unwrap_or(f64::MAX), + current_capacity: self.current_bucket_capacity.unwrap_or_default(), + last_timestamp: self.time_of_last_refill, + enabled: self.enable_throttling.unwrap_or_default(), + measured_tx_rate: self.tokens_retrieved_per_second.unwrap_or_default(), + last_tx_rate_bucket: self.previous_time_bucket.unwrap_or_default(), + request_count: self.request_count.unwrap_or_default(), + last_max_rate: self + .tokens_retrieved_per_second_at_time_of_last_throttle + .unwrap_or_default(), + time_of_last_throttle: self.time_of_last_throttle.unwrap_or_default(), + })), } } } #[cfg(test)] mod tests { - use super::{cubic_throttle, get_unix_timestamp, ClientRateLimiter}; + use super::{cubic_throttle, ClientRateLimiter}; + use crate::client::retries::client_rate_limiter::RequestReason; use approx::assert_relative_eq; use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep}; use aws_smithy_async::test_util::instant_time_and_sleep; @@ -261,28 +305,21 @@ mod tests { use aws_smithy_types::config_bag::ConfigBag; use std::time::{Duration, SystemTime}; - #[test] - fn it_sets_the_time_window_correctly() { - let mut rate_limiter = ClientRateLimiter::builder() - .tokens_retrieved_per_second_at_time_of_last_throttle(10.0) - .build(); - - rate_limiter.calculate_time_window(); - assert_relative_eq!(rate_limiter.time_window, 1.9574338205844317); - } + const ONE_SECOND: Duration = Duration::from_secs(1); + const TWO_HUNDRED_MILLISECONDS: Duration = Duration::from_millis(200); #[test] fn should_match_beta_decrease() { let new_rate = cubic_throttle(10.0); assert_relative_eq!(new_rate, 7.0); - let mut rate_limiter = ClientRateLimiter::builder() + let rate_limiter = ClientRateLimiter::builder() .tokens_retrieved_per_second_at_time_of_last_throttle(10.0) .time_of_last_throttle(1.0) .build(); - rate_limiter.calculate_time_window(); - let new_rate = rate_limiter.cubic_success(1.0); + rate_limiter.inner.lock().unwrap().calculate_time_window(); + let new_rate = rate_limiter.inner.lock().unwrap().cubic_success(1.0); assert_relative_eq!(new_rate, 7.0); } @@ -294,19 +331,18 @@ mod tests { .set_request_time(SharedTimeSource::new(time_source)); cfg.interceptor_state() .set_sleep_impl(Some(SharedAsyncSleep::new(sleep_impl))); - let now = get_unix_timestamp(&cfg); - let mut rate_limiter = ClientRateLimiter::builder() - .previous_time_bucket((now).floor()) - .time_of_last_throttle(now) + let rate_limiter = ClientRateLimiter::builder() + .previous_time_bucket(0.0) + .time_of_last_throttle(0.0) .build(); assert!( - !rate_limiter.enable_throttling, + !rate_limiter.inner.lock().unwrap().enabled, "rate_limiter should be disabled by default" ); - rate_limiter.update_rate_limiter(now, true); + rate_limiter.update_rate_limiter(0.0, true); assert!( - rate_limiter.enable_throttling, + rate_limiter.inner.lock().unwrap().enabled, "rate_limiter should be enabled after throttling error" ); } @@ -320,9 +356,8 @@ mod tests { .set_request_time(SharedTimeSource::new(time_source)); cfg.interceptor_state() .set_sleep_impl(Some(SharedAsyncSleep::new(sleep_impl.clone()))); - let now = get_unix_timestamp(&cfg); - let mut rate_limiter = ClientRateLimiter::builder() - .time_of_last_throttle(now) + let rate_limiter = ClientRateLimiter::builder() + .time_of_last_throttle(5.0) .tokens_retrieved_per_second_at_time_of_last_throttle(10.0) .build(); @@ -366,8 +401,12 @@ mod tests { // was implemented. See for yourself: // https://github.com/aws/aws-sdk-go-v2/blob/844ff45cdc76182229ad098c95bf3f5ab8c20e9f/aws/retry/adaptive_ratelimit_test.go#L97 for attempt in attempts { - rate_limiter.calculate_time_window(); - let calculated_rate = rate_limiter.cubic_success(attempt.seconds_since_unix_epoch); + rate_limiter.inner.lock().unwrap().calculate_time_window(); + let calculated_rate = rate_limiter + .inner + .lock() + .unwrap() + .cubic_success(attempt.seconds_since_unix_epoch); assert_relative_eq!(attempt.expected_calculated_rate, calculated_rate); } @@ -382,10 +421,9 @@ mod tests { .set_request_time(SharedTimeSource::new(time_source)); cfg.interceptor_state() .set_sleep_impl(Some(SharedAsyncSleep::new(sleep_impl.clone()))); - let now = get_unix_timestamp(&cfg); - let mut rate_limiter = ClientRateLimiter::builder() + let rate_limiter = ClientRateLimiter::builder() .tokens_retrieved_per_second_at_time_of_last_throttle(10.0) - .time_of_last_throttle(now) + .time_of_last_throttle(5.0) .build(); struct Attempt { @@ -442,13 +480,14 @@ mod tests { // https://github.com/aws/aws-sdk-go-v2/blob/844ff45cdc76182229ad098c95bf3f5ab8c20e9f/aws/retry/adaptive_ratelimit_test.go#L97 let mut calculated_rate = 0.0; for attempt in attempts { - rate_limiter.calculate_time_window(); + let mut inner = rate_limiter.inner.lock().unwrap(); + inner.calculate_time_window(); if attempt.throttled { calculated_rate = cubic_throttle(calculated_rate); - rate_limiter.time_of_last_throttle = attempt.seconds_since_unix_epoch; - rate_limiter.tokens_retrieved_per_second_at_time_of_last_throttle = calculated_rate; + inner.time_of_last_throttle = attempt.seconds_since_unix_epoch; + inner.last_max_rate = calculated_rate; } else { - calculated_rate = rate_limiter.cubic_success(attempt.seconds_since_unix_epoch); + calculated_rate = inner.cubic_success(attempt.seconds_since_unix_epoch); }; assert_relative_eq!(attempt.expected_calculated_rate, calculated_rate); @@ -463,7 +502,7 @@ mod tests { .set_request_time(SharedTimeSource::new(time_source)); cfg.interceptor_state() .set_sleep_impl(Some(SharedAsyncSleep::new(sleep_impl.clone()))); - let mut rate_limiter = ClientRateLimiter::builder().build(); + let rate_limiter = ClientRateLimiter::builder().build(); struct Attempt { throttled: bool, @@ -577,9 +616,8 @@ mod tests { }, ]; - let two_hundred_milliseconds = Duration::from_millis(200); for attempt in attempts { - sleep_impl.sleep(two_hundred_milliseconds).await; + sleep_impl.sleep(TWO_HUNDRED_MILLISECONDS).await; assert_eq!( attempt.seconds_since_unix_epoch, sleep_impl.total_duration().as_secs_f64() @@ -588,12 +626,53 @@ mod tests { rate_limiter.update_rate_limiter(attempt.seconds_since_unix_epoch, attempt.throttled); assert_relative_eq!( attempt.expected_tokens_retrieved_per_second, - rate_limiter.tokens_retrieved_per_second + rate_limiter.inner.lock().unwrap().measured_tx_rate ); assert_relative_eq!( attempt.expected_token_refill_rate, - rate_limiter.token_refill_rate + rate_limiter.inner.lock().unwrap().fill_rate ); } } + + // This test is only testing that we don't fail basic math and panic. It does include an + // element of randomness, but no duration between >= 0.0s and <= 1.0s will ever cause a panic. + // + // Because the cost of sending an individual request is 1.0, and because the minimum capacity is + // also 1.0, we will never encounter a situation where we run out of tokens. + #[tokio::test] + async fn test_when_throttling_is_enabled_requests_can_still_be_sent() { + let (time_source, sleep_impl) = instant_time_and_sleep(SystemTime::UNIX_EPOCH); + let crl = ClientRateLimiter::builder() + .time_of_last_throttle(0.0) + .previous_time_bucket(0.0) + .build(); + + // Start by recording a throttling error + crl.update_rate_limiter(0.0, true); + + for _i in 0..100 { + // advance time by a random amount (up to 1s) each iteration + let duration = Duration::from_secs_f64(fastrand::f64()); + sleep_impl.sleep(duration).await; + if let Err(delay) = crl.acquire_permission_to_send_a_request( + time_source.seconds_since_unix_epoch(), + RequestReason::InitialRequest, + ) { + sleep_impl.sleep(delay).await; + } + + // Assume all further requests succeed on the first try + crl.update_rate_limiter(time_source.seconds_since_unix_epoch(), false); + } + + let inner = crl.inner.lock().unwrap(); + assert!(inner.enabled, "the rate limiter should still be enabled"); + // Assert that the rate limiter respects the passage of time. + assert_relative_eq!( + inner.last_timestamp.unwrap(), + sleep_impl.total_duration().as_secs_f64(), + max_relative = 0.0001 + ); + } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/fixed_delay.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/fixed_delay.rs index 81437cead3f3307c57b8e7c2d530b4d569efe3b6..088e7bfd8c59123cc5dc9e71c7796872ed8fe8fd 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/fixed_delay.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/fixed_delay.rs @@ -56,7 +56,7 @@ impl RetryStrategy for FixedDelayRetryStrategy { let request_attempts = cfg .load::() .expect("at least one request attempt is made before any retry is attempted"); - if request_attempts.attempts() >= self.max_attempts as usize { + if request_attempts.attempts() >= self.max_attempts { tracing::trace!( attempts = request_attempts.attempts(), max_attempts = self.max_attempts, diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs index 50a83cae9b97b90fd84cda3a1a5dc97f77a56b8f..307297e0c5a88bebb20e058420010a45a110a0cd 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/strategy/standard.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +use crate::client::retries::client_rate_limiter::{ClientRateLimiter, RequestReason}; use crate::client::retries::strategy::standard::ReleaseResult::{ APermitWasReleased, NoPermitWasReleased, }; @@ -15,20 +16,21 @@ use aws_smithy_runtime_api::client::retries::{ ClassifyRetry, RetryReason, RetryStrategy, ShouldAttempt, }; use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace}; -use aws_smithy_types::retry::RetryConfig; +use aws_smithy_types::retry::{ErrorKind, RetryConfig}; use std::sync::Mutex; -use std::time::Duration; +use std::time::{Duration, SystemTime}; use tokio::sync::OwnedSemaphorePermit; +use tracing::debug; // The initial attempt, plus three retries. -const DEFAULT_MAX_ATTEMPTS: usize = 4; +const DEFAULT_MAX_ATTEMPTS: u32 = 4; #[derive(Debug)] pub struct StandardRetryStrategy { // Retry settings base: fn() -> f64, initial_backoff: Duration, - max_attempts: usize, + max_attempts: u32, max_backoff: Duration, retry_permit: Mutex>, } @@ -39,9 +41,16 @@ impl Storable for StandardRetryStrategy { impl StandardRetryStrategy { pub fn new(retry_config: &RetryConfig) -> Self { + let base = if retry_config.use_static_exponential_base() { + || 1.0 + } else { + fastrand::f64 + }; // TODO(enableNewSmithyRuntimeLaunch) add support for `retry_config.reconnect_mode()` here or in the orchestrator flow. Self::default() - .with_max_attempts(retry_config.max_attempts() as usize) + .with_base(base) + .with_max_backoff(retry_config.max_backoff()) + .with_max_attempts(retry_config.max_attempts()) .with_initial_backoff(retry_config.initial_backoff()) } @@ -50,7 +59,7 @@ impl StandardRetryStrategy { self } - pub fn with_max_attempts(mut self, max_attempts: usize) -> Self { + pub fn with_max_attempts(mut self, max_attempts: u32) -> Self { self.max_attempts = max_attempts; self } @@ -84,6 +93,61 @@ impl StandardRetryStrategy { p.forget() } } + + fn calculate_backoff( + &self, + cfg: &ConfigBag, + retry_reason: Option<&RetryReason>, + ) -> Result { + let request_attempts = cfg + .load::() + .expect("at least one request attempt is made before any retry is attempted") + .attempts(); + let token_bucket = cfg.load::(); + + match retry_reason { + Some(RetryReason::Explicit(backoff)) => Ok(*backoff), + Some(RetryReason::Error(kind)) => { + update_rate_limiter_if_exists(cfg, *kind == ErrorKind::ThrottlingError); + if let Some(delay) = check_rate_limiter_for_delay(cfg, *kind) { + let delay = delay.min(self.max_backoff); + debug!("rate limiter has requested a {delay:?} delay before retrying"); + Ok(delay) + } else { + if let Some(tb) = token_bucket { + match tb.acquire(kind) { + Some(permit) => self.set_retry_permit(permit), + None => { + debug!("attempt #{request_attempts} failed with {kind:?}; However, no retry permits are available, so no retry will be attempted."); + return Err(ShouldAttempt::No); + } + } + } + + let backoff = calculate_exponential_backoff( + // Generate a random base multiplier to create jitter + (self.base)(), + // Get the backoff time multiplier in seconds (with fractional seconds) + self.initial_backoff.as_secs_f64(), + // `self.local.attempts` tracks number of requests made including the initial request + // The initial attempt shouldn't count towards backoff calculations so we subtract it + request_attempts - 1, + ); + Ok(Duration::from_secs_f64(backoff).min(self.max_backoff)) + } + } + Some(_) => unreachable!("RetryReason is non-exhaustive"), + None => { + update_rate_limiter_if_exists(cfg, false); + debug!( + attempts = request_attempts, + max_attempts = self.max_attempts, + "encountered unretryable error" + ); + Err(ShouldAttempt::No) + } + } + } } enum ReleaseResult { @@ -105,8 +169,19 @@ impl Default for StandardRetryStrategy { } impl RetryStrategy for StandardRetryStrategy { - fn should_attempt_initial_request(&self, _cfg: &ConfigBag) -> Result { - // The standard token bucket is only ever considered for retry requests. + fn should_attempt_initial_request(&self, cfg: &ConfigBag) -> Result { + if let Some(crl) = cfg.load::() { + let seconds_since_unix_epoch = get_seconds_since_unix_epoch(cfg); + if let Err(delay) = crl.acquire_permission_to_send_a_request( + seconds_since_unix_epoch, + RequestReason::InitialRequest, + ) { + return Ok(ShouldAttempt::YesAfterDelay(delay)); + } + } else { + debug!("no client rate limiter configured, so no token is required for the initial request."); + } + Ok(ShouldAttempt::Yes) } @@ -121,7 +196,7 @@ impl RetryStrategy for StandardRetryStrategy { ); let token_bucket = cfg.load::(); if output_or_error.is_ok() { - tracing::debug!("request succeeded, no retry necessary"); + debug!("request succeeded, no retry necessary"); if let Some(tb) = token_bucket { // If this retry strategy is holding any permits, release them back to the bucket. if let NoPermitWasReleased = self.release_retry_permit() { @@ -132,6 +207,7 @@ impl RetryStrategy for StandardRetryStrategy { tb.regenerate_a_token(); } } + update_rate_limiter_if_exists(cfg, false); return Ok(ShouldAttempt::No); } @@ -142,7 +218,9 @@ impl RetryStrategy for StandardRetryStrategy { .expect("at least one request attempt is made before any retry is attempted") .attempts(); if request_attempts >= self.max_attempts { - tracing::trace!( + update_rate_limiter_if_exists(cfg, false); + + debug!( attempts = request_attempts, max_attempts = self.max_attempts, "not retrying because we are out of attempts" @@ -155,46 +233,12 @@ impl RetryStrategy for StandardRetryStrategy { let retry_reason = retry_classifiers.classify_retry(ctx); // Calculate the appropriate backoff time. - let backoff = match retry_reason { - Some(RetryReason::Explicit(dur)) => dur, - Some(RetryReason::Error(kind)) => { - // If a token bucket was set, and the RetryReason IS NOT explicit, attempt to acquire a retry permit. - if let Some(tb) = token_bucket { - match tb.acquire(&kind) { - Some(permit) => self.set_retry_permit(permit), - None => { - tracing::debug!( - "attempt #{request_attempts} failed with {kind:?}; \ - However, no retry permits are available, so no retry will be attempted.", - ); - return Ok(ShouldAttempt::No); - } - } - }; - - let backoff = calculate_exponential_backoff( - // Generate a random base multiplier to create jitter - (self.base)(), - // Get the backoff time multiplier in seconds (with fractional seconds) - self.initial_backoff.as_secs_f64(), - // `self.local.attempts` tracks number of requests made including the initial request - // The initial attempt shouldn't count towards backoff calculations so we subtract it - (request_attempts - 1) as u32, - ); - Duration::from_secs_f64(backoff).min(self.max_backoff) - } - Some(_) => unreachable!("RetryReason is non-exhaustive"), - None => { - tracing::debug!( - attempts = request_attempts, - max_attempts = self.max_attempts, - "encountered unretryable error" - ); - return Ok(ShouldAttempt::No); - } + let backoff = match self.calculate_backoff(cfg, retry_reason.as_ref()) { + Ok(value) => value, + // In some cases, backoff calculation will decide that we shouldn't retry at all. + Err(value) => return Ok(value), }; - - tracing::debug!( + debug!( "attempt #{request_attempts} failed with {:?}; retrying after {:?}", retry_reason.expect("the match statement above ensures this is not None"), backoff @@ -204,18 +248,47 @@ impl RetryStrategy for StandardRetryStrategy { } } +fn update_rate_limiter_if_exists(cfg: &ConfigBag, is_throttling_error: bool) { + if let Some(crl) = cfg.load::() { + let seconds_since_unix_epoch = get_seconds_since_unix_epoch(cfg); + crl.update_rate_limiter(seconds_since_unix_epoch, is_throttling_error); + } +} + +fn check_rate_limiter_for_delay(cfg: &ConfigBag, kind: ErrorKind) -> Option { + if let Some(crl) = cfg.load::() { + let retry_reason = if kind == ErrorKind::ThrottlingError { + RequestReason::RetryTimeout + } else { + RequestReason::Retry + }; + if let Err(delay) = crl + .acquire_permission_to_send_a_request(get_seconds_since_unix_epoch(cfg), retry_reason) + { + return Some(delay); + } + } + + None +} + fn calculate_exponential_backoff(base: f64, initial_backoff: f64, retry_attempts: u32) -> f64 { base * initial_backoff * 2_u32.pow(retry_attempts) as f64 } +fn get_seconds_since_unix_epoch(cfg: &ConfigBag) -> f64 { + let request_time = cfg.request_time().unwrap(); + request_time + .now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs_f64() +} + #[cfg(test)] mod tests { use super::*; - use super::{calculate_exponential_backoff, ShouldAttempt, StandardRetryStrategy}; - use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors; - use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext; use aws_smithy_runtime_api::client::orchestrator::OrchestratorError; - use aws_smithy_runtime_api::client::request_attempts::RequestAttempts; use aws_smithy_runtime_api::client::retries::{ AlwaysRetry, ClassifyRetry, RetryClassifiers, RetryReason, RetryStrategy, }; @@ -243,7 +316,7 @@ mod tests { fn set_up_cfg_and_context( error_kind: ErrorKind, - current_request_attempts: usize, + current_request_attempts: u32, ) -> (InterceptorContext, ConfigBag) { let mut ctx = InterceptorContext::new(TypeErasedBox::doesnt_matter()); ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter"))); @@ -492,7 +565,7 @@ mod tests { let (mut cfg, mut ctx) = setup_test(vec![RetryReason::Error(ErrorKind::TransientError)]); let strategy = StandardRetryStrategy::default() .with_base(|| 1.0) - .with_max_attempts(usize::MAX); + .with_max_attempts(u32::MAX); cfg.interceptor_state() .store_put(TokenBucket::new(PERMIT_COUNT)); let token_bucket = cfg.load::().unwrap().clone(); @@ -522,7 +595,7 @@ mod tests { // Replenish permits until we get back to `PERMIT_COUNT` while token_bucket.available_permits() < PERMIT_COUNT { if attempt > 23 { - panic!("This test should have completed by now (fillup)"); + panic!("This test should have completed by now (fill-up)"); } cfg.interceptor_state() diff --git a/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs b/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs index df1510459aa500e79d0baa52f12e4888fa20ef9b..a7c95d2a51139ecff764c8b57a0d079046342068 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/retries/token_bucket.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +use crate::client::retries::RetryPartition; use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin; use aws_smithy_types::config_bag::{FrozenLayer, Layer, Storable, StoreReplace}; use aws_smithy_types::retry::ErrorKind; @@ -34,13 +35,26 @@ impl RuntimePlugin for TokenBucketRuntimePlugin { } } +#[doc(hidden)] +#[non_exhaustive] +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct TokenBucketPartition { + retry_partition: RetryPartition, +} + +impl TokenBucketPartition { + pub fn new(retry_partition: RetryPartition) -> Self { + Self { retry_partition } + } +} + const DEFAULT_CAPACITY: usize = 500; const RETRY_COST: u32 = 5; const RETRY_TIMEOUT_COST: u32 = RETRY_COST * 2; const PERMIT_REGENERATION_AMOUNT: usize = 1; #[derive(Clone, Debug)] -pub(crate) struct TokenBucket { +pub struct TokenBucket { semaphore: Arc, max_permits: usize, timeout_retry_cost: u32, @@ -63,7 +77,7 @@ impl Default for TokenBucket { } impl TokenBucket { - pub(crate) fn new(initial_quota: usize) -> Self { + pub fn new(initial_quota: usize) -> Self { Self { semaphore: Arc::new(Semaphore::new(initial_quota)), max_permits: initial_quota, diff --git a/rust-runtime/aws-smithy-runtime/src/client/timeout.rs b/rust-runtime/aws-smithy-runtime/src/client/timeout.rs index cb8b0d0e90cfa3fb7c9dd3a4d20720c255143181..e1ba3aec636fef3214e08864449c3fd10437aa95 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/timeout.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/timeout.rs @@ -180,11 +180,16 @@ where #[cfg(test)] mod tests { - use super::*; + use crate::client::timeout::{MaybeTimeout, TimeoutKind}; use aws_smithy_async::assert_elapsed; use aws_smithy_async::future::never::Never; - use aws_smithy_async::rt::sleep::TokioSleep; - use aws_smithy_types::config_bag::Layer; + use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep, TokioSleep}; + use aws_smithy_http::result::SdkError; + use aws_smithy_runtime_api::client::config_bag_accessors::ConfigBagAccessors; + use aws_smithy_runtime_api::client::orchestrator::HttpResponse; + use aws_smithy_types::config_bag::{ConfigBag, Layer}; + use aws_smithy_types::timeout::TimeoutConfig; + use std::time::Duration; #[tokio::test] async fn test_no_timeout() { diff --git a/rust-runtime/aws-smithy-runtime/src/lib.rs b/rust-runtime/aws-smithy-runtime/src/lib.rs index 195fde2d067fa746c6f70e21d7d4d8d2ac62e3a4..2c2eb6a9f8efdc64a29bb8364e665c34a9a246be 100644 --- a/rust-runtime/aws-smithy-runtime/src/lib.rs +++ b/rust-runtime/aws-smithy-runtime/src/lib.rs @@ -11,3 +11,5 @@ )] pub mod client; + +pub mod static_partition_map; diff --git a/rust-runtime/aws-smithy-runtime/src/static_partition_map.rs b/rust-runtime/aws-smithy-runtime/src/static_partition_map.rs new file mode 100644 index 0000000000000000000000000000000000000000..10b0070ccfa2239600bcaf857f74885dec5791eb --- /dev/null +++ b/rust-runtime/aws-smithy-runtime/src/static_partition_map.rs @@ -0,0 +1,158 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use once_cell::sync::OnceCell; +use std::collections::HashMap; +use std::hash::Hash; +use std::sync::{Mutex, MutexGuard}; + +/// A data structure for persisting and sharing state between multiple clients. +/// +/// Some state should be shared between multiple clients. For example, when creating multiple clients +/// for the same service, it's desirable to share a client rate limiter. This way, when one client +/// receives a throttling response, the other clients will be aware of it as well. +/// +/// Whether clients share state is dependent on their partition key `K`. Going back to the client +/// rate limiter example, `K` would be a struct containing the name of the service as well as the +/// client's configured region, since receiving throttling responses in `us-east-1` shouldn't +/// throttle requests to the same service made in other regions. +/// +/// Values stored in a `StaticPartitionMap` will be cloned whenever they are requested. Values must +/// be initialized before they can be retrieved, and the `StaticPartitionMap::get_or_init` method is +/// how you can ensure this. +/// +/// # Example +/// +/// ``` +///use std::sync::{Arc, Mutex}; +/// use aws_smithy_runtime::static_partition_map::StaticPartitionMap; +/// +/// // The shared state must be `Clone` and will be internally mutable. Deriving `Default` isn't +/// // necessary, but allows us to use the `StaticPartitionMap::get_or_init_default` method. +/// #[derive(Clone, Default)] +/// pub struct SomeSharedState { +/// inner: Arc> +/// } +/// +/// #[derive(Default)] +/// struct Inner { +/// // Some shared state... +/// } +/// +/// // `Clone`, `Hash`, and `Eq` are all required trait impls for partition keys +/// #[derive(Clone, Hash, PartialEq, Eq)] +/// pub struct SharedStatePartition { +/// region: String, +/// service_name: String, +/// } +/// +/// impl SharedStatePartition { +/// pub fn new(region: impl Into, service_name: impl Into) -> Self { +/// Self { region: region.into(), service_name: service_name.into() } +/// } +/// } +/// +/// static SOME_SHARED_STATE: StaticPartitionMap = StaticPartitionMap::new(); +/// +/// struct Client { +/// shared_state: SomeSharedState, +/// } +/// +/// impl Client { +/// pub fn new() -> Self { +/// let key = SharedStatePartition::new("us-east-1", "example_service_20230628"); +/// Self { +/// // If the stored value implements `Default`, you can call the +/// // `StaticPartitionMap::get_or_init_default` convenience method. +/// shared_state: SOME_SHARED_STATE.get_or_init_default(key), +/// } +/// } +/// } +/// ``` +#[derive(Debug, Default)] +pub struct StaticPartitionMap { + inner: OnceCell>>, +} + +impl StaticPartitionMap { + pub const fn new() -> Self { + Self { + inner: OnceCell::new(), + } + } +} + +impl StaticPartitionMap +where + K: Eq + Hash, +{ + fn get_or_init_inner(&self) -> MutexGuard<'_, HashMap> { + self.inner + // At the very least, we'll always be storing the default state. + .get_or_init(|| Mutex::new(HashMap::with_capacity(1))) + .lock() + .unwrap() + } +} + +impl StaticPartitionMap +where + K: Eq + Hash, + V: Clone, +{ + #[must_use] + pub fn get(&self, partition_key: K) -> Option { + self.get_or_init_inner().get(&partition_key).cloned() + } + + #[must_use] + pub fn get_or_init(&self, partition_key: K, f: F) -> V + where + F: FnOnce() -> V, + { + let mut inner = self.get_or_init_inner(); + let v = inner.entry(partition_key).or_insert_with(f); + v.clone() + } +} + +impl StaticPartitionMap +where + K: Eq + Hash, + V: Clone + Default, +{ + #[must_use] + pub fn get_or_init_default(&self, partition_key: K) -> V { + self.get_or_init(partition_key, V::default) + } +} + +#[cfg(test)] +mod tests { + use super::StaticPartitionMap; + + #[test] + fn test_keyed_partition_returns_same_value_for_same_key() { + let kp = StaticPartitionMap::new(); + let _ = kp.get_or_init("A", || "A".to_owned()); + let actual = kp.get_or_init("A", || "B".to_owned()); + let expected = "A".to_owned(); + assert_eq!(expected, actual); + } + + #[test] + fn test_keyed_partition_returns_different_value_for_different_key() { + let kp = StaticPartitionMap::new(); + let _ = kp.get_or_init("A", || "A".to_owned()); + let actual = kp.get_or_init("B", || "B".to_owned()); + + let expected = "B".to_owned(); + assert_eq!(expected, actual); + + let actual = kp.get("A").unwrap(); + let expected = "A".to_owned(); + assert_eq!(expected, actual); + } +} diff --git a/rust-runtime/aws-smithy-types/src/retry.rs b/rust-runtime/aws-smithy-types/src/retry.rs index 1db024a411f258d5ca7885c794ad513e7903a42b..a5144178142bb87be4d3ce8e381ff44f3dba72b3 100644 --- a/rust-runtime/aws-smithy-types/src/retry.rs +++ b/rust-runtime/aws-smithy-types/src/retry.rs @@ -99,12 +99,12 @@ impl FromStr for RetryMode { fn from_str(string: &str) -> Result { let string = string.trim(); + // eq_ignore_ascii_case is OK here because the only strings we need to check for are ASCII if string.eq_ignore_ascii_case("standard") { Ok(RetryMode::Standard) - // TODO(https://github.com/awslabs/aws-sdk-rust/issues/247): adaptive retries - // } else if string.eq_ignore_ascii_case("adaptive") { - // Ok(RetryMode::Adaptive) + } else if string.eq_ignore_ascii_case("adaptive") { + Ok(RetryMode::Adaptive) } else { Err(RetryModeParseError::new(string)) } @@ -264,6 +264,7 @@ impl RetryConfigBuilder { .reconnect_mode .unwrap_or(ReconnectMode::ReconnectOnTransientError), max_backoff: self.max_backoff.unwrap_or_else(|| Duration::from_secs(20)), + use_static_exponential_base: false, } } } @@ -277,6 +278,7 @@ pub struct RetryConfig { initial_backoff: Duration, max_backoff: Duration, reconnect_mode: ReconnectMode, + use_static_exponential_base: bool, } impl Storable for RetryConfig { @@ -308,6 +310,19 @@ impl RetryConfig { initial_backoff: Duration::from_secs(1), reconnect_mode: ReconnectMode::ReconnectOnTransientError, max_backoff: Duration::from_secs(20), + use_static_exponential_base: false, + } + } + + /// Creates a default `RetryConfig` with `RetryMode::Adaptive` and max attempts of three. + pub fn adaptive() -> Self { + Self { + mode: RetryMode::Adaptive, + max_attempts: 3, + initial_backoff: Duration::from_secs(1), + reconnect_mode: ReconnectMode::ReconnectOnTransientError, + max_backoff: Duration::from_secs(20), + use_static_exponential_base: false, } } @@ -363,6 +378,20 @@ impl RetryConfig { self } + /// Hint to the retry strategy whether to use a static exponential base. + /// + /// When a retry strategy uses exponential backoff, it calculates a random base. This causes the + /// retry delay to be slightly random, and helps prevent "thundering herd" scenarios. However, + /// it's often useful during testing to know exactly how long the delay will be. + /// + /// Therefore, if you're writing a test and asserting an expected retry delay, + /// set this to `true`. + #[cfg(feature = "test-util")] + pub fn with_use_static_exponential_base(mut self, use_static_exponential_base: bool) -> Self { + self.use_static_exponential_base = use_static_exponential_base; + self + } + /// Returns the retry mode. pub fn mode(&self) -> RetryMode { self.mode @@ -383,10 +412,23 @@ impl RetryConfig { self.initial_backoff } + /// Returns the max backoff duration. + pub fn max_backoff(&self) -> Duration { + self.max_backoff + } + /// Returns true if retry is enabled with this config pub fn has_retry(&self) -> bool { self.max_attempts > 1 } + + /// Returns `true` if retry strategies should use a static exponential base instead of the + /// default random base. + /// + /// To set this value, the `test-util` feature must be enabled. + pub fn use_static_exponential_base(&self) -> bool { + self.use_static_exponential_base + } } #[cfg(test)]